Commit e0531221 authored by nanahira's avatar nanahira

Merge branch 'master' of github.com:AUTOMATIC1111/stable-diffusion-webui

parents 254c0e03 baf6946e
extensions
extensions-disabled
repositories
venv
\ No newline at end of file
/* global module */
module.exports = {
env: {
browser: true,
es2021: true,
},
extends: "eslint:recommended",
parserOptions: {
ecmaVersion: "latest",
},
rules: {
"arrow-spacing": "error",
"block-spacing": "error",
"brace-style": "error",
"comma-dangle": ["error", "only-multiline"],
"comma-spacing": "error",
"comma-style": ["error", "last"],
"curly": ["error", "multi-line", "consistent"],
"eol-last": "error",
"func-call-spacing": "error",
"function-call-argument-newline": ["error", "consistent"],
"function-paren-newline": ["error", "consistent"],
"indent": ["error", 4],
"key-spacing": "error",
"keyword-spacing": "error",
"linebreak-style": ["error", "unix"],
"no-extra-semi": "error",
"no-mixed-spaces-and-tabs": "error",
"no-multi-spaces": "error",
"no-redeclare": ["error", {builtinGlobals: false}],
"no-trailing-spaces": "error",
"no-unused-vars": "off",
"no-whitespace-before-property": "error",
"object-curly-newline": ["error", {consistent: true, multiline: true}],
"object-curly-spacing": ["error", "never"],
"operator-linebreak": ["error", "after"],
"quote-props": ["error", "consistent-as-needed"],
"semi": ["error", "always"],
"semi-spacing": "error",
"semi-style": ["error", "last"],
"space-before-blocks": "error",
"space-before-function-paren": ["error", "never"],
"space-in-parens": ["error", "never"],
"space-infix-ops": "error",
"space-unary-ops": "error",
"switch-colon-spacing": "error",
"template-curly-spacing": ["error", "never"],
"unicode-bom": "error",
},
globals: {
//script.js
gradioApp: "readonly",
onUiLoaded: "readonly",
onUiUpdate: "readonly",
onOptionsChanged: "readonly",
uiCurrentTab: "writable",
uiElementIsVisible: "readonly",
uiElementInSight: "readonly",
executeCallbacks: "readonly",
//ui.js
opts: "writable",
all_gallery_buttons: "readonly",
selected_gallery_button: "readonly",
selected_gallery_index: "readonly",
switch_to_txt2img: "readonly",
switch_to_img2img_tab: "readonly",
switch_to_img2img: "readonly",
switch_to_sketch: "readonly",
switch_to_inpaint: "readonly",
switch_to_inpaint_sketch: "readonly",
switch_to_extras: "readonly",
get_tab_index: "readonly",
create_submit_args: "readonly",
restart_reload: "readonly",
updateInput: "readonly",
//extraNetworks.js
requestGet: "readonly",
popup: "readonly",
// from python
localization: "readonly",
// progrssbar.js
randomId: "readonly",
requestProgress: "readonly",
// imageviewer.js
modalPrevImage: "readonly",
modalNextImage: "readonly",
}
};
# Apply ESlint
9c54b78d9dde5601e916f308d9a9d6953ec39430
\ No newline at end of file
...@@ -47,6 +47,15 @@ body: ...@@ -47,6 +47,15 @@ body:
description: Which commit are you running ? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit** link at the bottom of the UI, or from the cmd/terminal if you can't launch it.) description: Which commit are you running ? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit** link at the bottom of the UI, or from the cmd/terminal if you can't launch it.)
validations: validations:
required: true required: true
- type: dropdown
id: py-version
attributes:
label: What Python version are you running on ?
multiple: false
options:
- Python 3.10.x
- Python 3.11.x (above, no supported yet)
- Python 3.9.x (below, no recommended)
- type: dropdown - type: dropdown
id: platforms id: platforms
attributes: attributes:
...@@ -59,6 +68,18 @@ body: ...@@ -59,6 +68,18 @@ body:
- iOS - iOS
- Android - Android
- Other/Cloud - Other/Cloud
- type: dropdown
id: device
attributes:
label: What device are you running WebUI on?
multiple: true
options:
- Nvidia GPUs (RTX 20 above)
- Nvidia GPUs (GTX 16 below)
- AMD GPUs (RX 6000 above)
- AMD GPUs (RX 5000 below)
- CPU
- Other GPUs
- type: dropdown - type: dropdown
id: browsers id: browsers
attributes: attributes:
......
# Please read the [contributing wiki page](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing) before submitting a pull request! ## Description
If you have a large change, pay special attention to this paragraph: * a simple description of what you're trying to accomplish
* a summary of changes in code
* which issues it fixes, if any
> Before making changes, if you think that your feature will result in more than 100 lines changing, find me and talk to me about the feature you are proposing. It pains me to reject the hard work someone else did, but I won't add everything to the repo, and it's better if the rejection happens before you have to waste time working on the feature. ## Screenshots/videos:
Otherwise, after making sure you're following the rules described in wiki page, remove this section and continue on.
**Describe what this pull request is trying to achieve.** ## Checklist:
A clear and concise description of what you're trying to accomplish with this, so your intent doesn't have to be extracted from your code. - [ ] I have read [contributing wiki page](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing)
- [ ] I have performed a self-review of my own code
**Additional notes and description of your changes** - [ ] My code follows the [style guidelines](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing#code-style)
- [ ] My code passes [tests](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Tests)
More technical discussion about your changes go here, plus anything that a maintainer might have to specifically take a look at, or be wary of.
**Environment this was tested in**
List the environment you have developed / tested this on. As per the contributing page, changes should be able to work on Windows out of the box.
- OS: [e.g. Windows, Linux]
- Browser: [e.g. chrome, safari]
- Graphics card: [e.g. NVIDIA RTX 2080 8GB, AMD RX 6600 8GB]
**Screenshots or videos of your changes**
If applicable, screenshots or a video showing off your changes. If it edits an existing UI, it should ideally contain a comparison of what used to be there, before your changes were made.
This is **required** for anything that touches the user interface.
\ No newline at end of file
# See https://github.com/actions/starter-workflows/blob/1067f16ad8a1eac328834e4b0ae24f7d206f810d/ci/pylint.yml for original reference file
name: Run Linting/Formatting on Pull Requests name: Run Linting/Formatting on Pull Requests
on: on:
- push - push
- pull_request - pull_request
# See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#onpull_requestpull_request_targetbranchesbranches-ignore for syntax docs
# if you want to filter out branches, delete the `- pull_request` and uncomment these lines :
# pull_request:
# branches:
# - master
# branches-ignore:
# - development
jobs: jobs:
lint: lint-python:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout Code - name: Checkout Code
uses: actions/checkout@v3 uses: actions/checkout@v3
- name: Set up Python 3.10 - uses: actions/setup-python@v4
uses: actions/setup-python@v4
with: with:
python-version: 3.10.6 python-version: 3.11
cache: pip # NB: there's no cache: pip here since we're not installing anything
cache-dependency-path: | # from the requirements.txt file(s) in the repository; it's faster
**/requirements*txt # not to have GHA download an (at the time of writing) 4 GB cache
- name: Install PyLint # of PyTorch and other dependencies.
run: | - name: Install Ruff
python -m pip install --upgrade pip run: pip install ruff==0.0.265
pip install pylint - name: Run Ruff
# This lets PyLint check to see if it can resolve imports run: ruff .
- name: Install dependencies lint-js:
run: | runs-on: ubuntu-latest
export COMMANDLINE_ARGS="--skip-torch-cuda-test --exit" steps:
python launch.py - name: Checkout Code
- name: Analysing the code with pylint uses: actions/checkout@v3
run: | - name: Install Node.js
pylint $(git ls-files '*.py') uses: actions/setup-node@v3
with:
node-version: 18
- run: npm i --ci
- run: npm run lint
...@@ -17,13 +17,54 @@ jobs: ...@@ -17,13 +17,54 @@ jobs:
cache: pip cache: pip
cache-dependency-path: | cache-dependency-path: |
**/requirements*txt **/requirements*txt
launch.py
- name: Install test dependencies
run: pip install wait-for-it -r requirements-test.txt
env:
PIP_DISABLE_PIP_VERSION_CHECK: "1"
PIP_PROGRESS_BAR: "off"
- name: Setup environment
run: python launch.py --skip-torch-cuda-test --exit
env:
PIP_DISABLE_PIP_VERSION_CHECK: "1"
PIP_PROGRESS_BAR: "off"
TORCH_INDEX_URL: https://download.pytorch.org/whl/cpu
WEBUI_LAUNCH_LIVE_OUTPUT: "1"
PYTHONUNBUFFERED: "1"
- name: Start test server
run: >
python -m coverage run
--data-file=.coverage.server
launch.py
--skip-prepare-environment
--skip-torch-cuda-test
--test-server
--no-half
--disable-opt-split-attention
--use-cpu all
--add-stop-route
2>&1 | tee output.txt &
- name: Run tests - name: Run tests
run: python launch.py --tests test --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test run: |
- name: Upload main app stdout-stderr wait-for-it --service 127.0.0.1:7860 -t 600
python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test
- name: Kill test server
if: always()
run: curl -vv -XPOST http://127.0.0.1:7860/_stop && sleep 10
- name: Show coverage
run: |
python -m coverage combine .coverage*
python -m coverage report -i
python -m coverage html -i
- name: Upload main app output
uses: actions/upload-artifact@v3
if: always()
with:
name: output
path: output.txt
- name: Upload coverage HTML
uses: actions/upload-artifact@v3 uses: actions/upload-artifact@v3
if: always() if: always()
with: with:
name: stdout-stderr name: htmlcov
path: | path: htmlcov
test/stdout.txt
test/stderr.txt
...@@ -32,4 +32,8 @@ notification.mp3 ...@@ -32,4 +32,8 @@ notification.mp3
/extensions /extensions
/test/stdout.txt /test/stdout.txt
/test/stderr.txt /test/stderr.txt
/cache.json /cache.json*
/config_states/
/node_modules
/package-lock.json
/.coverage*
This diff is collapsed.
...@@ -15,7 +15,7 @@ A browser interface based on Gradio library for Stable Diffusion. ...@@ -15,7 +15,7 @@ A browser interface based on Gradio library for Stable Diffusion.
- Attention, specify parts of text that the model should pay more attention to - Attention, specify parts of text that the model should pay more attention to
- a man in a `((tuxedo))` - will pay more attention to tuxedo - a man in a `((tuxedo))` - will pay more attention to tuxedo
- a man in a `(tuxedo:1.21)` - alternative syntax - a man in a `(tuxedo:1.21)` - alternative syntax
- select text and press `Ctrl+Up` or `Ctrl+Down` to automatically adjust attention to selected text (code contributed by anonymous user) - select text and press `Ctrl+Up` or `Ctrl+Down` (or `Command+Up` or `Command+Down` if you're on a MacOS) to automatically adjust attention to selected text (code contributed by anonymous user)
- Loopback, run img2img processing multiple times - Loopback, run img2img processing multiple times
- X/Y/Z plot, a way to draw a 3 dimensional plot of images with different parameters - X/Y/Z plot, a way to draw a 3 dimensional plot of images with different parameters
- Textual Inversion - Textual Inversion
...@@ -99,8 +99,14 @@ Alternatively, use online services (like Google Colab): ...@@ -99,8 +99,14 @@ Alternatively, use online services (like Google Colab):
- [List of Online Services](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Online-Services) - [List of Online Services](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Online-Services)
### Installation on Windows 10/11 with NVidia-GPUs using release package
1. Download `sd.webui.zip` from [v1.0.0-pre](https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/tag/v1.0.0-pre) and extract it's contents.
2. Run `update.bat`.
3. Run `run.bat`.
> For more details see [Install-and-Run-on-NVidia-GPUs](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs)
### Automatic Installation on Windows ### Automatic Installation on Windows
1. Install [Python 3.10.6](https://www.python.org/downloads/windows/), checking "Add Python to PATH". 1. Install [Python 3.10.6](https://www.python.org/downloads/release/python-3106/) (Newer version of Python does not support torch), checking "Add Python to PATH".
2. Install [git](https://git-scm.com/download/win). 2. Install [git](https://git-scm.com/download/win).
3. Download the stable-diffusion-webui repository, for example by running `git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git`. 3. Download the stable-diffusion-webui repository, for example by running `git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git`.
4. Run `webui-user.bat` from Windows Explorer as normal, non-administrator, user. 4. Run `webui-user.bat` from Windows Explorer as normal, non-administrator, user.
...@@ -115,11 +121,12 @@ sudo dnf install wget git python3 ...@@ -115,11 +121,12 @@ sudo dnf install wget git python3
# Arch-based: # Arch-based:
sudo pacman -S wget git python3 sudo pacman -S wget git python3
``` ```
2. To install in `/home/$(whoami)/stable-diffusion-webui/`, run: 2. Navigate to the directory you would like the webui to be installed and execute the following command:
```bash ```bash
bash <(wget -qO- https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/master/webui.sh) bash <(wget -qO- https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/master/webui.sh)
``` ```
3. Run `webui.sh`. 3. Run `webui.sh`.
4. Check `webui-user.sh` for options.
### Installation on Apple Silicon ### Installation on Apple Silicon
Find the instructions [here](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Installation-on-Apple-Silicon). Find the instructions [here](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Installation-on-Apple-Silicon).
...@@ -157,5 +164,6 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al ...@@ -157,5 +164,6 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
- Instruct pix2pix - Tim Brooks (star), Aleksander Holynski (star), Alexei A. Efros (no star) - https://github.com/timothybrooks/instruct-pix2pix - Instruct pix2pix - Tim Brooks (star), Aleksander Holynski (star), Alexei A. Efros (no star) - https://github.com/timothybrooks/instruct-pix2pix
- Security advice - RyotaK - Security advice - RyotaK
- UniPC sampler - Wenliang Zhao - https://github.com/wl-zhao/UniPC - UniPC sampler - Wenliang Zhao - https://github.com/wl-zhao/UniPC
- TAESD - Ollin Boer Bohan - https://github.com/madebyollin/taesd
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user. - Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
- (You) - (You)
...@@ -4,8 +4,8 @@ channels: ...@@ -4,8 +4,8 @@ channels:
- defaults - defaults
dependencies: dependencies:
- python=3.10 - python=3.10
- pip=22.2.2 - pip=23.0
- cudatoolkit=11.3 - cudatoolkit=11.8
- pytorch=1.12.1 - pytorch=2.0
- torchvision=0.13.1 - torchvision=0.15
- numpy=1.23.1 - numpy=1.23
\ No newline at end of file
...@@ -88,7 +88,7 @@ class LDSR: ...@@ -88,7 +88,7 @@ class LDSR:
x_t = None x_t = None
logs = None logs = None
for n in range(n_runs): for _ in range(n_runs):
if custom_shape is not None: if custom_shape is not None:
x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device) x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0]) x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0])
...@@ -110,7 +110,6 @@ class LDSR: ...@@ -110,7 +110,6 @@ class LDSR:
diffusion_steps = int(steps) diffusion_steps = int(steps)
eta = 1.0 eta = 1.0
down_sample_method = 'Lanczos'
gc.collect() gc.collect()
if torch.cuda.is_available: if torch.cuda.is_available:
...@@ -158,7 +157,7 @@ class LDSR: ...@@ -158,7 +157,7 @@ class LDSR:
def get_cond(selected_path): def get_cond(selected_path):
example = dict() example = {}
up_f = 4 up_f = 4
c = selected_path.convert('RGB') c = selected_path.convert('RGB')
c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0) c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
...@@ -196,7 +195,7 @@ def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_s ...@@ -196,7 +195,7 @@ def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_s
@torch.no_grad() @torch.no_grad()
def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None, def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,
corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False): corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False):
log = dict() log = {}
z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key, z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,
return_first_stage_outputs=True, return_first_stage_outputs=True,
...@@ -244,7 +243,7 @@ def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize ...@@ -244,7 +243,7 @@ def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize
x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True) x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
log["sample_noquant"] = x_sample_noquant log["sample_noquant"] = x_sample_noquant
log["sample_diff"] = torch.abs(x_sample_noquant - x_sample) log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
except: except Exception:
pass pass
log["sample"] = x_sample log["sample"] = x_sample
......
...@@ -7,7 +7,8 @@ from basicsr.utils.download_util import load_file_from_url ...@@ -7,7 +7,8 @@ from basicsr.utils.download_util import load_file_from_url
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
from ldsr_model_arch import LDSR from ldsr_model_arch import LDSR
from modules import shared, script_callbacks from modules import shared, script_callbacks
import sd_hijack_autoencoder, sd_hijack_ddpm_v1 import sd_hijack_autoencoder # noqa: F401
import sd_hijack_ddpm_v1 # noqa: F401
class UpscalerLDSR(Upscaler): class UpscalerLDSR(Upscaler):
...@@ -25,22 +26,28 @@ class UpscalerLDSR(Upscaler): ...@@ -25,22 +26,28 @@ class UpscalerLDSR(Upscaler):
yaml_path = os.path.join(self.model_path, "project.yaml") yaml_path = os.path.join(self.model_path, "project.yaml")
old_model_path = os.path.join(self.model_path, "model.pth") old_model_path = os.path.join(self.model_path, "model.pth")
new_model_path = os.path.join(self.model_path, "model.ckpt") new_model_path = os.path.join(self.model_path, "model.ckpt")
safetensors_model_path = os.path.join(self.model_path, "model.safetensors")
local_model_paths = self.find_models(ext_filter=[".ckpt", ".safetensors"])
local_ckpt_path = next(iter([local_model for local_model in local_model_paths if local_model.endswith("model.ckpt")]), None)
local_safetensors_path = next(iter([local_model for local_model in local_model_paths if local_model.endswith("model.safetensors")]), None)
local_yaml_path = next(iter([local_model for local_model in local_model_paths if local_model.endswith("project.yaml")]), None)
if os.path.exists(yaml_path): if os.path.exists(yaml_path):
statinfo = os.stat(yaml_path) statinfo = os.stat(yaml_path)
if statinfo.st_size >= 10485760: if statinfo.st_size >= 10485760:
print("Removing invalid LDSR YAML file.") print("Removing invalid LDSR YAML file.")
os.remove(yaml_path) os.remove(yaml_path)
if os.path.exists(old_model_path): if os.path.exists(old_model_path):
print("Renaming model from model.pth to model.ckpt") print("Renaming model from model.pth to model.ckpt")
os.rename(old_model_path, new_model_path) os.rename(old_model_path, new_model_path)
if os.path.exists(safetensors_model_path):
model = safetensors_model_path if local_safetensors_path is not None and os.path.exists(local_safetensors_path):
model = local_safetensors_path
else: else:
model = load_file_from_url(url=self.model_url, model_dir=self.model_path, model = local_ckpt_path if local_ckpt_path is not None else load_file_from_url(url=self.model_url, model_dir=self.model_download_path, file_name="model.ckpt", progress=True)
file_name="model.ckpt", progress=True)
yaml = load_file_from_url(url=self.yaml_url, model_dir=self.model_path, yaml = local_yaml_path if local_yaml_path is not None else load_file_from_url(url=self.yaml_url, model_dir=self.model_download_path, file_name="project.yaml", progress=True)
file_name="project.yaml", progress=True)
try: try:
return LDSR(model, yaml) return LDSR(model, yaml)
......
# The content of this file comes from the ldm/models/autoencoder.py file of the compvis/stable-diffusion repo # The content of this file comes from the ldm/models/autoencoder.py file of the compvis/stable-diffusion repo
# The VQModel & VQModelInterface were subsequently removed from ldm/models/autoencoder.py when we moved to the stability-ai/stablediffusion repo # The VQModel & VQModelInterface were subsequently removed from ldm/models/autoencoder.py when we moved to the stability-ai/stablediffusion repo
# As the LDSR upscaler relies on VQModel & VQModelInterface, the hijack aims to put them back into the ldm.models.autoencoder # As the LDSR upscaler relies on VQModel & VQModelInterface, the hijack aims to put them back into the ldm.models.autoencoder
import numpy as np
import torch import torch
import pytorch_lightning as pl import pytorch_lightning as pl
import torch.nn.functional as F import torch.nn.functional as F
from contextlib import contextmanager from contextlib import contextmanager
from torch.optim.lr_scheduler import LambdaLR
from ldm.modules.ema import LitEma
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
from ldm.modules.diffusionmodules.model import Encoder, Decoder from ldm.modules.diffusionmodules.model import Encoder, Decoder
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
import ldm.models.autoencoder import ldm.models.autoencoder
from packaging import version
class VQModel(pl.LightningModule): class VQModel(pl.LightningModule):
def __init__(self, def __init__(self,
...@@ -19,7 +24,7 @@ class VQModel(pl.LightningModule): ...@@ -19,7 +24,7 @@ class VQModel(pl.LightningModule):
n_embed, n_embed,
embed_dim, embed_dim,
ckpt_path=None, ckpt_path=None,
ignore_keys=[], ignore_keys=None,
image_key="image", image_key="image",
colorize_nlabels=None, colorize_nlabels=None,
monitor=None, monitor=None,
...@@ -57,7 +62,7 @@ class VQModel(pl.LightningModule): ...@@ -57,7 +62,7 @@ class VQModel(pl.LightningModule):
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None: if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [])
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.lr_g_factor = lr_g_factor self.lr_g_factor = lr_g_factor
...@@ -76,11 +81,11 @@ class VQModel(pl.LightningModule): ...@@ -76,11 +81,11 @@ class VQModel(pl.LightningModule):
if context is not None: if context is not None:
print(f"{context}: Restored training weights") print(f"{context}: Restored training weights")
def init_from_ckpt(self, path, ignore_keys=list()): def init_from_ckpt(self, path, ignore_keys=None):
sd = torch.load(path, map_location="cpu")["state_dict"] sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys()) keys = list(sd.keys())
for k in keys: for k in keys:
for ik in ignore_keys: for ik in ignore_keys or []:
if k.startswith(ik): if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k)) print("Deleting key {} from state_dict.".format(k))
del sd[k] del sd[k]
...@@ -165,7 +170,7 @@ class VQModel(pl.LightningModule): ...@@ -165,7 +170,7 @@ class VQModel(pl.LightningModule):
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
log_dict = self._validation_step(batch, batch_idx) log_dict = self._validation_step(batch, batch_idx)
with self.ema_scope(): with self.ema_scope():
log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema") self._validation_step(batch, batch_idx, suffix="_ema")
return log_dict return log_dict
def _validation_step(self, batch, batch_idx, suffix=""): def _validation_step(self, batch, batch_idx, suffix=""):
...@@ -232,7 +237,7 @@ class VQModel(pl.LightningModule): ...@@ -232,7 +237,7 @@ class VQModel(pl.LightningModule):
return self.decoder.conv_out.weight return self.decoder.conv_out.weight
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs): def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
log = dict() log = {}
x = self.get_input(batch, self.image_key) x = self.get_input(batch, self.image_key)
x = x.to(self.device) x = x.to(self.device)
if only_inputs: if only_inputs:
...@@ -249,7 +254,8 @@ class VQModel(pl.LightningModule): ...@@ -249,7 +254,8 @@ class VQModel(pl.LightningModule):
if plot_ema: if plot_ema:
with self.ema_scope(): with self.ema_scope():
xrec_ema, _ = self(x) xrec_ema, _ = self(x)
if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema) if x.shape[1] > 3:
xrec_ema = self.to_rgb(xrec_ema)
log["reconstructions_ema"] = xrec_ema log["reconstructions_ema"] = xrec_ema
return log return log
...@@ -264,7 +270,7 @@ class VQModel(pl.LightningModule): ...@@ -264,7 +270,7 @@ class VQModel(pl.LightningModule):
class VQModelInterface(VQModel): class VQModelInterface(VQModel):
def __init__(self, embed_dim, *args, **kwargs): def __init__(self, embed_dim, *args, **kwargs):
super().__init__(embed_dim=embed_dim, *args, **kwargs) super().__init__(*args, embed_dim=embed_dim, **kwargs)
self.embed_dim = embed_dim self.embed_dim = embed_dim
def encode(self, x): def encode(self, x):
...@@ -282,5 +288,5 @@ class VQModelInterface(VQModel): ...@@ -282,5 +288,5 @@ class VQModelInterface(VQModel):
dec = self.decoder(quant) dec = self.decoder(quant)
return dec return dec
setattr(ldm.models.autoencoder, "VQModel", VQModel) ldm.models.autoencoder.VQModel = VQModel
setattr(ldm.models.autoencoder, "VQModelInterface", VQModelInterface) ldm.models.autoencoder.VQModelInterface = VQModelInterface
...@@ -48,7 +48,7 @@ class DDPMV1(pl.LightningModule): ...@@ -48,7 +48,7 @@ class DDPMV1(pl.LightningModule):
beta_schedule="linear", beta_schedule="linear",
loss_type="l2", loss_type="l2",
ckpt_path=None, ckpt_path=None,
ignore_keys=[], ignore_keys=None,
load_only_unet=False, load_only_unet=False,
monitor="val/loss", monitor="val/loss",
use_ema=True, use_ema=True,
...@@ -100,7 +100,7 @@ class DDPMV1(pl.LightningModule): ...@@ -100,7 +100,7 @@ class DDPMV1(pl.LightningModule):
if monitor is not None: if monitor is not None:
self.monitor = monitor self.monitor = monitor
if ckpt_path is not None: if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [], only_model=load_only_unet)
self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
...@@ -182,13 +182,13 @@ class DDPMV1(pl.LightningModule): ...@@ -182,13 +182,13 @@ class DDPMV1(pl.LightningModule):
if context is not None: if context is not None:
print(f"{context}: Restored training weights") print(f"{context}: Restored training weights")
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): def init_from_ckpt(self, path, ignore_keys=None, only_model=False):
sd = torch.load(path, map_location="cpu") sd = torch.load(path, map_location="cpu")
if "state_dict" in list(sd.keys()): if "state_dict" in list(sd.keys()):
sd = sd["state_dict"] sd = sd["state_dict"]
keys = list(sd.keys()) keys = list(sd.keys())
for k in keys: for k in keys:
for ik in ignore_keys: for ik in ignore_keys or []:
if k.startswith(ik): if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k)) print("Deleting key {} from state_dict.".format(k))
del sd[k] del sd[k]
...@@ -375,7 +375,7 @@ class DDPMV1(pl.LightningModule): ...@@ -375,7 +375,7 @@ class DDPMV1(pl.LightningModule):
@torch.no_grad() @torch.no_grad()
def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
log = dict() log = {}
x = self.get_input(batch, self.first_stage_key) x = self.get_input(batch, self.first_stage_key)
N = min(x.shape[0], N) N = min(x.shape[0], N)
n_row = min(x.shape[0], n_row) n_row = min(x.shape[0], n_row)
...@@ -383,7 +383,7 @@ class DDPMV1(pl.LightningModule): ...@@ -383,7 +383,7 @@ class DDPMV1(pl.LightningModule):
log["inputs"] = x log["inputs"] = x
# get diffusion row # get diffusion row
diffusion_row = list() diffusion_row = []
x_start = x[:n_row] x_start = x[:n_row]
for t in range(self.num_timesteps): for t in range(self.num_timesteps):
...@@ -444,13 +444,13 @@ class LatentDiffusionV1(DDPMV1): ...@@ -444,13 +444,13 @@ class LatentDiffusionV1(DDPMV1):
conditioning_key = None conditioning_key = None
ckpt_path = kwargs.pop("ckpt_path", None) ckpt_path = kwargs.pop("ckpt_path", None)
ignore_keys = kwargs.pop("ignore_keys", []) ignore_keys = kwargs.pop("ignore_keys", [])
super().__init__(conditioning_key=conditioning_key, *args, **kwargs) super().__init__(*args, conditioning_key=conditioning_key, **kwargs)
self.concat_mode = concat_mode self.concat_mode = concat_mode
self.cond_stage_trainable = cond_stage_trainable self.cond_stage_trainable = cond_stage_trainable
self.cond_stage_key = cond_stage_key self.cond_stage_key = cond_stage_key
try: try:
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
except: except Exception:
self.num_downs = 0 self.num_downs = 0
if not scale_by_std: if not scale_by_std:
self.scale_factor = scale_factor self.scale_factor = scale_factor
...@@ -877,16 +877,6 @@ class LatentDiffusionV1(DDPMV1): ...@@ -877,16 +877,6 @@ class LatentDiffusionV1(DDPMV1):
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
return self.p_losses(x, c, t, *args, **kwargs) return self.p_losses(x, c, t, *args, **kwargs)
def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
def rescale_bbox(bbox):
x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
w = min(bbox[2] / crop_coordinates[2], 1 - x0)
h = min(bbox[3] / crop_coordinates[3], 1 - y0)
return x0, y0, w, h
return [rescale_bbox(b) for b in bboxes]
def apply_model(self, x_noisy, t, cond, return_ids=False): def apply_model(self, x_noisy, t, cond, return_ids=False):
if isinstance(cond, dict): if isinstance(cond, dict):
...@@ -1126,7 +1116,7 @@ class LatentDiffusionV1(DDPMV1): ...@@ -1126,7 +1116,7 @@ class LatentDiffusionV1(DDPMV1):
if cond is not None: if cond is not None:
if isinstance(cond, dict): if isinstance(cond, dict):
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
list(map(lambda x: x[:batch_size], cond[key])) for key in cond} [x[:batch_size] for x in cond[key]] for key in cond}
else: else:
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
...@@ -1157,8 +1147,10 @@ class LatentDiffusionV1(DDPMV1): ...@@ -1157,8 +1147,10 @@ class LatentDiffusionV1(DDPMV1):
if i % log_every_t == 0 or i == timesteps - 1: if i % log_every_t == 0 or i == timesteps - 1:
intermediates.append(x0_partial) intermediates.append(x0_partial)
if callback: callback(i) if callback:
if img_callback: img_callback(img, i) callback(i)
if img_callback:
img_callback(img, i)
return img, intermediates return img, intermediates
@torch.no_grad() @torch.no_grad()
...@@ -1205,8 +1197,10 @@ class LatentDiffusionV1(DDPMV1): ...@@ -1205,8 +1197,10 @@ class LatentDiffusionV1(DDPMV1):
if i % log_every_t == 0 or i == timesteps - 1: if i % log_every_t == 0 or i == timesteps - 1:
intermediates.append(img) intermediates.append(img)
if callback: callback(i) if callback:
if img_callback: img_callback(img, i) callback(i)
if img_callback:
img_callback(img, i)
if return_intermediates: if return_intermediates:
return img, intermediates return img, intermediates
...@@ -1221,7 +1215,7 @@ class LatentDiffusionV1(DDPMV1): ...@@ -1221,7 +1215,7 @@ class LatentDiffusionV1(DDPMV1):
if cond is not None: if cond is not None:
if isinstance(cond, dict): if isinstance(cond, dict):
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
list(map(lambda x: x[:batch_size], cond[key])) for key in cond} [x[:batch_size] for x in cond[key]] for key in cond}
else: else:
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
return self.p_sample_loop(cond, return self.p_sample_loop(cond,
...@@ -1253,7 +1247,7 @@ class LatentDiffusionV1(DDPMV1): ...@@ -1253,7 +1247,7 @@ class LatentDiffusionV1(DDPMV1):
use_ddim = ddim_steps is not None use_ddim = ddim_steps is not None
log = dict() log = {}
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
return_first_stage_outputs=True, return_first_stage_outputs=True,
force_c_encode=True, force_c_encode=True,
...@@ -1280,7 +1274,7 @@ class LatentDiffusionV1(DDPMV1): ...@@ -1280,7 +1274,7 @@ class LatentDiffusionV1(DDPMV1):
if plot_diffusion_rows: if plot_diffusion_rows:
# get diffusion row # get diffusion row
diffusion_row = list() diffusion_row = []
z_start = z[:n_row] z_start = z[:n_row]
for t in range(self.num_timesteps): for t in range(self.num_timesteps):
if t % self.log_every_t == 0 or t == self.num_timesteps - 1: if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
...@@ -1322,7 +1316,7 @@ class LatentDiffusionV1(DDPMV1): ...@@ -1322,7 +1316,7 @@ class LatentDiffusionV1(DDPMV1):
if inpaint: if inpaint:
# make a simple center square # make a simple center square
b, h, w = z.shape[0], z.shape[2], z.shape[3] h, w = z.shape[2], z.shape[3]
mask = torch.ones(N, h, w).to(self.device) mask = torch.ones(N, h, w).to(self.device)
# zeros will be filled in # zeros will be filled in
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
...@@ -1424,10 +1418,10 @@ class Layout2ImgDiffusionV1(LatentDiffusionV1): ...@@ -1424,10 +1418,10 @@ class Layout2ImgDiffusionV1(LatentDiffusionV1):
# TODO: move all layout-specific hacks to this class # TODO: move all layout-specific hacks to this class
def __init__(self, cond_stage_key, *args, **kwargs): def __init__(self, cond_stage_key, *args, **kwargs):
assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"' assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs) super().__init__(*args, cond_stage_key=cond_stage_key, **kwargs)
def log_images(self, batch, N=8, *args, **kwargs): def log_images(self, batch, N=8, *args, **kwargs):
logs = super().log_images(batch=batch, N=N, *args, **kwargs) logs = super().log_images(*args, batch=batch, N=N, **kwargs)
key = 'train' if self.training else 'validation' key = 'train' if self.training else 'validation'
dset = self.trainer.datamodule.datasets[key] dset = self.trainer.datamodule.datasets[key]
...@@ -1443,7 +1437,7 @@ class Layout2ImgDiffusionV1(LatentDiffusionV1): ...@@ -1443,7 +1437,7 @@ class Layout2ImgDiffusionV1(LatentDiffusionV1):
logs['bbox_image'] = cond_img logs['bbox_image'] = cond_img
return logs return logs
setattr(ldm.models.diffusion.ddpm, "DDPMV1", DDPMV1) ldm.models.diffusion.ddpm.DDPMV1 = DDPMV1
setattr(ldm.models.diffusion.ddpm, "LatentDiffusionV1", LatentDiffusionV1) ldm.models.diffusion.ddpm.LatentDiffusionV1 = LatentDiffusionV1
setattr(ldm.models.diffusion.ddpm, "DiffusionWrapperV1", DiffusionWrapperV1) ldm.models.diffusion.ddpm.DiffusionWrapperV1 = DiffusionWrapperV1
setattr(ldm.models.diffusion.ddpm, "Layout2ImgDiffusionV1", Layout2ImgDiffusionV1) ldm.models.diffusion.ddpm.Layout2ImgDiffusionV1 = Layout2ImgDiffusionV1
from modules import extra_networks, shared from modules import extra_networks, shared
import lora import lora
class ExtraNetworkLora(extra_networks.ExtraNetwork): class ExtraNetworkLora(extra_networks.ExtraNetwork):
def __init__(self): def __init__(self):
super().__init__('lora') super().__init__('lora')
...@@ -8,7 +9,7 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork): ...@@ -8,7 +9,7 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
def activate(self, p, params_list): def activate(self, p, params_list):
additional = shared.opts.sd_lora additional = shared.opts.sd_lora
if additional != "" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0: if additional != "None" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0:
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts] p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier])) params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
...@@ -22,5 +23,23 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork): ...@@ -22,5 +23,23 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
lora.load_loras(names, multipliers) lora.load_loras(names, multipliers)
if shared.opts.lora_add_hashes_to_infotext:
lora_hashes = []
for item in lora.loaded_loras:
shorthash = item.lora_on_disk.shorthash
if not shorthash:
continue
alias = item.mentioned_name
if not alias:
continue
alias = alias.replace(":", "").replace(",", "")
lora_hashes.append(f"{alias}: {shorthash}")
if lora_hashes:
p.extra_generation_params["Lora hashes"] = ", ".join(lora_hashes)
def deactivate(self, p): def deactivate(self, p):
pass pass
This diff is collapsed.
import re
import torch import torch
import gradio as gr import gradio as gr
from fastapi import FastAPI
import lora import lora
import extra_networks_lora import extra_networks_lora
import ui_extra_networks_lora import ui_extra_networks_lora
from modules import script_callbacks, ui_extra_networks, extra_networks, shared from modules import script_callbacks, ui_extra_networks, extra_networks, shared
def unload(): def unload():
torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora
torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora
...@@ -49,8 +51,66 @@ torch.nn.MultiheadAttention._load_from_state_dict = lora.lora_MultiheadAttention ...@@ -49,8 +51,66 @@ torch.nn.MultiheadAttention._load_from_state_dict = lora.lora_MultiheadAttention
script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules) script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
script_callbacks.on_script_unloaded(unload) script_callbacks.on_script_unloaded(unload)
script_callbacks.on_before_ui(before_ui) script_callbacks.on_before_ui(before_ui)
script_callbacks.on_infotext_pasted(lora.infotext_pasted)
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), { shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras), "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None", *lora.available_loras]}, refresh=lora.list_available_loras),
"lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to Lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}),
"lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"),
}))
shared.options_templates.update(shared.options_section(('compatibility', "Compatibility"), {
"lora_functional": shared.OptionInfo(False, "Lora: use old method that takes longer when you have multiple Loras active and produces same results as kohya-ss/sd-webui-additional-networks extension"),
})) }))
def create_lora_json(obj: lora.LoraOnDisk):
return {
"name": obj.name,
"alias": obj.alias,
"path": obj.filename,
"metadata": obj.metadata,
}
def api_loras(_: gr.Blocks, app: FastAPI):
@app.get("/sdapi/v1/loras")
async def get_loras():
return [create_lora_json(obj) for obj in lora.available_loras.values()]
@app.post("/sdapi/v1/refresh-loras")
async def refresh_loras():
return lora.list_available_loras()
script_callbacks.on_app_started(api_loras)
re_lora = re.compile("<lora:([^:]+):")
def infotext_pasted(infotext, d):
hashes = d.get("Lora hashes")
if not hashes:
return
hashes = [x.strip().split(':', 1) for x in hashes.split(",")]
hashes = {x[0].strip().replace(",", ""): x[1].strip() for x in hashes}
def lora_replacement(m):
alias = m.group(1)
shorthash = hashes.get(alias)
if shorthash is None:
return m.group(0)
lora_on_disk = lora.available_lora_hash_lookup.get(shorthash)
if lora_on_disk is None:
return m.group(0)
return f'<lora:{lora_on_disk.get_alias()}:'
d["Prompt"] = re.sub(re_lora, lora_replacement, d["Prompt"])
script_callbacks.on_infotext_pasted(infotext_pasted)
...@@ -15,13 +15,16 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): ...@@ -15,13 +15,16 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
def list_items(self): def list_items(self):
for name, lora_on_disk in lora.available_loras.items(): for name, lora_on_disk in lora.available_loras.items():
path, ext = os.path.splitext(lora_on_disk.filename) path, ext = os.path.splitext(lora_on_disk.filename)
alias = lora_on_disk.get_alias()
yield { yield {
"name": name, "name": name,
"filename": path, "filename": path,
"preview": self.find_preview(path), "preview": self.find_preview(path),
"description": self.find_description(path), "description": self.find_description(path),
"search_term": self.search_terms_from_path(lora_on_disk.filename), "search_term": self.search_terms_from_path(lora_on_disk.filename),
"prompt": json.dumps(f"<lora:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"), "prompt": json.dumps(f"<lora:{alias}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
"local_preview": f"{path}.{shared.opts.samples_format}", "local_preview": f"{path}.{shared.opts.samples_format}",
"metadata": json.dumps(lora_on_disk.metadata, indent=4) if lora_on_disk.metadata else None, "metadata": json.dumps(lora_on_disk.metadata, indent=4) if lora_on_disk.metadata else None,
} }
......
...@@ -5,11 +5,14 @@ import traceback ...@@ -5,11 +5,14 @@ import traceback
import PIL.Image import PIL.Image
import numpy as np import numpy as np
import torch import torch
from tqdm import tqdm
from basicsr.utils.download_util import load_file_from_url from basicsr.utils.download_util import load_file_from_url
import modules.upscaler import modules.upscaler
from modules import devices, modelloader from modules import devices, modelloader, script_callbacks
from scunet_model_arch import SCUNet as net from scunet_model_arch import SCUNet as net
from modules.shared import opts
class UpscalerScuNET(modules.upscaler.Upscaler): class UpscalerScuNET(modules.upscaler.Upscaler):
...@@ -42,34 +45,83 @@ class UpscalerScuNET(modules.upscaler.Upscaler): ...@@ -42,34 +45,83 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
scalers.append(scaler_data2) scalers.append(scaler_data2)
self.scalers = scalers self.scalers = scalers
def do_upscale(self, img: PIL.Image, selected_file): @staticmethod
@torch.no_grad()
def tiled_inference(img, model):
# test the image tile by tile
h, w = img.shape[2:]
tile = opts.SCUNET_tile
tile_overlap = opts.SCUNET_tile_overlap
if tile == 0:
return model(img)
device = devices.get_device_for('scunet')
assert tile % 8 == 0, "tile size should be a multiple of window_size"
sf = 1
stride = tile - tile_overlap
h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
E = torch.zeros(1, 3, h * sf, w * sf, dtype=img.dtype, device=device)
W = torch.zeros_like(E, dtype=devices.dtype, device=device)
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="ScuNET tiles") as pbar:
for h_idx in h_idx_list:
for w_idx in w_idx_list:
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
out_patch = model(in_patch)
out_patch_mask = torch.ones_like(out_patch)
E[
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
].add_(out_patch)
W[
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
].add_(out_patch_mask)
pbar.update(1)
output = E.div_(W)
return output
def do_upscale(self, img: PIL.Image.Image, selected_file):
torch.cuda.empty_cache() torch.cuda.empty_cache()
model = self.load_model(selected_file) model = self.load_model(selected_file)
if model is None: if model is None:
print(f"ScuNET: Unable to load model from {selected_file}", file=sys.stderr)
return img return img
device = devices.get_device_for('scunet') device = devices.get_device_for('scunet')
img = np.array(img) tile = opts.SCUNET_tile
img = img[:, :, ::-1] h, w = img.height, img.width
img = np.moveaxis(img, 2, 0) / 255 np_img = np.array(img)
img = torch.from_numpy(img).float() np_img = np_img[:, :, ::-1] # RGB to BGR
img = img.unsqueeze(0).to(device) np_img = np_img.transpose((2, 0, 1)) / 255 # HWC to CHW
torch_img = torch.from_numpy(np_img).float().unsqueeze(0).to(device) # type: ignore
with torch.no_grad():
output = model(img) if tile > h or tile > w:
output = output.squeeze().float().cpu().clamp_(0, 1).numpy() _img = torch.zeros(1, 3, max(h, tile), max(w, tile), dtype=torch_img.dtype, device=torch_img.device)
output = 255. * np.moveaxis(output, 0, 2) _img[:, :, :h, :w] = torch_img # pad image
output = output.astype(np.uint8) torch_img = _img
output = output[:, :, ::-1]
torch_output = self.tiled_inference(torch_img, model).squeeze(0)
torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any
np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy()
del torch_img, torch_output
torch.cuda.empty_cache() torch.cuda.empty_cache()
return PIL.Image.fromarray(output, 'RGB')
output = np_output.transpose((1, 2, 0)) # CHW to HWC
output = output[:, :, ::-1] # BGR to RGB
return PIL.Image.fromarray((output * 255).astype(np.uint8))
def load_model(self, path: str): def load_model(self, path: str):
device = devices.get_device_for('scunet') device = devices.get_device_for('scunet')
if "http" in path: if "http" in path:
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name, filename = load_file_from_url(url=self.model_url, model_dir=self.model_download_path, file_name="%s.pth" % self.name, progress=True)
progress=True)
else: else:
filename = path filename = path
if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None: if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None:
...@@ -79,9 +131,19 @@ class UpscalerScuNET(modules.upscaler.Upscaler): ...@@ -79,9 +131,19 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64) model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
model.load_state_dict(torch.load(filename), strict=True) model.load_state_dict(torch.load(filename), strict=True)
model.eval() model.eval()
for k, v in model.named_parameters(): for _, v in model.named_parameters():
v.requires_grad = False v.requires_grad = False
model = model.to(device) model = model.to(device)
return model return model
def on_ui_settings():
import gradio as gr
from modules import shared
shared.opts.add_option("SCUNET_tile", shared.OptionInfo(256, "Tile size for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")).info("0 = no tiling"))
shared.opts.add_option("SCUNET_tile_overlap", shared.OptionInfo(8, "Tile overlap for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, section=('upscaling', "Upscaling")).info("Low values = visible seam"))
script_callbacks.on_ui_settings(on_ui_settings)
...@@ -61,7 +61,9 @@ class WMSA(nn.Module): ...@@ -61,7 +61,9 @@ class WMSA(nn.Module):
Returns: Returns:
output: tensor shape [b h w c] output: tensor shape [b h w c]
""" """
if self.type != 'W': x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2)) if self.type != 'W':
x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2))
x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size) x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size)
h_windows = x.size(1) h_windows = x.size(1)
w_windows = x.size(2) w_windows = x.size(2)
...@@ -85,8 +87,9 @@ class WMSA(nn.Module): ...@@ -85,8 +87,9 @@ class WMSA(nn.Module):
output = self.linear(output) output = self.linear(output)
output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size) output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size)
if self.type != 'W': output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2), if self.type != 'W':
dims=(1, 2)) output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2), dims=(1, 2))
return output return output
def relative_embedding(self): def relative_embedding(self):
......
import contextlib
import os import os
import numpy as np import numpy as np
...@@ -8,7 +7,7 @@ from basicsr.utils.download_util import load_file_from_url ...@@ -8,7 +7,7 @@ from basicsr.utils.download_util import load_file_from_url
from tqdm import tqdm from tqdm import tqdm
from modules import modelloader, devices, script_callbacks, shared from modules import modelloader, devices, script_callbacks, shared
from modules.shared import cmd_opts, opts, state from modules.shared import opts, state
from swinir_model_arch import SwinIR as net from swinir_model_arch import SwinIR as net
from swinir_model_arch_v2 import Swin2SR as net2 from swinir_model_arch_v2 import Swin2SR as net2
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
...@@ -45,14 +44,14 @@ class UpscalerSwinIR(Upscaler): ...@@ -45,14 +44,14 @@ class UpscalerSwinIR(Upscaler):
img = upscale(img, model) img = upscale(img, model)
try: try:
torch.cuda.empty_cache() torch.cuda.empty_cache()
except: except Exception:
pass pass
return img return img
def load_model(self, path, scale=4): def load_model(self, path, scale=4):
if "http" in path: if "http" in path:
dl_name = "%s%s" % (self.model_name.replace(" ", "_"), ".pth") dl_name = "%s%s" % (self.model_name.replace(" ", "_"), ".pth")
filename = load_file_from_url(url=path, model_dir=self.model_path, file_name=dl_name, progress=True) filename = load_file_from_url(url=path, model_dir=self.model_download_path, file_name=dl_name, progress=True)
else: else:
filename = path filename = path
if filename is None or not os.path.exists(filename): if filename is None or not os.path.exists(filename):
......
...@@ -644,7 +644,7 @@ class SwinIR(nn.Module): ...@@ -644,7 +644,7 @@ class SwinIR(nn.Module):
""" """
def __init__(self, img_size=64, patch_size=1, in_chans=3, def __init__(self, img_size=64, patch_size=1, in_chans=3,
embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True, norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
...@@ -844,7 +844,7 @@ class SwinIR(nn.Module): ...@@ -844,7 +844,7 @@ class SwinIR(nn.Module):
H, W = self.patches_resolution H, W = self.patches_resolution
flops += H * W * 3 * self.embed_dim * 9 flops += H * W * 3 * self.embed_dim * 9
flops += self.patch_embed.flops() flops += self.patch_embed.flops()
for i, layer in enumerate(self.layers): for layer in self.layers:
flops += layer.flops() flops += layer.flops()
flops += H * W * 3 * self.embed_dim * self.embed_dim flops += H * W * 3 * self.embed_dim * self.embed_dim
flops += self.upsample.flops() flops += self.upsample.flops()
......
...@@ -74,7 +74,7 @@ class WindowAttention(nn.Module): ...@@ -74,7 +74,7 @@ class WindowAttention(nn.Module):
""" """
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0., def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
pretrained_window_size=[0, 0]): pretrained_window_size=(0, 0)):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
...@@ -698,7 +698,7 @@ class Swin2SR(nn.Module): ...@@ -698,7 +698,7 @@ class Swin2SR(nn.Module):
""" """
def __init__(self, img_size=64, patch_size=1, in_chans=3, def __init__(self, img_size=64, patch_size=1, in_chans=3,
embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
window_size=7, mlp_ratio=4., qkv_bias=True, window_size=7, mlp_ratio=4., qkv_bias=True,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True, norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
...@@ -994,7 +994,7 @@ class Swin2SR(nn.Module): ...@@ -994,7 +994,7 @@ class Swin2SR(nn.Module):
H, W = self.patches_resolution H, W = self.patches_resolution
flops += H * W * 3 * self.embed_dim * 9 flops += H * W * 3 * self.embed_dim * 9
flops += self.patch_embed.flops() flops += self.patch_embed.flops()
for i, layer in enumerate(self.layers): for layer in self.layers:
flops += layer.flops() flops += layer.flops()
flops += H * W * 3 * self.embed_dim * self.embed_dim flops += H * W * 3 * self.embed_dim * self.embed_dim
flops += self.upsample.flops() flops += self.upsample.flops()
......
// Stable Diffusion WebUI - Bracket checker // Stable Diffusion WebUI - Bracket checker
// Version 1.0 // By Hingashi no Florin/Bwin4L & @akx
// By Hingashi no Florin/Bwin4L
// Counts open and closed brackets (round, square, curly) in the prompt and negative prompt text boxes in the txt2img and img2img tabs. // Counts open and closed brackets (round, square, curly) in the prompt and negative prompt text boxes in the txt2img and img2img tabs.
// If there's a mismatch, the keyword counter turns red and if you hover on it, a tooltip tells you what's wrong. // If there's a mismatch, the keyword counter turns red and if you hover on it, a tooltip tells you what's wrong.
function checkBrackets(evt, textArea, counterElt) { function checkBrackets(textArea, counterElt) {
errorStringParen = '(...) - Different number of opening and closing parentheses detected.\n'; var counts = {};
errorStringSquare = '[...] - Different number of opening and closing square brackets detected.\n'; (textArea.value.match(/[(){}[\]]/g) || []).forEach(bracket => {
errorStringCurly = '{...} - Different number of opening and closing curly brackets detected.\n'; counts[bracket] = (counts[bracket] || 0) + 1;
});
openBracketRegExp = /\(/g; var errors = [];
closeBracketRegExp = /\)/g;
openSquareBracketRegExp = /\[/g;
closeSquareBracketRegExp = /\]/g;
openCurlyBracketRegExp = /\{/g;
closeCurlyBracketRegExp = /\}/g;
totalOpenBracketMatches = 0;
totalCloseBracketMatches = 0;
totalOpenSquareBracketMatches = 0;
totalCloseSquareBracketMatches = 0;
totalOpenCurlyBracketMatches = 0;
totalCloseCurlyBracketMatches = 0;
openBracketMatches = textArea.value.match(openBracketRegExp);
if(openBracketMatches) {
totalOpenBracketMatches = openBracketMatches.length;
}
closeBracketMatches = textArea.value.match(closeBracketRegExp);
if(closeBracketMatches) {
totalCloseBracketMatches = closeBracketMatches.length;
}
openSquareBracketMatches = textArea.value.match(openSquareBracketRegExp);
if(openSquareBracketMatches) {
totalOpenSquareBracketMatches = openSquareBracketMatches.length;
}
closeSquareBracketMatches = textArea.value.match(closeSquareBracketRegExp);
if(closeSquareBracketMatches) {
totalCloseSquareBracketMatches = closeSquareBracketMatches.length;
}
openCurlyBracketMatches = textArea.value.match(openCurlyBracketRegExp);
if(openCurlyBracketMatches) {
totalOpenCurlyBracketMatches = openCurlyBracketMatches.length;
}
closeCurlyBracketMatches = textArea.value.match(closeCurlyBracketRegExp);
if(closeCurlyBracketMatches) {
totalCloseCurlyBracketMatches = closeCurlyBracketMatches.length;
}
if(totalOpenBracketMatches != totalCloseBracketMatches) {
if(!counterElt.title.includes(errorStringParen)) {
counterElt.title += errorStringParen;
}
} else {
counterElt.title = counterElt.title.replace(errorStringParen, '');
}
if(totalOpenSquareBracketMatches != totalCloseSquareBracketMatches) {
if(!counterElt.title.includes(errorStringSquare)) {
counterElt.title += errorStringSquare;
}
} else {
counterElt.title = counterElt.title.replace(errorStringSquare, '');
}
if(totalOpenCurlyBracketMatches != totalCloseCurlyBracketMatches) { function checkPair(open, close, kind) {
if(!counterElt.title.includes(errorStringCurly)) { if (counts[open] !== counts[close]) {
counterElt.title += errorStringCurly; errors.push(
`${open}...${close} - Detected ${counts[open] || 0} opening and ${counts[close] || 0} closing ${kind}.`
);
} }
} else {
counterElt.title = counterElt.title.replace(errorStringCurly, '');
} }
if(counterElt.title != '') { checkPair('(', ')', 'round brackets');
counterElt.classList.add('error'); checkPair('[', ']', 'square brackets');
} else { checkPair('{', '}', 'curly brackets');
counterElt.classList.remove('error'); counterElt.title = errors.join('\n');
} counterElt.classList.toggle('error', errors.length !== 0);
} }
function setupBracketChecking(id_prompt, id_counter){ function setupBracketChecking(id_prompt, id_counter) {
var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea"); var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea");
var counter = gradioApp().getElementById(id_counter) var counter = gradioApp().getElementById(id_counter);
textarea.addEventListener("input", function(evt){ if (textarea && counter) {
checkBrackets(evt, textarea, counter) textarea.addEventListener("input", () => checkBrackets(textarea, counter));
}); }
} }
onUiLoaded(function(){ onUiLoaded(function() {
setupBracketChecking('txt2img_prompt', 'txt2img_token_counter') setupBracketChecking('txt2img_prompt', 'txt2img_token_counter');
setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter') setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter');
setupBracketChecking('img2img_prompt', 'img2img_token_counter') setupBracketChecking('img2img_prompt', 'img2img_token_counter');
setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter') setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter');
}) });
\ No newline at end of file
<div class='card' style={style} onclick={card_clicked}> <div class='card' style={style} onclick={card_clicked}>
{background_image}
{metadata_button} {metadata_button}
<div class='actions'> <div class='actions'>
<div class='additional'> <div class='additional'>
<ul> <ul>
<a href="#" title="replace preview image with currently selected in gallery" onclick={save_card_preview}>replace preview</a> <a href="#" title="replace preview image with currently selected in gallery" onclick={save_card_preview}>replace preview</a>
</ul> </ul>
<span style="display:none" class='search_term'>{search_term}</span> <span style="display:none" class='search_term{search_only}'>{search_term}</span>
</div> </div>
<span class='name'>{name}</span> <span class='name'>{name}</span>
<span class='description'>{description}</span> <span class='description'>{description}</span>
</div> </div>
</div> </div>
...@@ -662,3 +662,29 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, ...@@ -662,3 +662,29 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE. THE SOFTWARE.
</pre> </pre>
<h2><a href="https://github.com/madebyollin/taesd/blob/main/LICENSE">TAESD</a></h2>
<small>Tiny AutoEncoder for Stable Diffusion option for live previews</small>
<pre>
MIT License
Copyright (c) 2023 Ollin Boer Bohan
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
</pre>
\ No newline at end of file
let currentWidth = null; let currentWidth = null;
let currentHeight = null; let currentHeight = null;
let arFrameTimeout = setTimeout(function(){},0); let arFrameTimeout = setTimeout(function() {}, 0);
function dimensionChange(e, is_width, is_height){ function dimensionChange(e, is_width, is_height) {
if(is_width){ if (is_width) {
currentWidth = e.target.value*1.0 currentWidth = e.target.value * 1.0;
} }
if(is_height){ if (is_height) {
currentHeight = e.target.value*1.0 currentHeight = e.target.value * 1.0;
} }
var inImg2img = gradioApp().querySelector("#tab_img2img").style.display == "block"; var inImg2img = gradioApp().querySelector("#tab_img2img").style.display == "block";
if(!inImg2img){ if (!inImg2img) {
return; return;
} }
var targetElement = null; var targetElement = null;
var tabIndex = get_tab_index('mode_img2img') var tabIndex = get_tab_index('mode_img2img');
if(tabIndex == 0){ // img2img if (tabIndex == 0) { // img2img
targetElement = gradioApp().querySelector('#img2img_image div[data-testid=image] img'); targetElement = gradioApp().querySelector('#img2img_image div[data-testid=image] img');
} else if(tabIndex == 1){ //Sketch } else if (tabIndex == 1) { //Sketch
targetElement = gradioApp().querySelector('#img2img_sketch div[data-testid=image] img'); targetElement = gradioApp().querySelector('#img2img_sketch div[data-testid=image] img');
} else if(tabIndex == 2){ // Inpaint } else if (tabIndex == 2) { // Inpaint
targetElement = gradioApp().querySelector('#img2maskimg div[data-testid=image] img'); targetElement = gradioApp().querySelector('#img2maskimg div[data-testid=image] img');
} else if(tabIndex == 3){ // Inpaint sketch } else if (tabIndex == 3) { // Inpaint sketch
targetElement = gradioApp().querySelector('#inpaint_sketch div[data-testid=image] img'); targetElement = gradioApp().querySelector('#inpaint_sketch div[data-testid=image] img');
} }
if(targetElement){ if (targetElement) {
var arPreviewRect = gradioApp().querySelector('#imageARPreview'); var arPreviewRect = gradioApp().querySelector('#imageARPreview');
if(!arPreviewRect){ if (!arPreviewRect) {
arPreviewRect = document.createElement('div') arPreviewRect = document.createElement('div');
arPreviewRect.id = "imageARPreview"; arPreviewRect.id = "imageARPreview";
gradioApp().appendChild(arPreviewRect) gradioApp().appendChild(arPreviewRect);
} }
var viewportOffset = targetElement.getBoundingClientRect(); var viewportOffset = targetElement.getBoundingClientRect();
viewportscale = Math.min( targetElement.clientWidth/targetElement.naturalWidth, targetElement.clientHeight/targetElement.naturalHeight ) var viewportscale = Math.min(targetElement.clientWidth / targetElement.naturalWidth, targetElement.clientHeight / targetElement.naturalHeight);
scaledx = targetElement.naturalWidth*viewportscale var scaledx = targetElement.naturalWidth * viewportscale;
scaledy = targetElement.naturalHeight*viewportscale var scaledy = targetElement.naturalHeight * viewportscale;
cleintRectTop = (viewportOffset.top+window.scrollY) var cleintRectTop = (viewportOffset.top + window.scrollY);
cleintRectLeft = (viewportOffset.left+window.scrollX) var cleintRectLeft = (viewportOffset.left + window.scrollX);
cleintRectCentreY = cleintRectTop + (targetElement.clientHeight/2) var cleintRectCentreY = cleintRectTop + (targetElement.clientHeight / 2);
cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth/2) var cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth / 2);
viewRectTop = cleintRectCentreY-(scaledy/2) var arscale = Math.min(scaledx / currentWidth, scaledy / currentHeight);
viewRectLeft = cleintRectCentreX-(scaledx/2) var arscaledx = currentWidth * arscale;
arRectWidth = scaledx var arscaledy = currentHeight * arscale;
arRectHeight = scaledy
arscale = Math.min( arRectWidth/currentWidth, arRectHeight/currentHeight ) var arRectTop = cleintRectCentreY - (arscaledy / 2);
arscaledx = currentWidth*arscale var arRectLeft = cleintRectCentreX - (arscaledx / 2);
arscaledy = currentHeight*arscale var arRectWidth = arscaledx;
var arRectHeight = arscaledy;
arRectTop = cleintRectCentreY-(arscaledy/2) arPreviewRect.style.top = arRectTop + 'px';
arRectLeft = cleintRectCentreX-(arscaledx/2) arPreviewRect.style.left = arRectLeft + 'px';
arRectWidth = arscaledx arPreviewRect.style.width = arRectWidth + 'px';
arRectHeight = arscaledy arPreviewRect.style.height = arRectHeight + 'px';
arPreviewRect.style.top = arRectTop+'px';
arPreviewRect.style.left = arRectLeft+'px';
arPreviewRect.style.width = arRectWidth+'px';
arPreviewRect.style.height = arRectHeight+'px';
clearTimeout(arFrameTimeout); clearTimeout(arFrameTimeout);
arFrameTimeout = setTimeout(function(){ arFrameTimeout = setTimeout(function() {
arPreviewRect.style.display = 'none'; arPreviewRect.style.display = 'none';
},2000); }, 2000);
arPreviewRect.style.display = 'block'; arPreviewRect.style.display = 'block';
...@@ -86,31 +81,33 @@ function dimensionChange(e, is_width, is_height){ ...@@ -86,31 +81,33 @@ function dimensionChange(e, is_width, is_height){
} }
onUiUpdate(function(){ onUiUpdate(function() {
var arPreviewRect = gradioApp().querySelector('#imageARPreview'); var arPreviewRect = gradioApp().querySelector('#imageARPreview');
if(arPreviewRect){ if (arPreviewRect) {
arPreviewRect.style.display = 'none'; arPreviewRect.style.display = 'none';
} }
var tabImg2img = gradioApp().querySelector("#tab_img2img"); var tabImg2img = gradioApp().querySelector("#tab_img2img");
if (tabImg2img) { if (tabImg2img) {
var inImg2img = tabImg2img.style.display == "block"; var inImg2img = tabImg2img.style.display == "block";
if(inImg2img){ if (inImg2img) {
let inputs = gradioApp().querySelectorAll('input'); let inputs = gradioApp().querySelectorAll('input');
inputs.forEach(function(e){ inputs.forEach(function(e) {
var is_width = e.parentElement.id == "img2img_width" var is_width = e.parentElement.id == "img2img_width";
var is_height = e.parentElement.id == "img2img_height" var is_height = e.parentElement.id == "img2img_height";
if((is_width || is_height) && !e.classList.contains('scrollwatch')){ if ((is_width || is_height) && !e.classList.contains('scrollwatch')) {
e.addEventListener('input', function(e){dimensionChange(e, is_width, is_height)} ) e.addEventListener('input', function(e) {
e.classList.add('scrollwatch') dimensionChange(e, is_width, is_height);
});
e.classList.add('scrollwatch');
} }
if(is_width){ if (is_width) {
currentWidth = e.value*1.0 currentWidth = e.value * 1.0;
} }
if(is_height){ if (is_height) {
currentHeight = e.value*1.0 currentHeight = e.value * 1.0;
} }
}) });
} }
} }
}); });
This diff is collapsed.
// allows drag-dropping files into gradio image elements, and also pasting images from clipboard // allows drag-dropping files into gradio image elements, and also pasting images from clipboard
function isValidImageList( files ) { function isValidImageList(files) {
return files && files?.length === 1 && ['image/png', 'image/gif', 'image/jpeg'].includes(files[0].type); return files && files?.length === 1 && ['image/png', 'image/gif', 'image/jpeg'].includes(files[0].type);
} }
function dropReplaceImage( imgWrap, files ) { function dropReplaceImage(imgWrap, files) {
if ( ! isValidImageList( files ) ) { if (!isValidImageList(files)) {
return; return;
} }
...@@ -14,8 +14,8 @@ function dropReplaceImage( imgWrap, files ) { ...@@ -14,8 +14,8 @@ function dropReplaceImage( imgWrap, files ) {
imgWrap.querySelector('.modify-upload button + button, .touch-none + div button + button')?.click(); imgWrap.querySelector('.modify-upload button + button, .touch-none + div button + button')?.click();
const callback = () => { const callback = () => {
const fileInput = imgWrap.querySelector('input[type="file"]'); const fileInput = imgWrap.querySelector('input[type="file"]');
if ( fileInput ) { if (fileInput) {
if ( files.length === 0 ) { if (files.length === 0) {
files = new DataTransfer(); files = new DataTransfer();
files.items.add(tmpFile); files.items.add(tmpFile);
fileInput.files = files.files; fileInput.files = files.files;
...@@ -26,32 +26,32 @@ function dropReplaceImage( imgWrap, files ) { ...@@ -26,32 +26,32 @@ function dropReplaceImage( imgWrap, files ) {
} }
}; };
if ( imgWrap.closest('#pnginfo_image') ) { if (imgWrap.closest('#pnginfo_image')) {
// special treatment for PNG Info tab, wait for fetch request to finish // special treatment for PNG Info tab, wait for fetch request to finish
const oldFetch = window.fetch; const oldFetch = window.fetch;
window.fetch = async (input, options) => { window.fetch = async(input, options) => {
const response = await oldFetch(input, options); const response = await oldFetch(input, options);
if ( 'api/predict/' === input ) { if ('api/predict/' === input) {
const content = await response.text(); const content = await response.text();
window.fetch = oldFetch; window.fetch = oldFetch;
window.requestAnimationFrame( () => callback() ); window.requestAnimationFrame(() => callback());
return new Response(content, { return new Response(content, {
status: response.status, status: response.status,
statusText: response.statusText, statusText: response.statusText,
headers: response.headers headers: response.headers
}) });
} }
return response; return response;
}; };
} else { } else {
window.requestAnimationFrame( () => callback() ); window.requestAnimationFrame(() => callback());
} }
} }
window.document.addEventListener('dragover', e => { window.document.addEventListener('dragover', e => {
const target = e.composedPath()[0]; const target = e.composedPath()[0];
const imgWrap = target.closest('[data-testid="image"]'); const imgWrap = target.closest('[data-testid="image"]');
if ( !imgWrap && target.placeholder && target.placeholder.indexOf("Prompt") == -1) { if (!imgWrap && target.placeholder && target.placeholder.indexOf("Prompt") == -1) {
return; return;
} }
e.stopPropagation(); e.stopPropagation();
...@@ -65,24 +65,27 @@ window.document.addEventListener('drop', e => { ...@@ -65,24 +65,27 @@ window.document.addEventListener('drop', e => {
return; return;
} }
const imgWrap = target.closest('[data-testid="image"]'); const imgWrap = target.closest('[data-testid="image"]');
if ( !imgWrap ) { if (!imgWrap) {
return; return;
} }
e.stopPropagation(); e.stopPropagation();
e.preventDefault(); e.preventDefault();
const files = e.dataTransfer.files; const files = e.dataTransfer.files;
dropReplaceImage( imgWrap, files ); dropReplaceImage(imgWrap, files);
}); });
window.addEventListener('paste', e => { window.addEventListener('paste', e => {
const files = e.clipboardData.files; const files = e.clipboardData.files;
if ( ! isValidImageList( files ) ) { if (!isValidImageList(files)) {
return; return;
} }
const visibleImageFields = [...gradioApp().querySelectorAll('[data-testid="image"]')] const visibleImageFields = [...gradioApp().querySelectorAll('[data-testid="image"]')]
.filter(el => uiElementIsVisible(el)); .filter(el => uiElementIsVisible(el))
if ( ! visibleImageFields.length ) { .sort((a, b) => uiElementInSight(b) - uiElementInSight(a));
if (!visibleImageFields.length) {
return; return;
} }
...@@ -93,5 +96,6 @@ window.addEventListener('paste', e => { ...@@ -93,5 +96,6 @@ window.addEventListener('paste', e => {
firstFreeImageField ? firstFreeImageField ?
firstFreeImageField : firstFreeImageField :
visibleImageFields[visibleImageFields.length - 1] visibleImageFields[visibleImageFields.length - 1]
, files ); , files
);
}); });
function keyupEditAttention(event){ function keyupEditAttention(event) {
let target = event.originalTarget || event.composedPath()[0]; let target = event.originalTarget || event.composedPath()[0];
if (! target.matches("[id*='_toprow'] [id*='_prompt'] textarea")) return; if (!target.matches("*:is([id*='_toprow'] [id*='_prompt'], .prompt) textarea")) return;
if (! (event.metaKey || event.ctrlKey)) return; if (!(event.metaKey || event.ctrlKey)) return;
let isPlus = event.key == "ArrowUp" let isPlus = event.key == "ArrowUp";
let isMinus = event.key == "ArrowDown" let isMinus = event.key == "ArrowDown";
if (!isPlus && !isMinus) return; if (!isPlus && !isMinus) return;
let selectionStart = target.selectionStart; let selectionStart = target.selectionStart;
let selectionEnd = target.selectionEnd; let selectionEnd = target.selectionEnd;
let text = target.value; let text = target.value;
function selectCurrentParenthesisBlock(OPEN, CLOSE){ function selectCurrentParenthesisBlock(OPEN, CLOSE) {
if (selectionStart !== selectionEnd) return false; if (selectionStart !== selectionEnd) return false;
// Find opening parenthesis around current cursor // Find opening parenthesis around current cursor
...@@ -44,27 +44,45 @@ function keyupEditAttention(event){ ...@@ -44,27 +44,45 @@ function keyupEditAttention(event){
return true; return true;
} }
// If the user hasn't selected anything, let's select their current parenthesis block function selectCurrentWord() {
if(! selectCurrentParenthesisBlock('<', '>')){ if (selectionStart !== selectionEnd) return false;
selectCurrentParenthesisBlock('(', ')') const delimiters = opts.keyedit_delimiters + " \r\n\t";
// seek backward until to find beggining
while (!delimiters.includes(text[selectionStart - 1]) && selectionStart > 0) {
selectionStart--;
}
// seek forward to find end
while (!delimiters.includes(text[selectionEnd]) && selectionEnd < text.length) {
selectionEnd++;
}
target.setSelectionRange(selectionStart, selectionEnd);
return true;
}
// If the user hasn't selected anything, let's select their current parenthesis block or word
if (!selectCurrentParenthesisBlock('<', '>') && !selectCurrentParenthesisBlock('(', ')')) {
selectCurrentWord();
} }
event.preventDefault(); event.preventDefault();
closeCharacter = ')' var closeCharacter = ')';
delta = opts.keyedit_precision_attention var delta = opts.keyedit_precision_attention;
if (selectionStart > 0 && text[selectionStart - 1] == '<'){ if (selectionStart > 0 && text[selectionStart - 1] == '<') {
closeCharacter = '>' closeCharacter = '>';
delta = opts.keyedit_precision_extra delta = opts.keyedit_precision_extra;
} else if (selectionStart == 0 || text[selectionStart - 1] != "(") { } else if (selectionStart == 0 || text[selectionStart - 1] != "(") {
// do not include spaces at the end // do not include spaces at the end
while(selectionEnd > selectionStart && text[selectionEnd-1] == ' '){ while (selectionEnd > selectionStart && text[selectionEnd - 1] == ' ') {
selectionEnd -= 1; selectionEnd -= 1;
} }
if(selectionStart == selectionEnd){ if (selectionStart == selectionEnd) {
return return;
} }
text = text.slice(0, selectionStart) + "(" + text.slice(selectionStart, selectionEnd) + ":1.0)" + text.slice(selectionEnd); text = text.slice(0, selectionStart) + "(" + text.slice(selectionStart, selectionEnd) + ":1.0)" + text.slice(selectionEnd);
...@@ -73,22 +91,28 @@ function keyupEditAttention(event){ ...@@ -73,22 +91,28 @@ function keyupEditAttention(event){
selectionEnd += 1; selectionEnd += 1;
} }
end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1; var end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1;
weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end)); var weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end));
if (isNaN(weight)) return; if (isNaN(weight)) return;
weight += isPlus ? delta : -delta; weight += isPlus ? delta : -delta;
weight = parseFloat(weight.toPrecision(12)); weight = parseFloat(weight.toPrecision(12));
if(String(weight).length == 1) weight += ".0" if (String(weight).length == 1) weight += ".0";
if (closeCharacter == ')' && weight == 1) {
text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + 5);
selectionStart--;
selectionEnd--;
} else {
text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1); text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1);
}
target.focus(); target.focus();
target.value = text; target.value = text;
target.selectionStart = selectionStart; target.selectionStart = selectionStart;
target.selectionEnd = selectionEnd; target.selectionEnd = selectionEnd;
updateInput(target) updateInput(target);
} }
addEventListener('keydown', (event) => { addEventListener('keydown', (event) => {
......
function extensions_apply(_, _, disable_all){ function extensions_apply(_disabled_list, _update_list, disable_all) {
var disable = [] var disable = [];
var update = [] var update = [];
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){ gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x) {
if(x.name.startsWith("enable_") && ! x.checked) if (x.name.startsWith("enable_") && !x.checked) {
disable.push(x.name.substr(7)) disable.push(x.name.substring(7));
}
if(x.name.startsWith("update_") && x.checked) if (x.name.startsWith("update_") && x.checked) {
update.push(x.name.substr(7)) update.push(x.name.substring(7));
}) }
});
restart_reload() restart_reload();
return [JSON.stringify(disable), JSON.stringify(update), disable_all] return [JSON.stringify(disable), JSON.stringify(update), disable_all];
} }
function extensions_check(_, _){ function extensions_check() {
var disable = [] var disable = [];
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){ gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x) {
if(x.name.startsWith("enable_") && ! x.checked) if (x.name.startsWith("enable_") && !x.checked) {
disable.push(x.name.substr(7)) disable.push(x.name.substring(7));
}) }
});
gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){ gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x) {
x.innerHTML = "Loading..." x.innerHTML = "Loading...";
}) });
var id = randomId() var id = randomId();
requestProgress(id, gradioApp().getElementById('extensions_installed_top'), null, function(){ requestProgress(id, gradioApp().getElementById('extensions_installed_top'), null, function() {
}) });
return [id, JSON.stringify(disable)] return [id, JSON.stringify(disable)];
} }
function install_extension_from_index(button, url){ function install_extension_from_index(button, url) {
button.disabled = "disabled" button.disabled = "disabled";
button.value = "Installing..." button.value = "Installing...";
textarea = gradioApp().querySelector('#extension_to_install textarea') var textarea = gradioApp().querySelector('#extension_to_install textarea');
textarea.value = url textarea.value = url;
updateInput(textarea) updateInput(textarea);
gradioApp().querySelector('#install_extension_button').click() gradioApp().querySelector('#install_extension_button').click();
}
function config_state_confirm_restore(_, config_state_name, config_restore_type) {
if (config_state_name == "Current") {
return [false, config_state_name, config_restore_type];
}
let restored = "";
if (config_restore_type == "extensions") {
restored = "all saved extension versions";
} else if (config_restore_type == "webui") {
restored = "the webui version";
} else {
restored = "the webui version and all saved extension versions";
}
let confirmed = confirm("Are you sure you want to restore from this state?\nThis will reset " + restored + ".");
if (confirmed) {
restart_reload();
gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x) {
x.innerHTML = "Loading...";
});
}
return [confirmed, config_state_name, config_restore_type];
} }
This diff is collapsed.
// attaches listeners to the txt2img and img2img galleries to update displayed generation param text when the image changes // attaches listeners to the txt2img and img2img galleries to update displayed generation param text when the image changes
let txt2img_gallery, img2img_gallery, modal = undefined; let txt2img_gallery, img2img_gallery, modal = undefined;
onUiUpdate(function(){ onUiUpdate(function() {
if (!txt2img_gallery) { if (!txt2img_gallery) {
txt2img_gallery = attachGalleryListeners("txt2img") txt2img_gallery = attachGalleryListeners("txt2img");
} }
if (!img2img_gallery) { if (!img2img_gallery) {
img2img_gallery = attachGalleryListeners("img2img") img2img_gallery = attachGalleryListeners("img2img");
} }
if (!modal) { if (!modal) {
modal = gradioApp().getElementById('lightboxModal') modal = gradioApp().getElementById('lightboxModal');
modalObserver.observe(modal, { attributes : true, attributeFilter : ['style'] }); modalObserver.observe(modal, {attributes: true, attributeFilter: ['style']});
} }
}); });
let modalObserver = new MutationObserver(function(mutations) { let modalObserver = new MutationObserver(function(mutations) {
mutations.forEach(function(mutationRecord) { mutations.forEach(function(mutationRecord) {
let selectedTab = gradioApp().querySelector('#tabs div button.bg-white')?.innerText let selectedTab = gradioApp().querySelector('#tabs div button.selected')?.innerText;
if (mutationRecord.target.style.display === 'none' && selectedTab === 'txt2img' || selectedTab === 'img2img') if (mutationRecord.target.style.display === 'none' && (selectedTab === 'txt2img' || selectedTab === 'img2img')) {
gradioApp().getElementById(selectedTab+"_generation_info_button").click() gradioApp().getElementById(selectedTab + "_generation_info_button")?.click();
}
}); });
}); });
function attachGalleryListeners(tab_name) { function attachGalleryListeners(tab_name) {
gallery = gradioApp().querySelector('#'+tab_name+'_gallery') var gallery = gradioApp().querySelector('#' + tab_name + '_gallery');
gallery?.addEventListener('click', () => gradioApp().getElementById(tab_name+"_generation_info_button").click()); gallery?.addEventListener('click', () => gradioApp().getElementById(tab_name + "_generation_info_button").click());
gallery?.addEventListener('keydown', (e) => { gallery?.addEventListener('keydown', (e) => {
if (e.keyCode == 37 || e.keyCode == 39) // left or right arrow if (e.keyCode == 37 || e.keyCode == 39) { // left or right arrow
gradioApp().getElementById(tab_name+"_generation_info_button").click() gradioApp().getElementById(tab_name + "_generation_info_button").click();
}
}); });
return gallery; return gallery;
} }
// mouseover tooltips for various UI elements // mouseover tooltips for various UI elements
titles = { var titles = {
"Sampling steps": "How many times to improve the generated image iteratively; higher values take longer; very low values can produce bad results", "Sampling steps": "How many times to improve the generated image iteratively; higher values take longer; very low values can produce bad results",
"Sampling method": "Which algorithm to use to produce the image", "Sampling method": "Which algorithm to use to produce the image",
"GFPGAN": "Restore low quality faces using GFPGAN neural network", "GFPGAN": "Restore low quality faces using GFPGAN neural network",
...@@ -9,6 +9,7 @@ titles = { ...@@ -9,6 +9,7 @@ titles = {
"UniPC": "Unified Predictor-Corrector Framework for Fast Sampling of Diffusion Models", "UniPC": "Unified Predictor-Corrector Framework for Fast Sampling of Diffusion Models",
"DPM adaptive": "Ignores step count - uses a number of steps determined by the CFG and resolution", "DPM adaptive": "Ignores step count - uses a number of steps determined by the CFG and resolution",
"\u{1F4D0}": "Auto detect size from img2img",
"Batch count": "How many batches of images to create (has no impact on generation performance or VRAM usage)", "Batch count": "How many batches of images to create (has no impact on generation performance or VRAM usage)",
"Batch size": "How many image to create in a single batch (increases generation performance at cost of higher VRAM usage)", "Batch size": "How many image to create in a single batch (increases generation performance at cost of higher VRAM usage)",
"CFG Scale": "Classifier Free Guidance Scale - how strongly the image should conform to prompt - lower values produce more creative results", "CFG Scale": "Classifier Free Guidance Scale - how strongly the image should conform to prompt - lower values produce more creative results",
...@@ -22,6 +23,7 @@ titles = { ...@@ -22,6 +23,7 @@ titles = {
"\u{1f4cb}": "Apply selected styles to current prompt", "\u{1f4cb}": "Apply selected styles to current prompt",
"\u{1f4d2}": "Paste available values into the field", "\u{1f4d2}": "Paste available values into the field",
"\u{1f3b4}": "Show/hide extra networks", "\u{1f3b4}": "Show/hide extra networks",
"\u{1f300}": "Restore progress",
"Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt", "Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
"SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back", "SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back",
...@@ -65,8 +67,8 @@ titles = { ...@@ -65,8 +67,8 @@ titles = {
"Interrogate": "Reconstruct prompt from existing image and put it into the prompt field.", "Interrogate": "Reconstruct prompt from existing image and put it into the prompt field.",
"Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.", "Images filename pattern": "Use tags like [seed] and [date] to define how filenames for images are chosen. Leave empty for default.",
"Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg],[prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.", "Directory name pattern": "Use tags like [seed] and [date] to define how subdirectories for images and grids are chosen. Leave empty for default.",
"Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle", "Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle",
"Loopback": "Performs img2img processing multiple times. Output images are used as input for the next loop.", "Loopback": "Performs img2img processing multiple times. Output images are used as input for the next loop.",
...@@ -85,7 +87,6 @@ titles = { ...@@ -85,7 +87,6 @@ titles = {
"vram": "Torch active: Peak amount of VRAM used by Torch during generation, excluding cached data.\nTorch reserved: Peak amount of VRAM allocated by Torch, including all active and cached data.\nSys VRAM: Peak amount of VRAM allocation across all applications / total GPU VRAM (peak utilization%).", "vram": "Torch active: Peak amount of VRAM used by Torch during generation, excluding cached data.\nTorch reserved: Peak amount of VRAM allocated by Torch, including all active and cached data.\nSys VRAM: Peak amount of VRAM allocation across all applications / total GPU VRAM (peak utilization%).",
"Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.", "Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.",
"Do not add watermark to images": "If this option is enabled, watermark will not be added to created images. Warning: if you do not add watermark, you may be behaving in an unethical manner.",
"Filename word regex": "This regular expression will be used extract words from filename, and they will be joined using the option below into label text used for training. Leave empty to keep filename text as it is.", "Filename word regex": "This regular expression will be used extract words from filename, and they will be joined using the option below into label text used for training. Leave empty to keep filename text as it is.",
"Filename join string": "This string will be used to join split words into a single line if the option above is enabled.", "Filename join string": "This string will be used to join split words into a single line if the option above is enabled.",
...@@ -111,37 +112,57 @@ titles = { ...@@ -111,37 +112,57 @@ titles = {
"Resize height to": "Resizes image to this height. If 0, height is inferred from either of two nearby sliders.", "Resize height to": "Resizes image to this height. If 0, height is inferred from either of two nearby sliders.",
"Multiplier for extra networks": "When adding extra network such as Hypernetwork or Lora to prompt, use this multiplier for it.", "Multiplier for extra networks": "When adding extra network such as Hypernetwork or Lora to prompt, use this multiplier for it.",
"Discard weights with matching name": "Regular expression; if weights's name matches it, the weights is not written to the resulting checkpoint. Use ^model_ema to discard EMA weights.", "Discard weights with matching name": "Regular expression; if weights's name matches it, the weights is not written to the resulting checkpoint. Use ^model_ema to discard EMA weights.",
"Extra networks tab order": "Comma-separated list of tab names; tabs listed here will appear in the extra networks UI first and in order lsited." "Extra networks tab order": "Comma-separated list of tab names; tabs listed here will appear in the extra networks UI first and in order lsited.",
} "Negative Guidance minimum sigma": "Skip negative prompt for steps where image is already mostly denoised; the higher this value, the more skips there will be; provides increased performance in exchange for minor quality reduction."
};
function updateTooltipForSpan(span) {
if (span.title) return; // already has a title
onUiUpdate(function(){ let tooltip = localization[titles[span.textContent]] || titles[span.textContent];
gradioApp().querySelectorAll('span, button, select, p').forEach(function(span){
tooltip = titles[span.textContent];
if(!tooltip){ if (!tooltip) {
tooltip = titles[span.value]; tooltip = localization[titles[span.value]] || titles[span.value];
} }
if(!tooltip){ if (!tooltip) {
for (const c of span.classList) { for (const c of span.classList) {
if (c in titles) { if (c in titles) {
tooltip = titles[c]; tooltip = localization[titles[c]] || titles[c];
break; break;
} }
} }
} }
if(tooltip){ if (tooltip) {
span.title = tooltip; span.title = tooltip;
} }
}) }
gradioApp().querySelectorAll('select').forEach(function(select){ function updateTooltipForSelect(select) {
if (select.onchange != null) return; if (select.onchange != null) return;
select.onchange = function(){ select.onchange = function() {
select.title = titles[select.value] || ""; select.title = localization[titles[select.value]] || titles[select.value] || "";
};
}
var observedTooltipElements = {SPAN: 1, BUTTON: 1, SELECT: 1, P: 1};
onUiUpdate(function(m) {
m.forEach(function(record) {
record.addedNodes.forEach(function(node) {
if (observedTooltipElements[node.tagName]) {
updateTooltipForSpan(node);
}
if (node.tagName == "SELECT") {
updateTooltipForSelect(node);
}
if (node.querySelectorAll) {
node.querySelectorAll('span, button, select, p').forEach(updateTooltipForSpan);
node.querySelectorAll('select').forEach(updateTooltipForSelect);
} }
}) });
}) });
});
function setInactive(elem, inactive){ function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y) {
if(inactive){ function setInactive(elem, inactive) {
elem.classList.add('inactive') elem.classList.toggle('inactive', !!inactive);
} else{
elem.classList.remove('inactive')
} }
}
function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y){ var hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale');
hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale') var hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x');
hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x') var hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y');
hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y')
gradioApp().getElementById('txt2img_hires_fix_row2').style.display = opts.use_old_hires_fix_width_height ? "none" : "" gradioApp().getElementById('txt2img_hires_fix_row2').style.display = opts.use_old_hires_fix_width_height ? "none" : "";
setInactive(hrUpscaleBy, opts.use_old_hires_fix_width_height || hr_resize_x > 0 || hr_resize_y > 0) setInactive(hrUpscaleBy, opts.use_old_hires_fix_width_height || hr_resize_x > 0 || hr_resize_y > 0);
setInactive(hrResizeX, opts.use_old_hires_fix_width_height || hr_resize_x == 0) setInactive(hrResizeX, opts.use_old_hires_fix_width_height || hr_resize_x == 0);
setInactive(hrResizeY, opts.use_old_hires_fix_width_height || hr_resize_y == 0) setInactive(hrResizeY, opts.use_old_hires_fix_width_height || hr_resize_y == 0);
return [enable, width, height, hr_scale, hr_resize_x, hr_resize_y] return [enable, width, height, hr_scale, hr_resize_x, hr_resize_y];
} }
...@@ -2,20 +2,18 @@ ...@@ -2,20 +2,18 @@
* temporary fix for https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/668 * temporary fix for https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/668
* @see https://github.com/gradio-app/gradio/issues/1721 * @see https://github.com/gradio-app/gradio/issues/1721
*/ */
window.addEventListener( 'resize', () => imageMaskResize());
function imageMaskResize() { function imageMaskResize() {
const canvases = gradioApp().querySelectorAll('#img2maskimg .touch-none canvas'); const canvases = gradioApp().querySelectorAll('#img2maskimg .touch-none canvas');
if ( ! canvases.length ) { if (!canvases.length) {
canvases_fixed = false; window.removeEventListener('resize', imageMaskResize);
window.removeEventListener( 'resize', imageMaskResize );
return; return;
} }
const wrapper = canvases[0].closest('.touch-none'); const wrapper = canvases[0].closest('.touch-none');
const previewImage = wrapper.previousElementSibling; const previewImage = wrapper.previousElementSibling;
if ( ! previewImage.complete ) { if (!previewImage.complete) {
previewImage.addEventListener( 'load', () => imageMaskResize()); previewImage.addEventListener('load', imageMaskResize);
return; return;
} }
...@@ -24,22 +22,22 @@ function imageMaskResize() { ...@@ -24,22 +22,22 @@ function imageMaskResize() {
const nw = previewImage.naturalWidth; const nw = previewImage.naturalWidth;
const nh = previewImage.naturalHeight; const nh = previewImage.naturalHeight;
const portrait = nh > nw; const portrait = nh > nw;
const factor = portrait;
const wW = Math.min(w, portrait ? h/nh*nw : w/nw*nw); const wW = Math.min(w, portrait ? h / nh * nw : w / nw * nw);
const wH = Math.min(h, portrait ? h/nh*nh : w/nw*nh); const wH = Math.min(h, portrait ? h / nh * nh : w / nw * nh);
wrapper.style.width = `${wW}px`; wrapper.style.width = `${wW}px`;
wrapper.style.height = `${wH}px`; wrapper.style.height = `${wH}px`;
wrapper.style.left = `0px`; wrapper.style.left = `0px`;
wrapper.style.top = `0px`; wrapper.style.top = `0px`;
canvases.forEach( c => { canvases.forEach(c => {
c.style.width = c.style.height = ''; c.style.width = c.style.height = '';
c.style.maxWidth = '100%'; c.style.maxWidth = '100%';
c.style.maxHeight = '100%'; c.style.maxHeight = '100%';
c.style.objectFit = 'contain'; c.style.objectFit = 'contain';
}); });
} }
onUiUpdate(() => imageMaskResize()); onUiUpdate(imageMaskResize);
window.addEventListener('resize', imageMaskResize);
window.onload = (function(){ window.onload = (function() {
window.addEventListener('drop', e => { window.addEventListener('drop', e => {
const target = e.composedPath()[0]; const target = e.composedPath()[0];
const idx = selected_gallery_index();
if (target.placeholder.indexOf("Prompt") == -1) return; if (target.placeholder.indexOf("Prompt") == -1) return;
let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image"; let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image";
...@@ -11,7 +10,7 @@ window.onload = (function(){ ...@@ -11,7 +10,7 @@ window.onload = (function(){
const imgParent = gradioApp().getElementById(prompt_target); const imgParent = gradioApp().getElementById(prompt_target);
const files = e.dataTransfer.files; const files = e.dataTransfer.files;
const fileInput = imgParent.querySelector('input[type="file"]'); const fileInput = imgParent.querySelector('input[type="file"]');
if ( fileInput ) { if (fileInput) {
fileInput.files = files; fileInput.files = files;
fileInput.dispatchEvent(new Event('change')); fileInput.dispatchEvent(new Event('change'));
} }
......
This diff is collapsed.
window.addEventListener('gamepadconnected', (e) => {
const index = e.gamepad.index;
let isWaiting = false;
setInterval(async() => {
if (!opts.js_modal_lightbox_gamepad || isWaiting) return;
const gamepad = navigator.getGamepads()[index];
const xValue = gamepad.axes[0];
if (xValue <= -0.3) {
modalPrevImage(e);
isWaiting = true;
} else if (xValue >= 0.3) {
modalNextImage(e);
isWaiting = true;
}
if (isWaiting) {
await sleepUntil(() => {
const xValue = navigator.getGamepads()[index].axes[0];
if (xValue < 0.3 && xValue > -0.3) {
return true;
}
}, opts.js_modal_lightbox_gamepad_repeat);
isWaiting = false;
}
}, 10);
});
/*
Primarily for vr controller type pointer devices.
I use the wheel event because there's currently no way to do it properly with web xr.
*/
let isScrolling = false;
window.addEventListener('wheel', (e) => {
if (!opts.js_modal_lightbox_gamepad || isScrolling) return;
isScrolling = true;
if (e.deltaX <= -0.6) {
modalPrevImage(e);
} else if (e.deltaX >= 0.6) {
modalNextImage(e);
}
setTimeout(() => {
isScrolling = false;
}, opts.js_modal_lightbox_gamepad_repeat);
});
function sleepUntil(f, timeout) {
return new Promise((resolve) => {
const timeStart = new Date();
const wait = setInterval(function() {
if (f() || new Date() - timeStart > timeout) {
clearInterval(wait);
resolve();
}
}, 20);
});
}
// localization = {} -- the dict with translations is created by the backend // localization = {} -- the dict with translations is created by the backend
ignore_ids_for_localization={ var ignore_ids_for_localization = {
setting_sd_hypernetwork: 'OPTION', setting_sd_hypernetwork: 'OPTION',
setting_sd_model_checkpoint: 'OPTION', setting_sd_model_checkpoint: 'OPTION',
setting_realesrgan_enabled_models: 'OPTION',
modelmerger_primary_model_name: 'OPTION', modelmerger_primary_model_name: 'OPTION',
modelmerger_secondary_model_name: 'OPTION', modelmerger_secondary_model_name: 'OPTION',
modelmerger_tertiary_model_name: 'OPTION', modelmerger_tertiary_model_name: 'OPTION',
...@@ -17,119 +16,145 @@ ignore_ids_for_localization={ ...@@ -17,119 +16,145 @@ ignore_ids_for_localization={
setting_realesrgan_enabled_models: 'SPAN', setting_realesrgan_enabled_models: 'SPAN',
extras_upscaler_1: 'SPAN', extras_upscaler_1: 'SPAN',
extras_upscaler_2: 'SPAN', extras_upscaler_2: 'SPAN',
} };
var re_num = /^[.\d]+$/;
var re_emoji = /[\p{Extended_Pictographic}\u{1F3FB}-\u{1F3FF}\u{1F9B0}-\u{1F9B3}]/u;
re_num = /^[\.\d]+$/ var original_lines = {};
re_emoji = /[\p{Extended_Pictographic}\u{1F3FB}-\u{1F3FF}\u{1F9B0}-\u{1F9B3}]/u var translated_lines = {};
original_lines = {} function hasLocalization() {
translated_lines = {} return window.localization && Object.keys(window.localization).length > 0;
}
function textNodesUnder(el){ function textNodesUnder(el) {
var n, a=[], walk=document.createTreeWalker(el,NodeFilter.SHOW_TEXT,null,false); var n, a = [], walk = document.createTreeWalker(el, NodeFilter.SHOW_TEXT, null, false);
while(n=walk.nextNode()) a.push(n); while ((n = walk.nextNode())) a.push(n);
return a; return a;
} }
function canBeTranslated(node, text){ function canBeTranslated(node, text) {
if(! text) return false; if (!text) return false;
if(! node.parentElement) return false; if (!node.parentElement) return false;
parentType = node.parentElement.nodeName var parentType = node.parentElement.nodeName;
if(parentType=='SCRIPT' || parentType=='STYLE' || parentType=='TEXTAREA') return false; if (parentType == 'SCRIPT' || parentType == 'STYLE' || parentType == 'TEXTAREA') return false;
if (parentType=='OPTION' || parentType=='SPAN'){ if (parentType == 'OPTION' || parentType == 'SPAN') {
pnode = node var pnode = node;
for(var level=0; level<4; level++){ for (var level = 0; level < 4; level++) {
pnode = pnode.parentElement pnode = pnode.parentElement;
if(! pnode) break; if (!pnode) break;
if(ignore_ids_for_localization[pnode.id] == parentType) return false; if (ignore_ids_for_localization[pnode.id] == parentType) return false;
} }
} }
if(re_num.test(text)) return false; if (re_num.test(text)) return false;
if(re_emoji.test(text)) return false; if (re_emoji.test(text)) return false;
return true return true;
} }
function getTranslation(text){ function getTranslation(text) {
if(! text) return undefined if (!text) return undefined;
if(translated_lines[text] === undefined){ if (translated_lines[text] === undefined) {
original_lines[text] = 1 original_lines[text] = 1;
} }
tl = localization[text] var tl = localization[text];
if(tl !== undefined){ if (tl !== undefined) {
translated_lines[tl] = 1 translated_lines[tl] = 1;
} }
return tl return tl;
} }
function processTextNode(node){ function processTextNode(node) {
text = node.textContent.trim() var text = node.textContent.trim();
if(! canBeTranslated(node, text)) return if (!canBeTranslated(node, text)) return;
tl = getTranslation(text) var tl = getTranslation(text);
if(tl !== undefined){ if (tl !== undefined) {
node.textContent = tl node.textContent = tl;
} }
} }
function processNode(node){ function processNode(node) {
if(node.nodeType == 3){ if (node.nodeType == 3) {
processTextNode(node) processTextNode(node);
return return;
} }
if(node.title){ if (node.title) {
tl = getTranslation(node.title) let tl = getTranslation(node.title);
if(tl !== undefined){ if (tl !== undefined) {
node.title = tl node.title = tl;
} }
} }
if(node.placeholder){ if (node.placeholder) {
tl = getTranslation(node.placeholder) let tl = getTranslation(node.placeholder);
if(tl !== undefined){ if (tl !== undefined) {
node.placeholder = tl node.placeholder = tl;
} }
} }
textNodesUnder(node).forEach(function(node){ textNodesUnder(node).forEach(function(node) {
processTextNode(node) processTextNode(node);
}) });
} }
function dumpTranslations(){ function dumpTranslations() {
dumped = {} if (!hasLocalization()) {
// If we don't have any localization,
// we will not have traversed the app to find
// original_lines, so do that now.
processNode(gradioApp());
}
var dumped = {};
if (localization.rtl) { if (localization.rtl) {
dumped.rtl = true dumped.rtl = true;
} }
Object.keys(original_lines).forEach(function(text){ for (const text in original_lines) {
if(dumped[text] !== undefined) return if (dumped[text] !== undefined) continue;
dumped[text] = localization[text] || text;
dumped[text] = localization[text] || text }
})
return dumped return dumped;
} }
onUiUpdate(function(m){ function download_localization() {
m.forEach(function(mutation){ var text = JSON.stringify(dumpTranslations(), null, 4);
mutation.addedNodes.forEach(function(node){
processNode(node) var element = document.createElement('a');
}) element.setAttribute('href', 'data:text/plain;charset=utf-8,' + encodeURIComponent(text));
}); element.setAttribute('download', "localization.json");
}) element.style.display = 'none';
document.body.appendChild(element);
element.click();
document.body.removeChild(element);
}
document.addEventListener("DOMContentLoaded", function() { document.addEventListener("DOMContentLoaded", function() {
processNode(gradioApp()) if (!hasLocalization()) {
return;
}
onUiUpdate(function(m) {
m.forEach(function(mutation) {
mutation.addedNodes.forEach(function(node) {
processNode(node);
});
});
});
processNode(gradioApp());
if (localization.rtl) { // if the language is from right to left, if (localization.rtl) { // if the language is from right to left,
(new MutationObserver((mutations, observer) => { // wait for the style to load (new MutationObserver((mutations, observer) => { // wait for the style to load
...@@ -144,22 +169,8 @@ document.addEventListener("DOMContentLoaded", function() { ...@@ -144,22 +169,8 @@ document.addEventListener("DOMContentLoaded", function() {
} }
} }
} }
})
}); });
})).observe(gradioApp(), { childList: true }); });
})).observe(gradioApp(), {childList: true});
} }
}) });
function download_localization() {
text = JSON.stringify(dumpTranslations(), null, 4)
var element = document.createElement('a');
element.setAttribute('href', 'data:text/plain;charset=utf-8,' + encodeURIComponent(text));
element.setAttribute('download', "localization.json");
element.style.display = 'none';
document.body.appendChild(element);
element.click();
document.body.removeChild(element);
}
...@@ -2,16 +2,16 @@ ...@@ -2,16 +2,16 @@
let lastHeadImg = null; let lastHeadImg = null;
notificationButton = null let notificationButton = null;
onUiUpdate(function(){ onUiUpdate(function() {
if(notificationButton == null){ if (notificationButton == null) {
notificationButton = gradioApp().getElementById('request_notifications') notificationButton = gradioApp().getElementById('request_notifications');
if(notificationButton != null){ if (notificationButton != null) {
notificationButton.addEventListener('click', function (evt) { notificationButton.addEventListener('click', () => {
Notification.requestPermission(); void Notification.requestPermission();
},true); }, true);
} }
} }
...@@ -42,7 +42,7 @@ onUiUpdate(function(){ ...@@ -42,7 +42,7 @@ onUiUpdate(function(){
} }
); );
notification.onclick = function(_){ notification.onclick = function(_) {
parent.focus(); parent.focus();
this.close(); this.close();
}; };
......
// code related to showing and updating progressbar shown as the image is being made // code related to showing and updating progressbar shown as the image is being made
function rememberGallerySelection(id_gallery){ function rememberGallerySelection() {
} }
function getGallerySelectedIndex(id_gallery){ function getGallerySelectedIndex() {
} }
function request(url, data, handler, errorHandler){ function request(url, data, handler, errorHandler) {
var xhr = new XMLHttpRequest(); var xhr = new XMLHttpRequest();
var url = url;
xhr.open("POST", url, true); xhr.open("POST", url, true);
xhr.setRequestHeader("Content-Type", "application/json"); xhr.setRequestHeader("Content-Type", "application/json");
xhr.onreadystatechange = function () { xhr.onreadystatechange = function() {
if (xhr.readyState === 4) { if (xhr.readyState === 4) {
if (xhr.status === 200) { if (xhr.status === 200) {
try { try {
var js = JSON.parse(xhr.responseText); var js = JSON.parse(xhr.responseText);
handler(js) handler(js);
} catch (error) { } catch (error) {
console.error(error); console.error(error);
errorHandler() errorHandler();
} }
} else{ } else {
errorHandler() errorHandler();
} }
} }
}; };
...@@ -32,147 +31,147 @@ function request(url, data, handler, errorHandler){ ...@@ -32,147 +31,147 @@ function request(url, data, handler, errorHandler){
xhr.send(js); xhr.send(js);
} }
function pad2(x){ function pad2(x) {
return x<10 ? '0'+x : x return x < 10 ? '0' + x : x;
} }
function formatTime(secs){ function formatTime(secs) {
if(secs > 3600){ if (secs > 3600) {
return pad2(Math.floor(secs/60/60)) + ":" + pad2(Math.floor(secs/60)%60) + ":" + pad2(Math.floor(secs)%60) return pad2(Math.floor(secs / 60 / 60)) + ":" + pad2(Math.floor(secs / 60) % 60) + ":" + pad2(Math.floor(secs) % 60);
} else if(secs > 60){ } else if (secs > 60) {
return pad2(Math.floor(secs/60)) + ":" + pad2(Math.floor(secs)%60) return pad2(Math.floor(secs / 60)) + ":" + pad2(Math.floor(secs) % 60);
} else{ } else {
return Math.floor(secs) + "s" return Math.floor(secs) + "s";
} }
} }
function setTitle(progress){ function setTitle(progress) {
var title = 'Stable Diffusion' var title = 'Stable Diffusion';
if(opts.show_progress_in_title && progress){ if (opts.show_progress_in_title && progress) {
title = '[' + progress.trim() + '] ' + title; title = '[' + progress.trim() + '] ' + title;
} }
if(document.title != title){ if (document.title != title) {
document.title = title; document.title = title;
} }
} }
function randomId(){ function randomId() {
return "task(" + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7)+")" return "task(" + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7) + ")";
} }
// starts sending progress requests to "/internal/progress" uri, creating progressbar above progressbarContainer element and // starts sending progress requests to "/internal/progress" uri, creating progressbar above progressbarContainer element and
// preview inside gallery element. Cleans up all created stuff when the task is over and calls atEnd. // preview inside gallery element. Cleans up all created stuff when the task is over and calls atEnd.
// calls onProgress every time there is a progress update // calls onProgress every time there is a progress update
function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgress){ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgress, inactivityTimeout = 40) {
var dateStart = new Date() var dateStart = new Date();
var wasEverActive = false var wasEverActive = false;
var parentProgressbar = progressbarContainer.parentNode var parentProgressbar = progressbarContainer.parentNode;
var parentGallery = gallery ? gallery.parentNode : null var parentGallery = gallery ? gallery.parentNode : null;
var divProgress = document.createElement('div') var divProgress = document.createElement('div');
divProgress.className='progressDiv' divProgress.className = 'progressDiv';
divProgress.style.display = opts.show_progressbar ? "block" : "none" divProgress.style.display = opts.show_progressbar ? "block" : "none";
var divInner = document.createElement('div') var divInner = document.createElement('div');
divInner.className='progress' divInner.className = 'progress';
divProgress.appendChild(divInner) divProgress.appendChild(divInner);
parentProgressbar.insertBefore(divProgress, progressbarContainer) parentProgressbar.insertBefore(divProgress, progressbarContainer);
if(parentGallery){ if (parentGallery) {
var livePreview = document.createElement('div') var livePreview = document.createElement('div');
livePreview.className='livePreview' livePreview.className = 'livePreview';
parentGallery.insertBefore(livePreview, gallery) parentGallery.insertBefore(livePreview, gallery);
} }
var removeProgressBar = function(){ var removeProgressBar = function() {
setTitle("") setTitle("");
parentProgressbar.removeChild(divProgress) parentProgressbar.removeChild(divProgress);
if(parentGallery) parentGallery.removeChild(livePreview) if (parentGallery) parentGallery.removeChild(livePreview);
atEnd() atEnd();
} };
var fun = function(id_task, id_live_preview){ var fun = function(id_task, id_live_preview) {
request("./internal/progress", {"id_task": id_task, "id_live_preview": id_live_preview}, function(res){ request("./internal/progress", {id_task: id_task, id_live_preview: id_live_preview}, function(res) {
if(res.completed){ if (res.completed) {
removeProgressBar() removeProgressBar();
return return;
} }
var rect = progressbarContainer.getBoundingClientRect() var rect = progressbarContainer.getBoundingClientRect();
if(rect.width){ if (rect.width) {
divProgress.style.width = rect.width + "px"; divProgress.style.width = rect.width + "px";
} }
progressText = "" let progressText = "";
divInner.style.width = ((res.progress || 0) * 100.0) + '%' divInner.style.width = ((res.progress || 0) * 100.0) + '%';
divInner.style.background = res.progress ? "" : "transparent" divInner.style.background = res.progress ? "" : "transparent";
if(res.progress > 0){ if (res.progress > 0) {
progressText = ((res.progress || 0) * 100.0).toFixed(0) + '%' progressText = ((res.progress || 0) * 100.0).toFixed(0) + '%';
} }
if(res.eta){ if (res.eta) {
progressText += " ETA: " + formatTime(res.eta) progressText += " ETA: " + formatTime(res.eta);
} }
setTitle(progressText) setTitle(progressText);
if(res.textinfo && res.textinfo.indexOf("\n") == -1){ if (res.textinfo && res.textinfo.indexOf("\n") == -1) {
progressText = res.textinfo + " " + progressText progressText = res.textinfo + " " + progressText;
} }
divInner.textContent = progressText divInner.textContent = progressText;
var elapsedFromStart = (new Date() - dateStart) / 1000 var elapsedFromStart = (new Date() - dateStart) / 1000;
if(res.active) wasEverActive = true; if (res.active) wasEverActive = true;
if(! res.active && wasEverActive){ if (!res.active && wasEverActive) {
removeProgressBar() removeProgressBar();
return return;
} }
if(elapsedFromStart > 5 && !res.queued && !res.active){ if (elapsedFromStart > inactivityTimeout && !res.queued && !res.active) {
removeProgressBar() removeProgressBar();
return return;
} }
if(res.live_preview && gallery){ if (res.live_preview && gallery) {
var rect = gallery.getBoundingClientRect() rect = gallery.getBoundingClientRect();
if(rect.width){ if (rect.width) {
livePreview.style.width = rect.width + "px" livePreview.style.width = rect.width + "px";
livePreview.style.height = rect.height + "px" livePreview.style.height = rect.height + "px";
} }
var img = new Image(); var img = new Image();
img.onload = function() { img.onload = function() {
livePreview.appendChild(img) livePreview.appendChild(img);
if(livePreview.childElementCount > 2){ if (livePreview.childElementCount > 2) {
livePreview.removeChild(livePreview.firstElementChild) livePreview.removeChild(livePreview.firstElementChild);
}
} }
};
img.src = res.live_preview; img.src = res.live_preview;
} }
if(onProgress){ if (onProgress) {
onProgress(res) onProgress(res);
} }
setTimeout(() => { setTimeout(() => {
fun(id_task, res.id_live_preview); fun(id_task, res.id_live_preview);
}, opts.live_preview_refresh_period || 500) }, opts.live_preview_refresh_period || 500);
}, function(){ }, function() {
removeProgressBar() removeProgressBar();
}) });
} };
fun(id_task, 0) fun(id_task, 0);
} }
function start_training_textual_inversion(){ function start_training_textual_inversion() {
gradioApp().querySelector('#ti_error').innerHTML='' gradioApp().querySelector('#ti_error').innerHTML = '';
var id = randomId() var id = randomId();
requestProgress(id, gradioApp().getElementById('ti_output'), gradioApp().getElementById('ti_gallery'), function(){}, function(progress){ requestProgress(id, gradioApp().getElementById('ti_output'), gradioApp().getElementById('ti_gallery'), function() {}, function(progress) {
gradioApp().getElementById('ti_progress').innerHTML = progress.textinfo gradioApp().getElementById('ti_progress').innerHTML = progress.textinfo;
}) });
var res = args_to_array(arguments) var res = Array.from(arguments);
res[0] = id res[0] = id;
return res return res;
} }
This diff is collapsed.
// various hints and extra info for the settings tab
var settingsHintsSetup = false;
onOptionsChanged(function() {
if (settingsHintsSetup) return;
settingsHintsSetup = true;
gradioApp().querySelectorAll('#settings [id^=setting_]').forEach(function(div) {
var name = div.id.substr(8);
var commentBefore = opts._comments_before[name];
var commentAfter = opts._comments_after[name];
if (!commentBefore && !commentAfter) return;
var span = null;
if (div.classList.contains('gradio-checkbox')) span = div.querySelector('label span');
else if (div.classList.contains('gradio-checkboxgroup')) span = div.querySelector('span').firstChild;
else if (div.classList.contains('gradio-radio')) span = div.querySelector('span').firstChild;
else span = div.querySelector('label span').firstChild;
if (!span) return;
if (commentBefore) {
var comment = document.createElement('DIV');
comment.className = 'settings-comment';
comment.innerHTML = commentBefore;
span.parentElement.insertBefore(document.createTextNode('\xa0'), span);
span.parentElement.insertBefore(comment, span);
span.parentElement.insertBefore(document.createTextNode('\xa0'), span);
}
if (commentAfter) {
comment = document.createElement('DIV');
comment.className = 'settings-comment';
comment.innerHTML = commentAfter;
span.parentElement.insertBefore(comment, span.nextSibling);
span.parentElement.insertBefore(document.createTextNode('\xa0'), span.nextSibling);
}
});
});
function settingsHintsShowQuicksettings() {
requestGet("./internal/quicksettings-hint", {}, function(data) {
var table = document.createElement('table');
table.className = 'settings-value-table';
data.forEach(function(obj) {
var tr = document.createElement('tr');
var td = document.createElement('td');
td.textContent = obj.name;
tr.appendChild(td);
td = document.createElement('td');
td.textContent = obj.label;
tr.appendChild(td);
table.appendChild(tr);
});
popup(table);
});
}
This diff is collapsed.
This diff is collapsed.
...@@ -223,8 +223,9 @@ for key in _options: ...@@ -223,8 +223,9 @@ for key in _options:
if(_options[key].dest != 'help'): if(_options[key].dest != 'help'):
flag = _options[key] flag = _options[key]
_type = str _type = str
if _options[key].default is not None: _type = type(_options[key].default) if _options[key].default is not None:
flags.update({flag.dest: (_type,Field(default=flag.default, description=flag.help))}) _type = type(_options[key].default)
flags.update({flag.dest: (_type, Field(default=flag.default, description=flag.help))})
FlagsModel = create_model("Flags", **flags) FlagsModel = create_model("Flags", **flags)
...@@ -286,6 +287,23 @@ class MemoryResponse(BaseModel): ...@@ -286,6 +287,23 @@ class MemoryResponse(BaseModel):
ram: dict = Field(title="RAM", description="System memory stats") ram: dict = Field(title="RAM", description="System memory stats")
cuda: dict = Field(title="CUDA", description="nVidia CUDA memory stats") cuda: dict = Field(title="CUDA", description="nVidia CUDA memory stats")
class ScriptsList(BaseModel): class ScriptsList(BaseModel):
txt2img: list = Field(default=None,title="Txt2img", description="Titles of scripts (txt2img)") txt2img: list = Field(default=None, title="Txt2img", description="Titles of scripts (txt2img)")
img2img: list = Field(default=None,title="Img2img", description="Titles of scripts (img2img)") img2img: list = Field(default=None, title="Img2img", description="Titles of scripts (img2img)")
\ No newline at end of file
class ScriptArg(BaseModel):
label: str = Field(default=None, title="Label", description="Name of the argument in UI")
value: Optional[Any] = Field(default=None, title="Value", description="Default value of the argument")
minimum: Optional[Any] = Field(default=None, title="Minimum", description="Minimum allowed value for the argumentin UI")
maximum: Optional[Any] = Field(default=None, title="Minimum", description="Maximum allowed value for the argumentin UI")
step: Optional[Any] = Field(default=None, title="Minimum", description="Step for changing value of the argumentin UI")
choices: Optional[List[str]] = Field(default=None, title="Choices", description="Possible values for the argument")
class ScriptInfo(BaseModel):
name: str = Field(default=None, title="Name", description="Script name")
is_alwayson: bool = Field(default=None, title="IsAlwayson", description="Flag specifying whether this script is an alwayson script")
is_img2img: bool = Field(default=None, title="IsImg2img", description="Flag specifying whether this script is an img2img script")
args: List[ScriptArg] = Field(title="Arguments", description="List of script's arguments")
...@@ -35,6 +35,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): ...@@ -35,6 +35,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
try: try:
res = func(*args, **kwargs) res = func(*args, **kwargs)
progress.record_results(id_task, res)
finally: finally:
progress.finish_task(id_task) progress.finish_task(id_task)
...@@ -59,7 +60,7 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False): ...@@ -59,7 +60,7 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
max_debug_str_len = 131072 # (1024*1024)/8 max_debug_str_len = 131072 # (1024*1024)/8
print("Error completing request", file=sys.stderr) print("Error completing request", file=sys.stderr)
argStr = f"Arguments: {str(args)} {str(kwargs)}" argStr = f"Arguments: {args} {kwargs}"
print(argStr[:max_debug_str_len], file=sys.stderr) print(argStr[:max_debug_str_len], file=sys.stderr)
if len(argStr) > max_debug_str_len: if len(argStr) > max_debug_str_len:
print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr) print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
...@@ -72,7 +73,8 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False): ...@@ -72,7 +73,8 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
if extra_outputs_array is None: if extra_outputs_array is None:
extra_outputs_array = [None, ''] extra_outputs_array = [None, '']
res = extra_outputs_array + [f"<div class='error'>{html.escape(type(e).__name__+': '+str(e))}</div>"] error_message = f'{type(e).__name__}: {e}'
res = extra_outputs_array + [f"<div class='error'>{html.escape(error_message)}</div>"]
shared.state.skipped = False shared.state.skipped = False
shared.state.interrupted = False shared.state.interrupted = False
......
import argparse import argparse
import json
import os import os
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, sd_default_config, sd_model_file from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, sd_default_config, sd_model_file # noqa: F401
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -11,8 +12,8 @@ parser.add_argument("--skip-torch-cuda-test", action='store_true', help="launch. ...@@ -11,8 +12,8 @@ parser.add_argument("--skip-torch-cuda-test", action='store_true', help="launch.
parser.add_argument("--reinstall-xformers", action='store_true', help="launch.py argument: install the appropriate version of xformers even if you have some version already installed") parser.add_argument("--reinstall-xformers", action='store_true', help="launch.py argument: install the appropriate version of xformers even if you have some version already installed")
parser.add_argument("--reinstall-torch", action='store_true', help="launch.py argument: install the appropriate version of torch even if you have some version already installed") parser.add_argument("--reinstall-torch", action='store_true', help="launch.py argument: install the appropriate version of torch even if you have some version already installed")
parser.add_argument("--update-check", action='store_true', help="launch.py argument: chck for updates at startup") parser.add_argument("--update-check", action='store_true', help="launch.py argument: chck for updates at startup")
parser.add_argument("--tests", type=str, default=None, help="launch.py argument: run tests in the specified directory") parser.add_argument("--test-server", action='store_true', help="launch.py argument: configure server for testing")
parser.add_argument("--no-tests", action='store_true', help="launch.py argument: do not run tests even if --tests option is specified") parser.add_argument("--skip-prepare-environment", action='store_true', help="launch.py argument: skip all environment preparation")
parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages") parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages")
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored") parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored")
parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",) parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
...@@ -39,7 +40,8 @@ parser.add_argument("--precision", type=str, help="evaluate at this precision", ...@@ -39,7 +40,8 @@ parser.add_argument("--precision", type=str, help="evaluate at this precision",
parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.") parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.")
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site") parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None) parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us") parser.add_argument("--ngrok-region", type=str, help="does not do anything.", default="")
parser.add_argument("--ngrok-options", type=json.loads, help='The options to pass to ngrok in JSON format, e.g.: \'{"authtoken_from_env":true, "basic_auth":"user:password", "oauth_provider":"google", "oauth_allow_emails":"user@asdf.com"}\'', default=dict())
parser.add_argument("--enable-insecure-extension-access", action='store_true', help="enable extensions tab regardless of other options") parser.add_argument("--enable-insecure-extension-access", action='store_true', help="enable extensions tab regardless of other options")
parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer')) parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN')) parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN'))
...@@ -51,16 +53,16 @@ parser.add_argument("--xformers", action='store_true', help="enable xformers for ...@@ -51,16 +53,16 @@ parser.add_argument("--xformers", action='store_true', help="enable xformers for
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work") parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
parser.add_argument("--xformers-flash-attention", action='store_true', help="enable xformers with Flash Attention to improve reproducibility (supported for SD2.x or variant only)") parser.add_argument("--xformers-flash-attention", action='store_true', help="enable xformers with Flash Attention to improve reproducibility (supported for SD2.x or variant only)")
parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything") parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.") parser.add_argument("--opt-split-attention", action='store_true', help="prefer Doggettx's cross-attention layer optimization for automatic choice of optimization")
parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization") parser.add_argument("--opt-sub-quad-attention", action='store_true', help="prefer memory efficient sub-quadratic cross-attention layer optimization for automatic choice of optimization")
parser.add_argument("--sub-quad-q-chunk-size", type=int, help="query chunk size for the sub-quadratic cross-attention layer optimization to use", default=1024) parser.add_argument("--sub-quad-q-chunk-size", type=int, help="query chunk size for the sub-quadratic cross-attention layer optimization to use", default=1024)
parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size for the sub-quadratic cross-attention layer optimization to use", default=None) parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size for the sub-quadratic cross-attention layer optimization to use", default=None)
parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage of VRAM threshold for the sub-quadratic cross-attention layer optimization to use chunking", default=None) parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage of VRAM threshold for the sub-quadratic cross-attention layer optimization to use chunking", default=None)
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.") parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="prefer InvokeAI's cross-attention layer optimization for automatic choice of optimization")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") parser.add_argument("--opt-split-attention-v1", action='store_true', help="prefer older version of split attention optimization for automatic choice of optimization")
parser.add_argument("--opt-sdp-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization; requires PyTorch 2.*") parser.add_argument("--opt-sdp-attention", action='store_true', help="prefer scaled dot product cross-attention layer optimization for automatic choice of optimization; requires PyTorch 2.*")
parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization without memory efficient attention, makes image generation deterministic; requires PyTorch 2.*") parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="prefer scaled dot product cross-attention layer optimization without memory efficient attention for automatic choice of optimization, makes image generation deterministic; requires PyTorch 2.*")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization")
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI") parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower) parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests") parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
...@@ -75,6 +77,7 @@ parser.add_argument("--gradio-auth", type=str, help='set gradio authentication l ...@@ -75,6 +77,7 @@ parser.add_argument("--gradio-auth", type=str, help='set gradio authentication l
parser.add_argument("--gradio-auth-path", type=str, help='set gradio authentication file path ex. "/path/to/auth/file" same auth format as --gradio-auth', default=None) parser.add_argument("--gradio-auth-path", type=str, help='set gradio authentication file path ex. "/path/to/auth/file" same auth format as --gradio-auth', default=None)
parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything') parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything')
parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything") parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything")
parser.add_argument("--gradio-allowed-path", action='append', help="add path to gradio's allowed_paths, make it possible to serve files from it")
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last") parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(data_path, 'styles.csv')) parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(data_path, 'styles.csv'))
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False) parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
...@@ -95,9 +98,12 @@ parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origin( ...@@ -95,9 +98,12 @@ parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origin(
parser.add_argument("--cors-allow-origins-regex", type=str, help="Allowed CORS origin(s) in the form of a single regular expression", default=None) parser.add_argument("--cors-allow-origins-regex", type=str, help="Allowed CORS origin(s) in the form of a single regular expression", default=None)
parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None) parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None)
parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None) parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
parser.add_argument("--disable-tls-verify", action="store_false", help="When passed, enables the use of self-signed certificates.", default=None)
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None) parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
parser.add_argument("--gradio-queue", action='store_true', help="does not do anything", default=True) parser.add_argument("--gradio-queue", action='store_true', help="does not do anything", default=True)
parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gradio queue; causes the webpage to use http requests instead of websockets; was the defaul in earlier versions") parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gradio queue; causes the webpage to use http requests instead of websockets; was the defaul in earlier versions")
parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers") parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False) parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False) parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server')
# this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py # this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py
import math import math
import numpy as np
import torch import torch
from torch import nn, Tensor from torch import nn, Tensor
import torch.nn.functional as F import torch.nn.functional as F
from typing import Optional, List from typing import Optional
from modules.codeformer.vqgan_arch import * from modules.codeformer.vqgan_arch import VQAutoEncoder, ResBlock
from basicsr.utils import get_root_logger
from basicsr.utils.registry import ARCH_REGISTRY from basicsr.utils.registry import ARCH_REGISTRY
def calc_mean_std(feat, eps=1e-5): def calc_mean_std(feat, eps=1e-5):
...@@ -163,8 +161,8 @@ class Fuse_sft_block(nn.Module): ...@@ -163,8 +161,8 @@ class Fuse_sft_block(nn.Module):
class CodeFormer(VQAutoEncoder): class CodeFormer(VQAutoEncoder):
def __init__(self, dim_embd=512, n_head=8, n_layers=9, def __init__(self, dim_embd=512, n_head=8, n_layers=9,
codebook_size=1024, latent_size=256, codebook_size=1024, latent_size=256,
connect_list=['32', '64', '128', '256'], connect_list=('32', '64', '128', '256'),
fix_modules=['quantize','generator']): fix_modules=('quantize', 'generator')):
super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size) super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
if fix_modules is not None: if fix_modules is not None:
......
...@@ -5,11 +5,9 @@ VQGAN code, adapted from the original created by the Unleashing Transformers aut ...@@ -5,11 +5,9 @@ VQGAN code, adapted from the original created by the Unleashing Transformers aut
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
''' '''
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import copy
from basicsr.utils import get_root_logger from basicsr.utils import get_root_logger
from basicsr.utils.registry import ARCH_REGISTRY from basicsr.utils.registry import ARCH_REGISTRY
...@@ -328,7 +326,7 @@ class Generator(nn.Module): ...@@ -328,7 +326,7 @@ class Generator(nn.Module):
@ARCH_REGISTRY.register() @ARCH_REGISTRY.register()
class VQAutoEncoder(nn.Module): class VQAutoEncoder(nn.Module):
def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256, def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=None, codebook_size=1024, emb_dim=256,
beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None): beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
super().__init__() super().__init__()
logger = get_root_logger() logger = get_root_logger()
...@@ -339,7 +337,7 @@ class VQAutoEncoder(nn.Module): ...@@ -339,7 +337,7 @@ class VQAutoEncoder(nn.Module):
self.embed_dim = emb_dim self.embed_dim = emb_dim
self.ch_mult = ch_mult self.ch_mult = ch_mult
self.resolution = img_size self.resolution = img_size
self.attn_resolutions = attn_resolutions self.attn_resolutions = attn_resolutions or [16]
self.quantizer_type = quantizer self.quantizer_type = quantizer
self.encoder = Encoder( self.encoder = Encoder(
self.in_channels, self.in_channels,
......
...@@ -33,11 +33,9 @@ def setup_model(dirname): ...@@ -33,11 +33,9 @@ def setup_model(dirname):
try: try:
from torchvision.transforms.functional import normalize from torchvision.transforms.functional import normalize
from modules.codeformer.codeformer_arch import CodeFormer from modules.codeformer.codeformer_arch import CodeFormer
from basicsr.utils.download_util import load_file_from_url from basicsr.utils import img2tensor, tensor2img
from basicsr.utils import imwrite, img2tensor, tensor2img
from facelib.utils.face_restoration_helper import FaceRestoreHelper from facelib.utils.face_restoration_helper import FaceRestoreHelper
from facelib.detection.retinaface import retinaface from facelib.detection.retinaface import retinaface
from modules.shared import cmd_opts
net_class = CodeFormer net_class = CodeFormer
...@@ -96,7 +94,7 @@ def setup_model(dirname): ...@@ -96,7 +94,7 @@ def setup_model(dirname):
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5) self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
self.face_helper.align_warp_face() self.face_helper.align_warp_face()
for idx, cropped_face in enumerate(self.face_helper.cropped_faces): for cropped_face in self.face_helper.cropped_faces:
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer) cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
......
This diff is collapsed.
...@@ -2,7 +2,6 @@ import os ...@@ -2,7 +2,6 @@ import os
import re import re
import torch import torch
from PIL import Image
import numpy as np import numpy as np
from modules import modelloader, paths, deepbooru_model, devices, images, shared from modules import modelloader, paths, deepbooru_model, devices, images, shared
...@@ -79,7 +78,7 @@ class DeepDanbooru: ...@@ -79,7 +78,7 @@ class DeepDanbooru:
res = [] res = []
filtertags = set([x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")]) filtertags = {x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")}
for tag in [x for x in tags if x not in filtertags]: for tag in [x for x in tags if x not in filtertags]:
probability = probability_dict[tag] probability = probability_dict[tag]
......
...@@ -65,7 +65,7 @@ def enable_tf32(): ...@@ -65,7 +65,7 @@ def enable_tf32():
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407 # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
if any([torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())]): if any(torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())):
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
...@@ -92,14 +92,18 @@ def cond_cast_float(input): ...@@ -92,14 +92,18 @@ def cond_cast_float(input):
def randn(seed, shape): def randn(seed, shape):
from modules.shared import opts
torch.manual_seed(seed) torch.manual_seed(seed)
if device.type == 'mps': if opts.randn_source == "CPU" or device.type == 'mps':
return torch.randn(shape, device=cpu).to(device) return torch.randn(shape, device=cpu).to(device)
return torch.randn(shape, device=device) return torch.randn(shape, device=device)
def randn_without_seed(shape): def randn_without_seed(shape):
if device.type == 'mps': from modules.shared import opts
if opts.randn_source == "CPU" or device.type == 'mps':
return torch.randn(shape, device=cpu).to(device) return torch.randn(shape, device=cpu).to(device)
return torch.randn(shape, device=device) return torch.randn(shape, device=device)
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
from modules import extra_networks, shared, extra_networks from modules import extra_networks, shared
from modules.hypernetworks import hypernetwork from modules.hypernetworks import hypernetwork
...@@ -9,8 +9,9 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork): ...@@ -9,8 +9,9 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
def activate(self, p, params_list): def activate(self, p, params_list):
additional = shared.opts.sd_hypernetwork additional = shared.opts.sd_hypernetwork
if additional != "" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0: if additional != "None" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0:
p.all_prompts = [x + f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts] hypernet_prompt_text = f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>"
p.all_prompts = [f"{prompt}{hypernet_prompt_text}" for prompt in p.all_prompts]
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier])) params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
names = [] names = []
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
from .sampler import UniPCSampler from .sampler import UniPCSampler # noqa: F401
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment