Commit e79b7db4 authored by unknown's avatar unknown

Merge branch 'master' of github.com:AUTOMATIC1111/stable-diffusion-webui into gamepad

parents b921a520 e8a41df4
...@@ -37,20 +37,20 @@ body: ...@@ -37,20 +37,20 @@ body:
id: what-should id: what-should
attributes: attributes:
label: What should have happened? label: What should have happened?
description: tell what you think the normal behavior should be description: Tell what you think the normal behavior should be
validations: validations:
required: true required: true
- type: input - type: input
id: commit id: commit
attributes: attributes:
label: Commit where the problem happens label: Commit where the problem happens
description: Which commit are you running ? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit hash** shown in the cmd/terminal when you launch the UI) description: Which commit are you running ? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit** link at the bottom of the UI, or from the cmd/terminal if you can't launch it.)
validations: validations:
required: true required: true
- type: dropdown - type: dropdown
id: platforms id: platforms
attributes: attributes:
label: What platforms do you use to access UI ? label: What platforms do you use to access the UI ?
multiple: true multiple: true
options: options:
- Windows - Windows
...@@ -74,10 +74,27 @@ body: ...@@ -74,10 +74,27 @@ body:
id: cmdargs id: cmdargs
attributes: attributes:
label: Command Line Arguments label: Command Line Arguments
description: Are you using any launching parameters/command line arguments (modified webui-user.py) ? If yes, please write them below description: Are you using any launching parameters/command line arguments (modified webui-user .bat/.sh) ? If yes, please write them below. Write "No" otherwise.
render: Shell render: Shell
validations:
required: true
- type: textarea
id: extensions
attributes:
label: List of extensions
description: Are you using any extensions other than built-ins? If yes, provide a list, you can copy it at "Extensions" tab. Write "No" otherwise.
validations:
required: true
- type: textarea
id: logs
attributes:
label: Console logs
description: Please provide **full** cmd/terminal logs from the moment you started UI to the end of it, after your bug happened. If it's very long, provide a link to pastebin or similar service.
render: Shell
validations:
required: true
- type: textarea - type: textarea
id: misc id: misc
attributes: attributes:
label: Additional information, context and logs label: Additional information
description: Please provide us with any relevant additional info, context or log output. description: Please provide us with any relevant additional info or context.
name: Feature request name: Feature request
description: Suggest an idea for this project description: Suggest an idea for this project
title: "[Feature Request]: " title: "[Feature Request]: "
labels: ["suggestion"] labels: ["enhancement"]
body: body:
- type: checkboxes - type: checkboxes
......
...@@ -18,8 +18,8 @@ More technical discussion about your changes go here, plus anything that a maint ...@@ -18,8 +18,8 @@ More technical discussion about your changes go here, plus anything that a maint
List the environment you have developed / tested this on. As per the contributing page, changes should be able to work on Windows out of the box. List the environment you have developed / tested this on. As per the contributing page, changes should be able to work on Windows out of the box.
- OS: [e.g. Windows, Linux] - OS: [e.g. Windows, Linux]
- Browser [e.g. chrome, safari] - Browser: [e.g. chrome, safari]
- Graphics card [e.g. NVIDIA RTX 2080 8GB, AMD RX 6600 8GB] - Graphics card: [e.g. NVIDIA RTX 2080 8GB, AMD RX 6600 8GB]
**Screenshots or videos of your changes** **Screenshots or videos of your changes**
......
...@@ -19,22 +19,19 @@ jobs: ...@@ -19,22 +19,19 @@ jobs:
- name: Checkout Code - name: Checkout Code
uses: actions/checkout@v3 uses: actions/checkout@v3
- name: Set up Python 3.10 - name: Set up Python 3.10
uses: actions/setup-python@v3 uses: actions/setup-python@v4
with: with:
python-version: 3.10.6 python-version: 3.10.6
- uses: actions/cache@v2 cache: pip
with: cache-dependency-path: |
path: ~/.cache/pip **/requirements*txt
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
restore-keys: |
${{ runner.os }}-pip-
- name: Install PyLint - name: Install PyLint
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install pylint pip install pylint
# This lets PyLint check to see if it can resolve imports # This lets PyLint check to see if it can resolve imports
- name: Install dependencies - name: Install dependencies
run : | run: |
export COMMANDLINE_ARGS="--skip-torch-cuda-test --exit" export COMMANDLINE_ARGS="--skip-torch-cuda-test --exit"
python launch.py python launch.py
- name: Analysing the code with pylint - name: Analysing the code with pylint
......
...@@ -14,13 +14,11 @@ jobs: ...@@ -14,13 +14,11 @@ jobs:
uses: actions/setup-python@v4 uses: actions/setup-python@v4
with: with:
python-version: 3.10.6 python-version: 3.10.6
- uses: actions/cache@v3 cache: pip
with: cache-dependency-path: |
path: ~/.cache/pip **/requirements*txt
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
restore-keys: ${{ runner.os }}-pip-
- name: Run tests - name: Run tests
run: python launch.py --tests basic_features --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test run: python launch.py --tests --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test
- name: Upload main app stdout-stderr - name: Upload main app stdout-stderr
uses: actions/upload-artifact@v3 uses: actions/upload-artifact@v3
if: always() if: always()
......
...@@ -32,3 +32,4 @@ notification.mp3 ...@@ -32,3 +32,4 @@ notification.mp3
/extensions /extensions
/test/stdout.txt /test/stdout.txt
/test/stderr.txt /test/stderr.txt
/cache.json
This diff is collapsed.
# Stable Diffusion web UI # Stable Diffusion web UI
A browser interface based on Gradio library for Stable Diffusion. A browser interface based on Gradio library for Stable Diffusion.
![](txt2img_Screenshot.png) ![](screenshot.png)
Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Scripts) wiki page for extra scripts developed by users.
## Features ## Features
[Detailed feature showcase with images](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features): [Detailed feature showcase with images](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features):
...@@ -19,7 +17,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web ...@@ -19,7 +17,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
- a man in a (tuxedo:1.21) - alternative syntax - a man in a (tuxedo:1.21) - alternative syntax
- select text and press ctrl+up or ctrl+down to automatically adjust attention to selected text (code contributed by anonymous user) - select text and press ctrl+up or ctrl+down to automatically adjust attention to selected text (code contributed by anonymous user)
- Loopback, run img2img processing multiple times - Loopback, run img2img processing multiple times
- X/Y plot, a way to draw a 2 dimensional plot of images with different parameters - X/Y/Z plot, a way to draw a 3 dimensional plot of images with different parameters
- Textual Inversion - Textual Inversion
- have as many embeddings as you want and use any names you like for them - have as many embeddings as you want and use any names you like for them
- use multiple embeddings with different numbers of vectors per token - use multiple embeddings with different numbers of vectors per token
...@@ -51,9 +49,9 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web ...@@ -51,9 +49,9 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
- Running arbitrary python code from UI (must run with --allow-code to enable) - Running arbitrary python code from UI (must run with --allow-code to enable)
- Mouseover hints for most UI elements - Mouseover hints for most UI elements
- Possible to change defaults/mix/max/step values for UI elements via text config - Possible to change defaults/mix/max/step values for UI elements via text config
- Random artist button
- Tiling support, a checkbox to create images that can be tiled like textures - Tiling support, a checkbox to create images that can be tiled like textures
- Progress bar and live image generation preview - Progress bar and live image generation preview
- Can use a separate neural network to produce previews with almost none VRAM or compute requirement
- Negative prompt, an extra text field that allows you to list what you don't want to see in generated image - Negative prompt, an extra text field that allows you to list what you don't want to see in generated image
- Styles, a way to save part of prompt and easily apply them via dropdown later - Styles, a way to save part of prompt and easily apply them via dropdown later
- Variations, a way to generate same image but with tiny differences - Variations, a way to generate same image but with tiny differences
...@@ -78,13 +76,22 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web ...@@ -78,13 +76,22 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
- hypernetworks and embeddings options - hypernetworks and embeddings options
- Preprocessing images: cropping, mirroring, autotagging using BLIP or deepdanbooru (for anime) - Preprocessing images: cropping, mirroring, autotagging using BLIP or deepdanbooru (for anime)
- Clip skip - Clip skip
- Use Hypernetworks - Hypernetworks
- Use VAEs - Loras (same as Hypernetworks but more pretty)
- A sparate UI where you can choose, with preview, which embeddings, hypernetworks or Loras to add to your prompt.
- Can select to load a different VAE from settings screen
- Estimated completion time in progress bar - Estimated completion time in progress bar
- API - API
- Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML. - Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML.
- via extension: [Aesthetic Gradients](https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients), a way to generate images with a specific aesthetic by using clip images embeds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients)) - via extension: [Aesthetic Gradients](https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients), a way to generate images with a specific aesthetic by using clip images embeds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients))
- [Stable Diffusion 2.0](https://github.com/Stability-AI/stablediffusion) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20) for instructions - [Stable Diffusion 2.0](https://github.com/Stability-AI/stablediffusion) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20) for instructions
- [Alt-Diffusion](https://arxiv.org/abs/2211.06679) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#alt-diffusion) for instructions
- Now without any bad letters!
- Load checkpoints in safetensors format
- Eased resolution restriction: generated image's domension must be a multiple of 8 rather than 64
- Now with a license!
- Reorder elements in the UI from settings screen
-
## Installation and Running ## Installation and Running
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs. Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
...@@ -97,9 +104,8 @@ Alternatively, use online services (like Google Colab): ...@@ -97,9 +104,8 @@ Alternatively, use online services (like Google Colab):
1. Install [Python 3.10.6](https://www.python.org/downloads/windows/), checking "Add Python to PATH" 1. Install [Python 3.10.6](https://www.python.org/downloads/windows/), checking "Add Python to PATH"
2. Install [git](https://git-scm.com/download/win). 2. Install [git](https://git-scm.com/download/win).
3. Download the stable-diffusion-webui repository, for example by running `git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git`. 3. Download the stable-diffusion-webui repository, for example by running `git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git`.
4. Place `model.ckpt` in the `models` directory (see [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) for where to get it). 4. Place stable diffusion checkpoint (`model.ckpt`) in the `models/Stable-diffusion` directory (see [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) for where to get it).
5. _*(Optional)*_ Place `GFPGANv1.4.pth` in the base directory, alongside `webui.py` (see [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) for where to get it). 5. Run `webui-user.bat` from Windows Explorer as normal, non-administrator, user.
6. Run `webui-user.bat` from Windows Explorer as normal, non-administrator, user.
### Automatic Installation on Linux ### Automatic Installation on Linux
1. Install the dependencies: 1. Install the dependencies:
...@@ -127,6 +133,8 @@ Here's how to add code to this repo: [Contributing](https://github.com/AUTOMATIC ...@@ -127,6 +133,8 @@ Here's how to add code to this repo: [Contributing](https://github.com/AUTOMATIC
The documentation was moved from this README over to the project's [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki). The documentation was moved from this README over to the project's [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki).
## Credits ## Credits
Licenses for borrowed code can be found in `Settings -> Licenses` screen, and also in `html/licenses.html` file.
- Stable Diffusion - https://github.com/CompVis/stable-diffusion, https://github.com/CompVis/taming-transformers - Stable Diffusion - https://github.com/CompVis/stable-diffusion, https://github.com/CompVis/taming-transformers
- k-diffusion - https://github.com/crowsonkb/k-diffusion.git - k-diffusion - https://github.com/crowsonkb/k-diffusion.git
- GFPGAN - https://github.com/TencentARC/GFPGAN.git - GFPGAN - https://github.com/TencentARC/GFPGAN.git
...@@ -139,6 +147,7 @@ The documentation was moved from this README over to the project's [wiki](https: ...@@ -139,6 +147,7 @@ The documentation was moved from this README over to the project's [wiki](https:
- Ideas for optimizations - https://github.com/basujindal/stable-diffusion - Ideas for optimizations - https://github.com/basujindal/stable-diffusion
- Cross Attention layer optimization - Doggettx - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing. - Cross Attention layer optimization - Doggettx - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing.
- Cross Attention layer optimization - InvokeAI, lstein - https://github.com/invoke-ai/InvokeAI (originally http://github.com/lstein/stable-diffusion) - Cross Attention layer optimization - InvokeAI, lstein - https://github.com/invoke-ai/InvokeAI (originally http://github.com/lstein/stable-diffusion)
- Sub-quadratic Cross Attention layer optimization - Alex Birch (https://github.com/Birch-san/diffusers/pull/1), Amin Rezaei (https://github.com/AminRezaei0x443/memory-efficient-attention)
- Textual Inversion - Rinon Gal - https://github.com/rinongal/textual_inversion (we're not using his code, but we are using his ideas). - Textual Inversion - Rinon Gal - https://github.com/rinongal/textual_inversion (we're not using his code, but we are using his ideas).
- Idea for SD upscale - https://github.com/jquesnelle/txt2imghd - Idea for SD upscale - https://github.com/jquesnelle/txt2imghd
- Noise generation for outpainting mk2 - https://github.com/parlance-zz/g-diffuser-bot - Noise generation for outpainting mk2 - https://github.com/parlance-zz/g-diffuser-bot
...@@ -146,6 +155,8 @@ The documentation was moved from this README over to the project's [wiki](https: ...@@ -146,6 +155,8 @@ The documentation was moved from this README over to the project's [wiki](https:
- Idea for Composable Diffusion - https://github.com/energy-based-model/Compositional-Visual-Generation-with-Composable-Diffusion-Models-PyTorch - Idea for Composable Diffusion - https://github.com/energy-based-model/Compositional-Visual-Generation-with-Composable-Diffusion-Models-PyTorch
- xformers - https://github.com/facebookresearch/xformers - xformers - https://github.com/facebookresearch/xformers
- DeepDanbooru - interrogator for anime diffusers https://github.com/KichangKim/DeepDanbooru - DeepDanbooru - interrogator for anime diffusers https://github.com/KichangKim/DeepDanbooru
- Sampling in float32 precision from a float16 UNet - marunine for the idea, Birch-san for the example Diffusers implementation (https://github.com/Birch-san/diffusers-play/tree/92feee6)
- Instruct pix2pix - Tim Brooks (star), Aleksander Holynski (star), Alexei A. Efros (no star) - https://github.com/timothybrooks/instruct-pix2pix
- Security advice - RyotaK - Security advice - RyotaK
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user. - Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
- (You) - (You)
This source diff could not be displayed because it is too large. You can view the blob instead.
model:
base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 10000 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: modules.xlmr.BertSeriesModelWithTransformation
params:
name: "XLMR-Large"
\ No newline at end of file
# File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
# See more details in LICENSE.
model:
base_learning_rate: 1.0e-04
target: modules.models.diffusion.ddpm_edit.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: edited
cond_stage_key: edit
# image_size: 64
# image_size: 32
image_size: 16
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: hybrid
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: true
load_ema: true
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 0 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 8
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
data:
target: main.DataModuleFromConfig
params:
batch_size: 128
num_workers: 1
wrap: false
validation:
target: edit_dataset.EditDataset
params:
path: data/clip-filtered-dataset
cache_dir: data/
cache_name: data_10k
split: val
min_text_sim: 0.2
min_image_sim: 0.75
min_direction_sim: 0.2
max_samples_per_prompt: 1
min_resize_res: 512
max_resize_res: 512
crop_res: 512
output_as_edit: False
real_input: True
model:
base_learning_rate: 7.5e-05
target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: hybrid # important
monitor: val/loss_simple_ema
scale_factor: 0.18215
finetune_keys: null
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 9 # 4 data + 4 downscaled image + 1 mask
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
import os import os
import gc import gc
import time import time
import warnings
import numpy as np import numpy as np
import torch import torch
...@@ -15,8 +14,6 @@ from ldm.models.diffusion.ddim import DDIMSampler ...@@ -15,8 +14,6 @@ from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import instantiate_from_config, ismap from ldm.util import instantiate_from_config, ismap
from modules import shared, sd_hijack from modules import shared, sd_hijack
warnings.filterwarnings("ignore", category=UserWarning)
cached_ldsr_model: torch.nn.Module = None cached_ldsr_model: torch.nn.Module = None
......
from modules import extra_networks
import lora
class ExtraNetworkLora(extra_networks.ExtraNetwork):
def __init__(self):
super().__init__('lora')
def activate(self, p, params_list):
names = []
multipliers = []
for params in params_list:
assert len(params.items) > 0
names.append(params.items[0])
multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
lora.load_loras(names, multipliers)
def deactivate(self, p):
pass
import glob
import os
import re
import torch
from modules import shared, devices, sd_models
re_digits = re.compile(r"\d+")
re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)")
re_unet_mid_blocks = re.compile(r"lora_unet_mid_block_attentions_(\d+)_(.+)")
re_unet_up_blocks = re.compile(r"lora_unet_up_blocks_(\d+)_attentions_(\d+)_(.+)")
re_text_block = re.compile(r"lora_te_text_model_encoder_layers_(\d+)_(.+)")
def convert_diffusers_name_to_compvis(key):
def match(match_list, regex):
r = re.match(regex, key)
if not r:
return False
match_list.clear()
match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()])
return True
m = []
if match(m, re_unet_down_blocks):
return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[1]}_1_{m[2]}"
if match(m, re_unet_mid_blocks):
return f"diffusion_model_middle_block_1_{m[1]}"
if match(m, re_unet_up_blocks):
return f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_1_{m[2]}"
if match(m, re_text_block):
return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
return key
class LoraOnDisk:
def __init__(self, name, filename):
self.name = name
self.filename = filename
class LoraModule:
def __init__(self, name):
self.name = name
self.multiplier = 1.0
self.modules = {}
self.mtime = None
class LoraUpDownModule:
def __init__(self):
self.up = None
self.down = None
self.alpha = None
def assign_lora_names_to_compvis_modules(sd_model):
lora_layer_mapping = {}
for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():
lora_name = name.replace(".", "_")
lora_layer_mapping[lora_name] = module
module.lora_layer_name = lora_name
for name, module in shared.sd_model.model.named_modules():
lora_name = name.replace(".", "_")
lora_layer_mapping[lora_name] = module
module.lora_layer_name = lora_name
sd_model.lora_layer_mapping = lora_layer_mapping
def load_lora(name, filename):
lora = LoraModule(name)
lora.mtime = os.path.getmtime(filename)
sd = sd_models.read_state_dict(filename)
keys_failed_to_match = []
for key_diffusers, weight in sd.items():
fullkey = convert_diffusers_name_to_compvis(key_diffusers)
key, lora_key = fullkey.split(".", 1)
sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
if sd_module is None:
keys_failed_to_match.append(key_diffusers)
continue
lora_module = lora.modules.get(key, None)
if lora_module is None:
lora_module = LoraUpDownModule()
lora.modules[key] = lora_module
if lora_key == "alpha":
lora_module.alpha = weight.item()
continue
if type(sd_module) == torch.nn.Linear:
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
elif type(sd_module) == torch.nn.Conv2d:
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
else:
assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}'
with torch.no_grad():
module.weight.copy_(weight)
module.to(device=devices.device, dtype=devices.dtype)
if lora_key == "lora_up.weight":
lora_module.up = module
elif lora_key == "lora_down.weight":
lora_module.down = module
else:
assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha'
if len(keys_failed_to_match) > 0:
print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}")
return lora
def load_loras(names, multipliers=None):
already_loaded = {}
for lora in loaded_loras:
if lora.name in names:
already_loaded[lora.name] = lora
loaded_loras.clear()
loras_on_disk = [available_loras.get(name, None) for name in names]
if any([x is None for x in loras_on_disk]):
list_available_loras()
loras_on_disk = [available_loras.get(name, None) for name in names]
for i, name in enumerate(names):
lora = already_loaded.get(name, None)
lora_on_disk = loras_on_disk[i]
if lora_on_disk is not None:
if lora is None or os.path.getmtime(lora_on_disk.filename) > lora.mtime:
lora = load_lora(name, lora_on_disk.filename)
if lora is None:
print(f"Couldn't find Lora with name {name}")
continue
lora.multiplier = multipliers[i] if multipliers else 1.0
loaded_loras.append(lora)
def lora_forward(module, input, res):
if len(loaded_loras) == 0:
return res
lora_layer_name = getattr(module, 'lora_layer_name', None)
for lora in loaded_loras:
module = lora.modules.get(lora_layer_name, None)
if module is not None:
if shared.opts.lora_apply_to_outputs and res.shape == input.shape:
res = res + module.up(module.down(res)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
else:
res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
return res
def lora_Linear_forward(self, input):
return lora_forward(self, input, torch.nn.Linear_forward_before_lora(self, input))
def lora_Conv2d_forward(self, input):
return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora(self, input))
def list_available_loras():
available_loras.clear()
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
candidates = \
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.pt'), recursive=True) + \
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.safetensors'), recursive=True) + \
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.ckpt'), recursive=True)
for filename in sorted(candidates):
if os.path.isdir(filename):
continue
name = os.path.splitext(os.path.basename(filename))[0]
available_loras[name] = LoraOnDisk(name, filename)
available_loras = {}
loaded_loras = []
list_available_loras()
import os
from modules import paths
def preload(parser):
parser.add_argument("--lora-dir", type=str, help="Path to directory with Lora networks.", default=os.path.join(paths.models_path, 'Lora'))
import torch
import lora
import extra_networks_lora
import ui_extra_networks_lora
from modules import script_callbacks, ui_extra_networks, extra_networks, shared
def unload():
torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora
torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora
def before_ui():
ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora())
extra_networks.register_extra_network(extra_networks_lora.ExtraNetworkLora())
if not hasattr(torch.nn, 'Linear_forward_before_lora'):
torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward
if not hasattr(torch.nn, 'Conv2d_forward_before_lora'):
torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward
torch.nn.Linear.forward = lora.lora_Linear_forward
torch.nn.Conv2d.forward = lora.lora_Conv2d_forward
script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
script_callbacks.on_script_unloaded(unload)
script_callbacks.on_before_ui(before_ui)
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
"lora_apply_to_outputs": shared.OptionInfo(False, "Apply Lora to outputs rather than inputs when possible (experimental)"),
}))
import json
import os
import lora
from modules import shared, ui_extra_networks
class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
def __init__(self):
super().__init__('Lora')
def refresh(self):
lora.list_available_loras()
def list_items(self):
for name, lora_on_disk in lora.available_loras.items():
path, ext = os.path.splitext(lora_on_disk.filename)
previews = [path + ".png", path + ".preview.png"]
preview = None
for file in previews:
if os.path.isfile(file):
preview = "./file=" + file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(file))
break
yield {
"name": name,
"filename": path,
"preview": preview,
"prompt": json.dumps(f"<lora:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
"local_preview": path + ".png",
}
def allowed_directories_for_previews(self):
return [shared.cmd_opts.lora_dir]
...@@ -8,7 +8,7 @@ from basicsr.utils.download_util import load_file_from_url ...@@ -8,7 +8,7 @@ from basicsr.utils.download_util import load_file_from_url
from tqdm import tqdm from tqdm import tqdm
from modules import modelloader, devices, script_callbacks, shared from modules import modelloader, devices, script_callbacks, shared
from modules.shared import cmd_opts, opts from modules.shared import cmd_opts, opts, state
from swinir_model_arch import SwinIR as net from swinir_model_arch import SwinIR as net
from swinir_model_arch_v2 import Swin2SR as net2 from swinir_model_arch_v2 import Swin2SR as net2
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
...@@ -145,7 +145,13 @@ def inference(img, model, tile, tile_overlap, window_size, scale): ...@@ -145,7 +145,13 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar: with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
for h_idx in h_idx_list: for h_idx in h_idx_list:
if state.interrupted or state.skipped:
break
for w_idx in w_idx_list: for w_idx in w_idx_list:
if state.interrupted or state.skipped:
break
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
out_patch = model(in_patch) out_patch = model(in_patch)
out_patch_mask = torch.ones_like(out_patch) out_patch_mask = torch.ones_like(out_patch)
......
...@@ -4,16 +4,10 @@ ...@@ -4,16 +4,10 @@
// Counts open and closed brackets (round, square, curly) in the prompt and negative prompt text boxes in the txt2img and img2img tabs. // Counts open and closed brackets (round, square, curly) in the prompt and negative prompt text boxes in the txt2img and img2img tabs.
// If there's a mismatch, the keyword counter turns red and if you hover on it, a tooltip tells you what's wrong. // If there's a mismatch, the keyword counter turns red and if you hover on it, a tooltip tells you what's wrong.
function checkBrackets(evt) { function checkBrackets(evt, textArea, counterElt) {
textArea = evt.target; errorStringParen = '(...) - Different number of opening and closing parentheses detected.\n';
tabName = evt.target.parentElement.parentElement.id.split("_")[0]; errorStringSquare = '[...] - Different number of opening and closing square brackets detected.\n';
counterElt = document.querySelector('gradio-app').shadowRoot.querySelector('#' + tabName + '_token_counter'); errorStringCurly = '{...} - Different number of opening and closing curly brackets detected.\n';
promptName = evt.target.parentElement.parentElement.id.includes('neg') ? ' negative' : '';
errorStringParen = '(' + tabName + promptName + ' prompt) - Different number of opening and closing parentheses detected.\n';
errorStringSquare = '[' + tabName + promptName + ' prompt] - Different number of opening and closing square brackets detected.\n';
errorStringCurly = '{' + tabName + promptName + ' prompt} - Different number of opening and closing curly brackets detected.\n';
openBracketRegExp = /\(/g; openBracketRegExp = /\(/g;
closeBracketRegExp = /\)/g; closeBracketRegExp = /\)/g;
...@@ -86,22 +80,31 @@ function checkBrackets(evt) { ...@@ -86,22 +80,31 @@ function checkBrackets(evt) {
} }
if(counterElt.title != '') { if(counterElt.title != '') {
counterElt.style = 'color: #FF5555;'; counterElt.classList.add('error');
} else { } else {
counterElt.style = ''; counterElt.classList.remove('error');
} }
} }
function setupBracketChecking(id_prompt, id_counter){
var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea");
var counter = gradioApp().getElementById(id_counter)
textarea.addEventListener("input", function(evt){
checkBrackets(evt, textarea, counter)
});
}
var shadowRootLoaded = setInterval(function() { var shadowRootLoaded = setInterval(function() {
var shadowTextArea = document.querySelector('gradio-app').shadowRoot.querySelectorAll('#txt2img_prompt > label > textarea'); var shadowRoot = document.querySelector('gradio-app').shadowRoot;
if(shadowTextArea.length < 1) { if(! shadowRoot) return false;
return false;
} var shadowTextArea = shadowRoot.querySelectorAll('#txt2img_prompt > label > textarea');
if(shadowTextArea.length < 1) return false;
clearInterval(shadowRootLoaded); clearInterval(shadowRootLoaded);
document.querySelector('gradio-app').shadowRoot.querySelector('#txt2img_prompt').onkeyup = checkBrackets; setupBracketChecking('txt2img_prompt', 'txt2img_token_counter')
document.querySelector('gradio-app').shadowRoot.querySelector('#txt2img_neg_prompt').onkeyup = checkBrackets; setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter')
document.querySelector('gradio-app').shadowRoot.querySelector('#img2img_prompt').onkeyup = checkBrackets; setupBracketChecking('img2img_prompt', 'imgimg_token_counter')
document.querySelector('gradio-app').shadowRoot.querySelector('#img2img_neg_prompt').onkeyup = checkBrackets; setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter')
}, 1000); }, 1000);
<div class='card' {preview_html} onclick={card_clicked}>
<div class='actions'>
<div class='additional'>
<ul>
<a href="#" title="replace preview image with currently selected in gallery" onclick={save_card_preview}>replace preview</a>
</ul>
</div>
<span class='name'>{name}</span>
</div>
</div>
<div class='nocards'>
<h1>Nothing here. Add some content to the following directories:</h1>
<ul>
{dirs}
</ul>
</div>
<div>
<a href="/docs">API</a>
 • 
<a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui">Github</a>
 • 
<a href="https://gradio.app">Gradio</a>
 • 
<a href="/" onclick="javascript:gradioApp().getElementById('settings_restart_gradio').click(); return false">Reload UI</a>
</div>
<br />
<div class="versions">
{versions}
</div>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24">
<filter id='shadow' color-interpolation-filters="sRGB">
<feDropShadow flood-color="black" dx="0" dy="0" flood-opacity="0.9" stdDeviation="0.5"/>
<feDropShadow flood-color="black" dx="0" dy="0" flood-opacity="0.9" stdDeviation="0.5"/>
</filter>
<path style="filter:url(#shadow);" fill="#FFFFFF" d="M13.18 19C13.35 19.72 13.64 20.39 14.03 21H5C3.9 21 3 20.11 3 19V5C3 3.9 3.9 3 5 3H19C20.11 3 21 3.9 21 5V11.18C20.5 11.07 20 11 19.5 11C19.33 11 19.17 11 19 11.03V5H5V19H13.18M11.21 15.83L9.25 13.47L6.5 17H13.03C13.14 15.54 13.73 14.22 14.64 13.19L13.96 12.29L11.21 15.83M19 13.5V12L16.75 14.25L19 16.5V15C20.38 15 21.5 16.12 21.5 17.5C21.5 17.9 21.41 18.28 21.24 18.62L22.33 19.71C22.75 19.08 23 18.32 23 17.5C23 15.29 21.21 13.5 19 13.5M19 20C17.62 20 16.5 18.88 16.5 17.5C16.5 17.1 16.59 16.72 16.76 16.38L15.67 15.29C15.25 15.92 15 16.68 15 17.5C15 19.71 16.79 21.5 19 21.5V23L21.25 20.75L19 18.5V20Z" />
</svg>
This diff is collapsed.
...@@ -21,12 +21,17 @@ function dimensionChange(e, is_width, is_height){ ...@@ -21,12 +21,17 @@ function dimensionChange(e, is_width, is_height){
var targetElement = null; var targetElement = null;
var tabIndex = get_tab_index('mode_img2img') var tabIndex = get_tab_index('mode_img2img')
if(tabIndex == 0){ if(tabIndex == 0){ // img2img
targetElement = gradioApp().querySelector('div[data-testid=image] img'); targetElement = gradioApp().querySelector('div[data-testid=image] img');
} else if(tabIndex == 1){ } else if(tabIndex == 1){ //Sketch
targetElement = gradioApp().querySelector('#img2img_sketch div[data-testid=image] img');
} else if(tabIndex == 2){ // Inpaint
targetElement = gradioApp().querySelector('#img2maskimg div[data-testid=image] img'); targetElement = gradioApp().querySelector('#img2maskimg div[data-testid=image] img');
} else if(tabIndex == 3){ // Inpaint sketch
targetElement = gradioApp().querySelector('#inpaint_sketch div[data-testid=image] img');
} }
if(targetElement){ if(targetElement){
var arPreviewRect = gradioApp().querySelector('#imageARPreview'); var arPreviewRect = gradioApp().querySelector('#imageARPreview');
......
...@@ -9,11 +9,19 @@ function dropReplaceImage( imgWrap, files ) { ...@@ -9,11 +9,19 @@ function dropReplaceImage( imgWrap, files ) {
return; return;
} }
const tmpFile = files[0];
imgWrap.querySelector('.modify-upload button + button, .touch-none + div button + button')?.click(); imgWrap.querySelector('.modify-upload button + button, .touch-none + div button + button')?.click();
const callback = () => { const callback = () => {
const fileInput = imgWrap.querySelector('input[type="file"]'); const fileInput = imgWrap.querySelector('input[type="file"]');
if ( fileInput ) { if ( fileInput ) {
if ( files.length === 0 ) {
files = new DataTransfer();
files.items.add(tmpFile);
fileInput.files = files.files;
} else {
fileInput.files = files; fileInput.files = files;
}
fileInput.dispatchEvent(new Event('change')); fileInput.dispatchEvent(new Event('change'));
} }
}; };
......
addEventListener('keydown', (event) => { function keyupEditAttention(event){
let target = event.originalTarget || event.composedPath()[0]; let target = event.originalTarget || event.composedPath()[0];
if (!target.matches("#toprow textarea.gr-text-input[placeholder]")) return; if (!target.matches("[id*='_toprow'] textarea.gr-text-input[placeholder]")) return;
if (! (event.metaKey || event.ctrlKey)) return; if (! (event.metaKey || event.ctrlKey)) return;
let isPlus = event.key == "ArrowUp"
let plus = "ArrowUp" let isMinus = event.key == "ArrowDown"
let minus = "ArrowDown" if (!isPlus && !isMinus) return;
if (event.key != plus && event.key != minus) return;
let selectionStart = target.selectionStart; let selectionStart = target.selectionStart;
let selectionEnd = target.selectionEnd; let selectionEnd = target.selectionEnd;
// If the user hasn't selected anything, let's select their current parenthesis block let text = target.value;
if (selectionStart === selectionEnd) {
function selectCurrentParenthesisBlock(OPEN, CLOSE){
if (selectionStart !== selectionEnd) return false;
// Find opening parenthesis around current cursor // Find opening parenthesis around current cursor
const before = target.value.substring(0, selectionStart); const before = text.substring(0, selectionStart);
let beforeParen = before.lastIndexOf("("); let beforeParen = before.lastIndexOf(OPEN);
if (beforeParen == -1) return; if (beforeParen == -1) return false;
let beforeParenClose = before.lastIndexOf(")"); let beforeParenClose = before.lastIndexOf(CLOSE);
while (beforeParenClose !== -1 && beforeParenClose > beforeParen) { while (beforeParenClose !== -1 && beforeParenClose > beforeParen) {
beforeParen = before.lastIndexOf("(", beforeParen - 1); beforeParen = before.lastIndexOf(OPEN, beforeParen - 1);
beforeParenClose = before.lastIndexOf(")", beforeParenClose - 1); beforeParenClose = before.lastIndexOf(CLOSE, beforeParenClose - 1);
} }
// Find closing parenthesis around current cursor // Find closing parenthesis around current cursor
const after = target.value.substring(selectionStart); const after = text.substring(selectionStart);
let afterParen = after.indexOf(")"); let afterParen = after.indexOf(CLOSE);
if (afterParen == -1) return; if (afterParen == -1) return false;
let afterParenOpen = after.indexOf("("); let afterParenOpen = after.indexOf(OPEN);
while (afterParenOpen !== -1 && afterParen > afterParenOpen) { while (afterParenOpen !== -1 && afterParen > afterParenOpen) {
afterParen = after.indexOf(")", afterParen + 1); afterParen = after.indexOf(CLOSE, afterParen + 1);
afterParenOpen = after.indexOf("(", afterParenOpen + 1); afterParenOpen = after.indexOf(OPEN, afterParenOpen + 1);
} }
if (beforeParen === -1 || afterParen === -1) return; if (beforeParen === -1 || afterParen === -1) return false;
// Set the selection to the text between the parenthesis // Set the selection to the text between the parenthesis
const parenContent = target.value.substring(beforeParen + 1, selectionStart + afterParen); const parenContent = text.substring(beforeParen + 1, selectionStart + afterParen);
const lastColon = parenContent.lastIndexOf(":"); const lastColon = parenContent.lastIndexOf(":");
selectionStart = beforeParen + 1; selectionStart = beforeParen + 1;
selectionEnd = selectionStart + lastColon; selectionEnd = selectionStart + lastColon;
target.setSelectionRange(selectionStart, selectionEnd); target.setSelectionRange(selectionStart, selectionEnd);
return true;
}
// If the user hasn't selected anything, let's select their current parenthesis block
if(! selectCurrentParenthesisBlock('<', '>')){
selectCurrentParenthesisBlock('(', ')')
} }
event.preventDefault(); event.preventDefault();
if (selectionStart == 0 || target.value[selectionStart - 1] != "(") { closeCharacter = ')'
target.value = target.value.slice(0, selectionStart) + delta = opts.keyedit_precision_attention
"(" + target.value.slice(selectionStart, selectionEnd) + ":1.0)" +
target.value.slice(selectionEnd);
target.focus(); if (selectionStart > 0 && text[selectionStart - 1] == '<'){
target.selectionStart = selectionStart + 1; closeCharacter = '>'
target.selectionEnd = selectionEnd + 1; delta = opts.keyedit_precision_extra
} else if (selectionStart == 0 || text[selectionStart - 1] != "(") {
} else { // do not include spaces at the end
end = target.value.slice(selectionEnd + 1).indexOf(")") + 1; while(selectionEnd > selectionStart && text[selectionEnd-1] == ' '){
weight = parseFloat(target.value.slice(selectionEnd + 1, selectionEnd + 1 + end)); selectionEnd -= 1;
}
if(selectionStart == selectionEnd){
return
}
text = text.slice(0, selectionStart) + "(" + text.slice(selectionStart, selectionEnd) + ":1.0)" + text.slice(selectionEnd);
selectionStart += 1;
selectionEnd += 1;
}
end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1;
weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end));
if (isNaN(weight)) return; if (isNaN(weight)) return;
if (event.key == minus) weight -= 0.1;
if (event.key == plus) weight += 0.1;
weight += isPlus ? delta : -delta;
weight = parseFloat(weight.toPrecision(12)); weight = parseFloat(weight.toPrecision(12));
if(String(weight).length == 1) weight += ".0"
target.value = target.value.slice(0, selectionEnd + 1) + text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1);
weight +
target.value.slice(selectionEnd + 1 + end - 1);
target.focus(); target.focus();
target.value = text;
target.selectionStart = selectionStart; target.selectionStart = selectionStart;
target.selectionEnd = selectionEnd; target.selectionEnd = selectionEnd;
}
// Since we've modified a Gradio Textbox component manually, we need to simulate an `input` DOM event to ensure its updateInput(target)
// internal Svelte data binding remains in sync. }
target.dispatchEvent(new Event("input", { bubbles: true }));
addEventListener('keydown', (event) => {
keyupEditAttention(event);
}); });
\ No newline at end of file
...@@ -29,7 +29,7 @@ function install_extension_from_index(button, url){ ...@@ -29,7 +29,7 @@ function install_extension_from_index(button, url){
textarea = gradioApp().querySelector('#extension_to_install textarea') textarea = gradioApp().querySelector('#extension_to_install textarea')
textarea.value = url textarea.value = url
textarea.dispatchEvent(new Event("input", { bubbles: true })) updateInput(textarea)
gradioApp().querySelector('#install_extension_button').click() gradioApp().querySelector('#install_extension_button').click()
} }
function setupExtraNetworksForTab(tabname){
gradioApp().querySelector('#'+tabname+'_extra_tabs').classList.add('extra-networks')
var tabs = gradioApp().querySelector('#'+tabname+'_extra_tabs > div')
var search = gradioApp().querySelector('#'+tabname+'_extra_search textarea')
var refresh = gradioApp().getElementById(tabname+'_extra_refresh')
var close = gradioApp().getElementById(tabname+'_extra_close')
search.classList.add('search')
tabs.appendChild(search)
tabs.appendChild(refresh)
tabs.appendChild(close)
search.addEventListener("input", function(evt){
searchTerm = search.value.toLowerCase()
gradioApp().querySelectorAll('#'+tabname+'_extra_tabs div.card').forEach(function(elem){
text = elem.querySelector('.name').textContent.toLowerCase()
elem.style.display = text.indexOf(searchTerm) == -1 ? "none" : ""
})
});
}
var activePromptTextarea = {};
function setupExtraNetworks(){
setupExtraNetworksForTab('txt2img')
setupExtraNetworksForTab('img2img')
function registerPrompt(tabname, id){
var textarea = gradioApp().querySelector("#" + id + " > label > textarea");
if (! activePromptTextarea[tabname]){
activePromptTextarea[tabname] = textarea
}
textarea.addEventListener("focus", function(){
activePromptTextarea[tabname] = textarea;
});
}
registerPrompt('txt2img', 'txt2img_prompt')
registerPrompt('txt2img', 'txt2img_neg_prompt')
registerPrompt('img2img', 'img2img_prompt')
registerPrompt('img2img', 'img2img_neg_prompt')
}
onUiLoaded(setupExtraNetworks)
function cardClicked(tabname, textToAdd, allowNegativePrompt){
var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea")
textarea.value = textarea.value + " " + textToAdd
updateInput(textarea)
}
function saveCardPreview(event, tabname, filename){
var textarea = gradioApp().querySelector("#" + tabname + '_preview_filename > label > textarea')
var button = gradioApp().getElementById(tabname + '_save_preview')
textarea.value = filename
updateInput(textarea)
button.click()
event.stopPropagation()
event.preventDefault()
}
...@@ -4,7 +4,7 @@ titles = { ...@@ -4,7 +4,7 @@ titles = {
"Sampling steps": "How many times to improve the generated image iteratively; higher values take longer; very low values can produce bad results", "Sampling steps": "How many times to improve the generated image iteratively; higher values take longer; very low values can produce bad results",
"Sampling method": "Which algorithm to use to produce the image", "Sampling method": "Which algorithm to use to produce the image",
"GFPGAN": "Restore low quality faces using GFPGAN neural network", "GFPGAN": "Restore low quality faces using GFPGAN neural network",
"Euler a": "Euler Ancestral - very creative, each can get a completely different picture depending on step count, setting steps to higher than 30-40 does not help", "Euler a": "Euler Ancestral - very creative, each can get a completely different picture depending on step count, setting steps higher than 30-40 does not help",
"DDIM": "Denoising Diffusion Implicit Models - best at inpainting", "DDIM": "Denoising Diffusion Implicit Models - best at inpainting",
"DPM adaptive": "Ignores step count - uses a number of steps determined by the CFG and resolution", "DPM adaptive": "Ignores step count - uses a number of steps determined by the CFG and resolution",
...@@ -14,12 +14,14 @@ titles = { ...@@ -14,12 +14,14 @@ titles = {
"Seed": "A value that determines the output of random number generator - if you create an image with same parameters and seed as another image, you'll get the same result", "Seed": "A value that determines the output of random number generator - if you create an image with same parameters and seed as another image, you'll get the same result",
"\u{1f3b2}\ufe0f": "Set seed to -1, which will cause a new random number to be used every time", "\u{1f3b2}\ufe0f": "Set seed to -1, which will cause a new random number to be used every time",
"\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomed", "\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomed",
"\u{1f3a8}": "Add a random artist to the prompt.",
"\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.", "\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.",
"\u{1f4c2}": "Open images output directory", "\u{1f4c2}": "Open images output directory",
"\u{1f4be}": "Save style", "\u{1f4be}": "Save style",
"\U0001F5D1": "Clear prompt", "\U0001F5D1": "Clear prompt",
"\u{1f4cb}": "Apply selected styles to current prompt", "\u{1f4cb}": "Apply selected styles to current prompt",
"\u{1f4d2}": "Paste available values into the field",
"\u{1f3b4}": "Show extra networks",
"Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt", "Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
"SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back", "SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back",
...@@ -48,7 +50,7 @@ titles = { ...@@ -48,7 +50,7 @@ titles = {
"None": "Do not do anything special", "None": "Do not do anything special",
"Prompt matrix": "Separate prompts into parts using vertical pipe character (|) and the script will create a picture for every combination of them (except for the first part, which will be present in all combinations)", "Prompt matrix": "Separate prompts into parts using vertical pipe character (|) and the script will create a picture for every combination of them (except for the first part, which will be present in all combinations)",
"X/Y plot": "Create a grid where images will have different parameters. Use inputs below to specify which parameters will be shared by columns and rows", "X/Y/Z plot": "Create grid(s) where images will have different parameters. Use inputs below to specify which parameters will be shared by columns and rows",
"Custom code": "Run Python code. Advanced user only. Must run program with --allow-code for this to work", "Custom code": "Run Python code. Advanced user only. Must run program with --allow-code for this to work",
"Prompt S/R": "Separate a list of words with commas, and the first word will be used as a keyword: script will search for this word in the prompt, and replace it with others", "Prompt S/R": "Separate a list of words with commas, and the first word will be used as a keyword: script will search for this word in the prompt, and replace it with others",
...@@ -74,16 +76,13 @@ titles = { ...@@ -74,16 +76,13 @@ titles = {
"Style 1": "Style to apply; styles have components for both positive and negative prompts and apply to both", "Style 1": "Style to apply; styles have components for both positive and negative prompts and apply to both",
"Style 2": "Style to apply; styles have components for both positive and negative prompts and apply to both", "Style 2": "Style to apply; styles have components for both positive and negative prompts and apply to both",
"Apply style": "Insert selected styles into prompt fields", "Apply style": "Insert selected styles into prompt fields",
"Create style": "Save current prompts as a style. If you add the token {prompt} to the text, the style use that as placeholder for your prompt when you use the style in the future.", "Create style": "Save current prompts as a style. If you add the token {prompt} to the text, the style uses that as a placeholder for your prompt when you use the style in the future.",
"Checkpoint name": "Loads weights from checkpoint before making images. You can either use hash or a part of filename (as seen in settings) for checkpoint name. Recommended to use with Y axis for less switching.", "Checkpoint name": "Loads weights from checkpoint before making images. You can either use hash or a part of filename (as seen in settings) for checkpoint name. Recommended to use with Y axis for less switching.",
"Inpainting conditioning mask strength": "Only applies to inpainting models. Determines how strongly to mask off the original image for inpainting and img2img. 1.0 means fully masked, which is the default behaviour. 0.0 means a fully unmasked conditioning. Lower values will help preserve the overall composition of the image, but will struggle with large changes.", "Inpainting conditioning mask strength": "Only applies to inpainting models. Determines how strongly to mask off the original image for inpainting and img2img. 1.0 means fully masked, which is the default behaviour. 0.0 means a fully unmasked conditioning. Lower values will help preserve the overall composition of the image, but will struggle with large changes.",
"vram": "Torch active: Peak amount of VRAM used by Torch during generation, excluding cached data.\nTorch reserved: Peak amount of VRAM allocated by Torch, including all active and cached data.\nSys VRAM: Peak amount of VRAM allocation across all applications / total GPU VRAM (peak utilization%).", "vram": "Torch active: Peak amount of VRAM used by Torch during generation, excluding cached data.\nTorch reserved: Peak amount of VRAM allocated by Torch, including all active and cached data.\nSys VRAM: Peak amount of VRAM allocation across all applications / total GPU VRAM (peak utilization%).",
"Highres. fix": "Use a two step process to partially create an image at smaller resolution, upscale, and then improve details in it without changing composition",
"Scale latent": "Uscale the image in latent space. Alternative is to produce the full image from latent representation, upscale that, and then move it back to latent space.",
"Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.", "Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.",
"Do not add watermark to images": "If this option is enabled, watermark will not be added to created images. Warning: if you do not add watermark, you may be behaving in an unethical manner.", "Do not add watermark to images": "If this option is enabled, watermark will not be added to created images. Warning: if you do not add watermark, you may be behaving in an unethical manner.",
...@@ -94,13 +93,24 @@ titles = { ...@@ -94,13 +93,24 @@ titles = {
"Weighted sum": "Result = A * (1 - M) + B * M", "Weighted sum": "Result = A * (1 - M) + B * M",
"Add difference": "Result = A + (B - C) * M", "Add difference": "Result = A + (B - C) * M",
"No interpolation": "Result = A",
"Learning rate": "how fast should the training go. Low values will take longer to train, high values may fail to converge (not generate accurate results) and/or may break the embedding (This has happened if you see Loss: nan in the training info textbox. If this happens, you need to manually restore your embedding from an older not-broken backup).\n\nYou can set a single numeric value, or multiple learning rates using the syntax:\n\n rate_1:max_steps_1, rate_2:max_steps_2, ...\n\nEG: 0.005:100, 1e-3:1000, 1e-5\n\nWill train with rate of 0.005 for first 100 steps, then 1e-3 until 1000 steps, then 1e-5 for all remaining steps.", "Initialization text": "If the number of tokens is more than the number of vectors, some may be skipped.\nLeave the textbox empty to start with zeroed out vectors",
"Learning rate": "How fast should training go. Low values will take longer to train, high values may fail to converge (not generate accurate results) and/or may break the embedding (This has happened if you see Loss: nan in the training info textbox. If this happens, you need to manually restore your embedding from an older not-broken backup).\n\nYou can set a single numeric value, or multiple learning rates using the syntax:\n\n rate_1:max_steps_1, rate_2:max_steps_2, ...\n\nEG: 0.005:100, 1e-3:1000, 1e-5\n\nWill train with rate of 0.005 for first 100 steps, then 1e-3 until 1000 steps, then 1e-5 for all remaining steps.",
"Clip skip": "Early stopping parameter for CLIP model; 1 is stop at last layer as usual, 2 is stop at penultimate layer, etc.", "Clip skip": "Early stopping parameter for CLIP model; 1 is stop at last layer as usual, 2 is stop at penultimate layer, etc.",
"Approx NN": "Cheap neural network approximation. Very fast compared to VAE, but produces pictures with 4 times smaller horizontal/vertical resoluton and lower quality.", "Approx NN": "Cheap neural network approximation. Very fast compared to VAE, but produces pictures with 4 times smaller horizontal/vertical resolution and lower quality.",
"Approx cheap": "Very cheap approximation. Very fast compared to VAE, but produces pictures with 8 times smaller horizontal/vertical resoluton and extremely low quality." "Approx cheap": "Very cheap approximation. Very fast compared to VAE, but produces pictures with 8 times smaller horizontal/vertical resolution and extremely low quality.",
"Hires. fix": "Use a two step process to partially create an image at smaller resolution, upscale, and then improve details in it without changing composition",
"Hires steps": "Number of sampling steps for upscaled picture. If 0, uses same as for original.",
"Upscale by": "Adjusts the size of the image by multiplying the original width and height by the selected value. Ignored if either Resize width to or Resize height to are non-zero.",
"Resize width to": "Resizes image to this width. If 0, width is inferred from either of two nearby sliders.",
"Resize height to": "Resizes image to this height. If 0, height is inferred from either of two nearby sliders.",
"Multiplier for extra networks": "When adding extra network such as Hypernetwork or Lora to prompt, use this multiplier for it.",
"Discard weights with matching name": "Regular expression; if weights's name matches it, the weights is not written to the resulting checkpoint. Use ^model_ema to discard EMA weights.",
"Extra networks tab order": "Comma-separated list of tab names; tabs listed here will appear in the extra networks UI first and in order lsited."
} }
......
function setInactive(elem, inactive){
if(inactive){
elem.classList.add('inactive')
} else{
elem.classList.remove('inactive')
}
}
function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y){
hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale')
hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x')
hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y')
gradioApp().getElementById('txt2img_hires_fix_row2').style.display = opts.use_old_hires_fix_width_height ? "none" : ""
setInactive(hrUpscaleBy, opts.use_old_hires_fix_width_height || hr_resize_x > 0 || hr_resize_y > 0)
setInactive(hrResizeX, opts.use_old_hires_fix_width_height || hr_resize_x == 0)
setInactive(hrResizeY, opts.use_old_hires_fix_width_height || hr_resize_y == 0)
return [enable, width, height, hr_scale, hr_resize_x, hr_resize_y]
}
...@@ -148,9 +148,18 @@ function showGalleryImage() { ...@@ -148,9 +148,18 @@ function showGalleryImage() {
if(e && e.parentElement.tagName == 'DIV'){ if(e && e.parentElement.tagName == 'DIV'){
e.style.cursor='pointer' e.style.cursor='pointer'
e.style.userSelect='none' e.style.userSelect='none'
e.addEventListener('click', function (evt) {
if(!opts.js_modal_lightbox) return; var isFirefox = isFirefox = navigator.userAgent.toLowerCase().indexOf('firefox') > -1
// For Firefox, listening on click first switched to next image then shows the lightbox.
// If you know how to fix this without switching to mousedown event, please.
// For other browsers the event is click to make it possiblr to drag picture.
var event = isFirefox ? 'mousedown' : 'click'
e.addEventListener(event, function (evt) {
if(!opts.js_modal_lightbox || evt.button != 0) return;
modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed) modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed)
evt.preventDefault()
showModal(evt) showModal(evt)
}, true); }, true);
} }
......
...@@ -10,10 +10,8 @@ ignore_ids_for_localization={ ...@@ -10,10 +10,8 @@ ignore_ids_for_localization={
modelmerger_tertiary_model_name: 'OPTION', modelmerger_tertiary_model_name: 'OPTION',
train_embedding: 'OPTION', train_embedding: 'OPTION',
train_hypernetwork: 'OPTION', train_hypernetwork: 'OPTION',
txt2img_style_index: 'OPTION', txt2img_styles: 'OPTION',
txt2img_style2_index: 'OPTION', img2img_styles: 'OPTION',
img2img_style_index: 'OPTION',
img2img_style2_index: 'OPTION',
setting_random_artist_categories: 'SPAN', setting_random_artist_categories: 'SPAN',
setting_face_restoration_model: 'SPAN', setting_face_restoration_model: 'SPAN',
setting_realesrgan_enabled_models: 'SPAN', setting_realesrgan_enabled_models: 'SPAN',
......
This diff is collapsed.
function start_training_textual_inversion(){ function start_training_textual_inversion(){
requestProgress('ti')
gradioApp().querySelector('#ti_error').innerHTML='' gradioApp().querySelector('#ti_error').innerHTML=''
return args_to_array(arguments) var id = randomId()
requestProgress(id, gradioApp().getElementById('ti_output'), gradioApp().getElementById('ti_gallery'), function(){}, function(progress){
gradioApp().getElementById('ti_progress').innerHTML = progress.textinfo
})
var res = args_to_array(arguments)
res[0] = id
return res
} }
// various functions for interation with ui.py not large enough to warrant putting them in separate files // various functions for interaction with ui.py not large enough to warrant putting them in separate files
function set_theme(theme){ function set_theme(theme){
gradioURL = window.location.href gradioURL = window.location.href
...@@ -19,7 +19,7 @@ function selected_gallery_index(){ ...@@ -19,7 +19,7 @@ function selected_gallery_index(){
function extract_image_from_gallery(gallery){ function extract_image_from_gallery(gallery){
if(gallery.length == 1){ if(gallery.length == 1){
return gallery[0] return [gallery[0]]
} }
index = selected_gallery_index() index = selected_gallery_index()
...@@ -28,7 +28,7 @@ function extract_image_from_gallery(gallery){ ...@@ -28,7 +28,7 @@ function extract_image_from_gallery(gallery){
return [null] return [null]
} }
return gallery[index]; return [gallery[index]];
} }
function args_to_array(args){ function args_to_array(args){
...@@ -45,16 +45,33 @@ function switch_to_txt2img(){ ...@@ -45,16 +45,33 @@ function switch_to_txt2img(){
return args_to_array(arguments); return args_to_array(arguments);
} }
function switch_to_img2img(){ function switch_to_img2img_tab(no){
gradioApp().querySelector('#tabs').querySelectorAll('button')[1].click(); gradioApp().querySelector('#tabs').querySelectorAll('button')[1].click();
gradioApp().getElementById('mode_img2img').querySelectorAll('button')[0].click(); gradioApp().getElementById('mode_img2img').querySelectorAll('button')[no].click();
}
function switch_to_img2img(){
switch_to_img2img_tab(0);
return args_to_array(arguments);
}
function switch_to_sketch(){
switch_to_img2img_tab(1);
return args_to_array(arguments);
}
function switch_to_inpaint(){
switch_to_img2img_tab(2);
return args_to_array(arguments);
}
function switch_to_inpaint_sketch(){
switch_to_img2img_tab(3);
return args_to_array(arguments); return args_to_array(arguments);
} }
function switch_to_inpaint(){ function switch_to_inpaint(){
gradioApp().querySelector('#tabs').querySelectorAll('button')[1].click(); gradioApp().querySelector('#tabs').querySelectorAll('button')[1].click();
gradioApp().getElementById('mode_img2img').querySelectorAll('button')[1].click(); gradioApp().getElementById('mode_img2img').querySelectorAll('button')[2].click();
return args_to_array(arguments); return args_to_array(arguments);
} }
...@@ -87,9 +104,11 @@ function create_tab_index_args(tabId, args){ ...@@ -87,9 +104,11 @@ function create_tab_index_args(tabId, args){
return res return res
} }
function get_extras_tab_index(){ function get_img2img_tab_index() {
const [,,...args] = [...arguments] let res = args_to_array(arguments)
return [get_tab_index('mode_extras'), get_tab_index('extras_resize_mode'), ...args] res.splice(-2)
res[0] = get_tab_index('mode_img2img')
return res
} }
function create_submit_args(args){ function create_submit_args(args){
...@@ -109,22 +128,54 @@ function create_submit_args(args){ ...@@ -109,22 +128,54 @@ function create_submit_args(args){
return res return res
} }
function showSubmitButtons(tabname, show){
gradioApp().getElementById(tabname+'_interrupt').style.display = show ? "none" : "block"
gradioApp().getElementById(tabname+'_skip').style.display = show ? "none" : "block"
}
function submit(){ function submit(){
requestProgress('txt2img') rememberGallerySelection('txt2img_gallery')
showSubmitButtons('txt2img', false)
var id = randomId()
requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function(){
showSubmitButtons('txt2img', true)
})
var res = create_submit_args(arguments)
res[0] = id
return create_submit_args(arguments) return res
} }
function submit_img2img(){ function submit_img2img(){
requestProgress('img2img') rememberGallerySelection('img2img_gallery')
showSubmitButtons('img2img', false)
res = create_submit_args(arguments) var id = randomId()
requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function(){
showSubmitButtons('img2img', true)
})
res[0] = get_tab_index('mode_img2img') var res = create_submit_args(arguments)
res[0] = id
res[1] = get_tab_index('mode_img2img')
return res return res
} }
function modelmerger(){
var id = randomId()
requestProgress(id, gradioApp().getElementById('modelmerger_results_panel'), null, function(){})
var res = create_submit_args(arguments)
res[0] = id
return res
}
function ask_for_style_name(_, prompt_text, negative_prompt_text) { function ask_for_style_name(_, prompt_text, negative_prompt_text) {
name_ = prompt('Style name:') name_ = prompt('Style name:')
...@@ -140,27 +191,17 @@ function confirm_clear_prompt(prompt, negative_prompt) { ...@@ -140,27 +191,17 @@ function confirm_clear_prompt(prompt, negative_prompt) {
return [prompt, negative_prompt] return [prompt, negative_prompt]
} }
opts = {} opts = {}
function apply_settings(jsdata){
console.log(jsdata)
opts = JSON.parse(jsdata)
return jsdata
}
onUiUpdate(function(){ onUiUpdate(function(){
if(Object.keys(opts).length != 0) return; if(Object.keys(opts).length != 0) return;
json_elem = gradioApp().getElementById('settings_json') json_elem = gradioApp().getElementById('settings_json')
if(json_elem == null) return; if(json_elem == null) return;
textarea = json_elem.querySelector('textarea') var textarea = json_elem.querySelector('textarea')
jsdata = textarea.value var jsdata = textarea.value
opts = JSON.parse(jsdata) opts = JSON.parse(jsdata)
executeCallbacks(optionsChangedCallbacks);
Object.defineProperty(textarea, 'value', { Object.defineProperty(textarea, 'value', {
set: function(newValue) { set: function(newValue) {
...@@ -171,6 +212,8 @@ onUiUpdate(function(){ ...@@ -171,6 +212,8 @@ onUiUpdate(function(){
if (oldValue != newValue) { if (oldValue != newValue) {
opts = JSON.parse(textarea.value) opts = JSON.parse(textarea.value)
} }
executeCallbacks(optionsChangedCallbacks);
}, },
get: function() { get: function() {
var valueProp = Object.getOwnPropertyDescriptor(HTMLTextAreaElement.prototype, 'value'); var valueProp = Object.getOwnPropertyDescriptor(HTMLTextAreaElement.prototype, 'value');
...@@ -180,13 +223,51 @@ onUiUpdate(function(){ ...@@ -180,13 +223,51 @@ onUiUpdate(function(){
json_elem.parentElement.style.display="none" json_elem.parentElement.style.display="none"
if (!txt2img_textarea) { function registerTextarea(id, id_counter, id_button){
txt2img_textarea = gradioApp().querySelector("#txt2img_prompt > label > textarea"); var prompt = gradioApp().getElementById(id)
txt2img_textarea?.addEventListener("input", () => update_token_counter("txt2img_token_button")); var counter = gradioApp().getElementById(id_counter)
var textarea = gradioApp().querySelector("#" + id + " > label > textarea");
if(counter.parentElement == prompt.parentElement){
return
}
prompt.parentElement.insertBefore(counter, prompt)
counter.classList.add("token-counter")
prompt.parentElement.style.position = "relative"
textarea.addEventListener("input", function(){
update_token_counter(id_button);
});
}
registerTextarea('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button')
registerTextarea('txt2img_neg_prompt', 'txt2img_negative_token_counter', 'txt2img_negative_token_button')
registerTextarea('img2img_prompt', 'img2img_token_counter', 'img2img_token_button')
registerTextarea('img2img_neg_prompt', 'img2img_negative_token_counter', 'img2img_negative_token_button')
show_all_pages = gradioApp().getElementById('settings_show_all_pages')
settings_tabs = gradioApp().querySelector('#settings div')
if(show_all_pages && settings_tabs){
settings_tabs.appendChild(show_all_pages)
show_all_pages.onclick = function(){
gradioApp().querySelectorAll('#settings > div').forEach(function(elem){
elem.style.display = "block";
})
} }
if (!img2img_textarea) { }
img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea"); })
img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button"));
onOptionsChanged(function(){
elem = gradioApp().getElementById('sd_checkpoint_hash')
sd_checkpoint_hash = opts.sd_checkpoint_hash || ""
shorthash = sd_checkpoint_hash.substr(0,10)
if(elem && elem.textContent != shorthash){
elem.textContent = shorthash
elem.title = sd_checkpoint_hash
elem.href = "https://google.com/search?q=" + sd_checkpoint_hash
} }
}) })
...@@ -220,3 +301,11 @@ function restart_reload(){ ...@@ -220,3 +301,11 @@ function restart_reload(){
return [] return []
} }
// Simulate an `input` DOM event for Gradio Textbox component. Needed after you edit its contents in javascript, otherwise your edits
// will only visible on web page and not sent to python.
function updateInput(target){
let e = new Event("input", { bubbles: true })
Object.defineProperty(e, "target", {value: target})
target.dispatchEvent(e);
}
...@@ -13,6 +13,53 @@ dir_extensions = "extensions" ...@@ -13,6 +13,53 @@ dir_extensions = "extensions"
python = sys.executable python = sys.executable
git = os.environ.get('GIT', "git") git = os.environ.get('GIT', "git")
index_url = os.environ.get('INDEX_URL', "") index_url = os.environ.get('INDEX_URL', "")
stored_commit_hash = None
skip_install = False
def check_python_version():
is_windows = platform.system() == "Windows"
major = sys.version_info.major
minor = sys.version_info.minor
micro = sys.version_info.micro
if is_windows:
supported_minors = [10]
else:
supported_minors = [7, 8, 9, 10, 11]
if not (major == 3 and minor in supported_minors):
import modules.errors
modules.errors.print_error_explanation(f"""
INCOMPATIBLE PYTHON VERSION
This program is tested with 3.10.6 Python, but you have {major}.{minor}.{micro}.
If you encounter an error with "RuntimeError: Couldn't install torch." message,
or any other error regarding unsuccessful package (library) installation,
please downgrade (or upgrade) to the latest version of 3.10 Python
and delete current Python and "venv" folder in WebUI's directory.
You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3109/
{"Alternatively, use a binary release of WebUI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases" if is_windows else ""}
Use --skip-python-version-check to suppress this warning.
""")
def commit_hash():
global stored_commit_hash
if stored_commit_hash is not None:
return stored_commit_hash
try:
stored_commit_hash = run(f"{git} rev-parse HEAD").strip()
except Exception:
stored_commit_hash = "<none>"
return stored_commit_hash
def extract_arg(args, name): def extract_arg(args, name):
...@@ -32,10 +79,19 @@ def extract_opt(args, name): ...@@ -32,10 +79,19 @@ def extract_opt(args, name):
return args, is_present, opt return args, is_present, opt
def run(command, desc=None, errdesc=None, custom_env=None): def run(command, desc=None, errdesc=None, custom_env=None, live=False):
if desc is not None: if desc is not None:
print(desc) print(desc)
if live:
result = subprocess.run(command, shell=True, env=os.environ if custom_env is None else custom_env)
if result.returncode != 0:
raise RuntimeError(f"""{errdesc or 'Error running command'}.
Command: {command}
Error code: {result.returncode}""")
return ""
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env) result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env)
if result.returncode != 0: if result.returncode != 0:
...@@ -74,6 +130,9 @@ def run_python(code, desc=None, errdesc=None): ...@@ -74,6 +130,9 @@ def run_python(code, desc=None, errdesc=None):
def run_pip(args, desc=None): def run_pip(args, desc=None):
if skip_install:
return
index_url_line = f' --index-url {index_url}' if index_url != '' else '' index_url_line = f' --index-url {index_url}' if index_url != '' else ''
return run(f'"{python}" -m pip {args} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}") return run(f'"{python}" -m pip {args} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}")
...@@ -89,18 +148,18 @@ def git_clone(url, dir, name, commithash=None): ...@@ -89,18 +148,18 @@ def git_clone(url, dir, name, commithash=None):
if commithash is None: if commithash is None:
return return
current_hash = run(f'"{git}" -C {dir} rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip() current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip()
if current_hash == commithash: if current_hash == commithash:
return return
run(f'"{git}" -C {dir} fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}") run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
run(f'"{git}" -C {dir} checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}") run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}")
return return
run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}") run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}")
if commithash is not None: if commithash is not None:
run(f'"{git}" -C {dir} checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}") run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
def version_check(commit): def version_check(commit):
...@@ -158,7 +217,9 @@ def run_extensions_installers(settings_file): ...@@ -158,7 +217,9 @@ def run_extensions_installers(settings_file):
def prepare_environment(): def prepare_environment():
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113") global skip_install
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117")
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
commandline_args = os.environ.get('COMMANDLINE_ARGS', "") commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
...@@ -166,8 +227,6 @@ def prepare_environment(): ...@@ -166,8 +227,6 @@ def prepare_environment():
clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1") clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b") openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b")
xformers_windows_package = os.environ.get('XFORMERS_WINDOWS_PACKAGE', 'https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl')
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git") stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git") taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git")
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git') k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
...@@ -188,22 +247,25 @@ def prepare_environment(): ...@@ -188,22 +247,25 @@ def prepare_environment():
sys.argv, _ = extract_arg(sys.argv, '-f') sys.argv, _ = extract_arg(sys.argv, '-f')
sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test') sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test')
sys.argv, skip_python_version_check = extract_arg(sys.argv, '--skip-python-version-check')
sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers') sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers')
sys.argv, reinstall_torch = extract_arg(sys.argv, '--reinstall-torch')
sys.argv, update_check = extract_arg(sys.argv, '--update-check') sys.argv, update_check = extract_arg(sys.argv, '--update-check')
sys.argv, run_tests, test_dir = extract_opt(sys.argv, '--tests') sys.argv, run_tests, test_dir = extract_opt(sys.argv, '--tests')
sys.argv, skip_install = extract_arg(sys.argv, '--skip-install')
xformers = '--xformers' in sys.argv xformers = '--xformers' in sys.argv
ngrok = '--ngrok' in sys.argv ngrok = '--ngrok' in sys.argv
try: if not skip_python_version_check:
commit = run(f"{git} rev-parse HEAD").strip() check_python_version()
except Exception:
commit = "<none>" commit = commit_hash()
print(f"Python {sys.version}") print(f"Python {sys.version}")
print(f"Commit hash: {commit}") print(f"Commit hash: {commit}")
if not is_installed("torch") or not is_installed("torchvision"): if reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch") run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
if not skip_torch_cuda_test: if not skip_torch_cuda_test:
run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'") run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'")
...@@ -220,14 +282,14 @@ def prepare_environment(): ...@@ -220,14 +282,14 @@ def prepare_environment():
if (not is_installed("xformers") or reinstall_xformers) and xformers: if (not is_installed("xformers") or reinstall_xformers) and xformers:
if platform.system() == "Windows": if platform.system() == "Windows":
if platform.python_version().startswith("3.10"): if platform.python_version().startswith("3.10"):
run_pip(f"install -U -I --no-deps {xformers_windows_package}", "xformers") run_pip(f"install -U -I --no-deps xformers==0.0.16rc425", "xformers")
else: else:
print("Installation of xformers is not supported in this version of Python.") print("Installation of xformers is not supported in this version of Python.")
print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness") print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness")
if not is_installed("xformers"): if not is_installed("xformers"):
exit(0) exit(0)
elif platform.system() == "Linux": elif platform.system() == "Linux":
run_pip("install xformers", "xformers") run_pip("install xformers==0.0.16rc425", "xformers")
if not is_installed("pyngrok") and ngrok: if not is_installed("pyngrok") and ngrok:
run_pip("install pyngrok", "ngrok") run_pip("install pyngrok", "ngrok")
...@@ -267,9 +329,12 @@ def tests(test_dir): ...@@ -267,9 +329,12 @@ def tests(test_dir):
sys.argv.append("./test/test_files/empty.pt") sys.argv.append("./test/test_files/empty.pt")
if "--skip-torch-cuda-test" not in sys.argv: if "--skip-torch-cuda-test" not in sys.argv:
sys.argv.append("--skip-torch-cuda-test") sys.argv.append("--skip-torch-cuda-test")
if "--disable-nan-check" not in sys.argv:
sys.argv.append("--disable-nan-check")
print(f"Launching Web UI in another process for testing with arguments: {' '.join(sys.argv[1:])}") print(f"Launching Web UI in another process for testing with arguments: {' '.join(sys.argv[1:])}")
os.environ['COMMANDLINE_ARGS'] = ""
with open('test/stdout.txt', "w", encoding="utf8") as stdout, open('test/stderr.txt', "w", encoding="utf8") as stderr: with open('test/stdout.txt', "w", encoding="utf8") as stdout, open('test/stderr.txt', "w", encoding="utf8") as stderr:
proc = subprocess.Popen([sys.executable, *sys.argv], stdout=stdout, stderr=stderr) proc = subprocess.Popen([sys.executable, *sys.argv], stdout=stdout, stderr=stderr)
......
This diff is collapsed.
...@@ -100,13 +100,13 @@ class PydanticModelGenerator: ...@@ -100,13 +100,13 @@ class PydanticModelGenerator:
StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator( StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
"StableDiffusionProcessingTxt2Img", "StableDiffusionProcessingTxt2Img",
StableDiffusionProcessingTxt2Img, StableDiffusionProcessingTxt2Img,
[{"key": "sampler_index", "type": str, "default": "Euler"}] [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}]
).generate_model() ).generate_model()
StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator( StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
"StableDiffusionProcessingImg2Img", "StableDiffusionProcessingImg2Img",
StableDiffusionProcessingImg2Img, StableDiffusionProcessingImg2Img,
[{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}] [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}]
).generate_model() ).generate_model()
class TextToImageResponse(BaseModel): class TextToImageResponse(BaseModel):
...@@ -125,7 +125,7 @@ class ExtrasBaseRequest(BaseModel): ...@@ -125,7 +125,7 @@ class ExtrasBaseRequest(BaseModel):
gfpgan_visibility: float = Field(default=0, title="GFPGAN Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of GFPGAN, values should be between 0 and 1.") gfpgan_visibility: float = Field(default=0, title="GFPGAN Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of GFPGAN, values should be between 0 and 1.")
codeformer_visibility: float = Field(default=0, title="CodeFormer Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of CodeFormer, values should be between 0 and 1.") codeformer_visibility: float = Field(default=0, title="CodeFormer Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of CodeFormer, values should be between 0 and 1.")
codeformer_weight: float = Field(default=0, title="CodeFormer Weight", ge=0, le=1, allow_inf_nan=False, description="Sets the weight of CodeFormer, values should be between 0 and 1.") codeformer_weight: float = Field(default=0, title="CodeFormer Weight", ge=0, le=1, allow_inf_nan=False, description="Sets the weight of CodeFormer, values should be between 0 and 1.")
upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=4, description="By how much to upscale the image, only used when resize_mode=0.") upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=8, description="By how much to upscale the image, only used when resize_mode=0.")
upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.") upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.")
upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.") upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.")
upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the chosen size?") upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the chosen size?")
...@@ -157,7 +157,8 @@ class PNGInfoRequest(BaseModel): ...@@ -157,7 +157,8 @@ class PNGInfoRequest(BaseModel):
image: str = Field(title="Image", description="The base64 encoded PNG image") image: str = Field(title="Image", description="The base64 encoded PNG image")
class PNGInfoResponse(BaseModel): class PNGInfoResponse(BaseModel):
info: str = Field(title="Image info", description="A string with all the info the image had") info: str = Field(title="Image info", description="A string with the parameters used to generate the image")
items: dict = Field(title="Items", description="An object containing all the info the image had")
class ProgressRequest(BaseModel): class ProgressRequest(BaseModel):
skip_current_image: bool = Field(default=False, title="Skip current image", description="Skip current image serialization") skip_current_image: bool = Field(default=False, title="Skip current image", description="Skip current image serialization")
...@@ -167,6 +168,7 @@ class ProgressResponse(BaseModel): ...@@ -167,6 +168,7 @@ class ProgressResponse(BaseModel):
eta_relative: float = Field(title="ETA in secs") eta_relative: float = Field(title="ETA in secs")
state: dict = Field(title="State", description="The current state snapshot") state: dict = Field(title="State", description="The current state snapshot")
current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.") current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
textinfo: str = Field(default=None, title="Info text", description="Info text used by WebUI.")
class InterrogateRequest(BaseModel): class InterrogateRequest(BaseModel):
image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.") image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
...@@ -218,13 +220,15 @@ class UpscalerItem(BaseModel): ...@@ -218,13 +220,15 @@ class UpscalerItem(BaseModel):
model_name: Optional[str] = Field(title="Model Name") model_name: Optional[str] = Field(title="Model Name")
model_path: Optional[str] = Field(title="Path") model_path: Optional[str] = Field(title="Path")
model_url: Optional[str] = Field(title="URL") model_url: Optional[str] = Field(title="URL")
scale: Optional[float] = Field(title="Scale")
class SDModelItem(BaseModel): class SDModelItem(BaseModel):
title: str = Field(title="Title") title: str = Field(title="Title")
model_name: str = Field(title="Model Name") model_name: str = Field(title="Model Name")
hash: str = Field(title="Hash") hash: Optional[str] = Field(title="Short hash")
sha256: Optional[str] = Field(title="sha256 hash")
filename: str = Field(title="Filename") filename: str = Field(title="Filename")
config: str = Field(title="Config file") config: Optional[str] = Field(title="Config file")
class HypernetworkItem(BaseModel): class HypernetworkItem(BaseModel):
name: str = Field(title="Name") name: str = Field(title="Name")
...@@ -249,3 +253,17 @@ class ArtistItem(BaseModel): ...@@ -249,3 +253,17 @@ class ArtistItem(BaseModel):
score: float = Field(title="Score") score: float = Field(title="Score")
category: str = Field(title="Category") category: str = Field(title="Category")
class EmbeddingItem(BaseModel):
step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available")
sd_checkpoint: Optional[str] = Field(title="SD Checkpoint", description="The hash of the checkpoint this embedding was trained on, if available")
sd_checkpoint_name: Optional[str] = Field(title="SD Checkpoint Name", description="The name of the checkpoint this embedding was trained on, if available. Note that this is the name that was used by the trainer; for a stable identifier, use `sd_checkpoint` instead")
shape: int = Field(title="Shape", description="The length of each individual vector in the embedding")
vectors: int = Field(title="Vectors", description="The number of vectors in the embedding")
class EmbeddingsResponse(BaseModel):
loaded: Dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model")
skipped: Dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")
class MemoryResponse(BaseModel):
ram: dict = Field(title="RAM", description="System memory stats")
cuda: dict = Field(title="CUDA", description="nVidia CUDA memory stats")
import os.path
import csv
from collections import namedtuple
Artist = namedtuple("Artist", ['name', 'weight', 'category'])
class ArtistsDatabase:
def __init__(self, filename):
self.cats = set()
self.artists = []
if not os.path.exists(filename):
return
with open(filename, "r", newline='', encoding="utf8") as file:
reader = csv.DictReader(file)
for row in reader:
artist = Artist(row["artist"], float(row["score"]), row["category"])
self.artists.append(artist)
self.cats.add(artist.category)
def categories(self):
return sorted(self.cats)
...@@ -4,7 +4,7 @@ import threading ...@@ -4,7 +4,7 @@ import threading
import traceback import traceback
import time import time
from modules import shared from modules import shared, progress
queue_lock = threading.Lock() queue_lock = threading.Lock()
...@@ -22,10 +22,21 @@ def wrap_queued_call(func): ...@@ -22,10 +22,21 @@ def wrap_queued_call(func):
def wrap_gradio_gpu_call(func, extra_outputs=None): def wrap_gradio_gpu_call(func, extra_outputs=None):
def f(*args, **kwargs): def f(*args, **kwargs):
shared.state.begin() # if the first argument is a string that says "task(...)", it is treated as a job id
if len(args) > 0 and type(args[0]) == str and args[0][0:5] == "task(" and args[0][-1] == ")":
id_task = args[0]
progress.add_task_to_queue(id_task)
else:
id_task = None
with queue_lock: with queue_lock:
shared.state.begin()
progress.start_task(id_task)
try:
res = func(*args, **kwargs) res = func(*args, **kwargs)
finally:
progress.finish_task(id_task)
shared.state.end() shared.state.end()
......
...@@ -8,7 +8,7 @@ import torch ...@@ -8,7 +8,7 @@ import torch
import modules.face_restoration import modules.face_restoration
import modules.shared import modules.shared
from modules import shared, devices, modelloader from modules import shared, devices, modelloader
from modules.paths import script_path, models_path from modules.paths import models_path
# codeformer people made a choice to include modified basicsr library to their project which makes # codeformer people made a choice to include modified basicsr library to their project which makes
# it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN. # it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN.
......
...@@ -2,6 +2,8 @@ import torch ...@@ -2,6 +2,8 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from modules import devices
# see https://github.com/AUTOMATIC1111/TorchDeepDanbooru for more # see https://github.com/AUTOMATIC1111/TorchDeepDanbooru for more
...@@ -196,7 +198,7 @@ class DeepDanbooruModel(nn.Module): ...@@ -196,7 +198,7 @@ class DeepDanbooruModel(nn.Module):
t_358, = inputs t_358, = inputs
t_359 = t_358.permute(*[0, 3, 1, 2]) t_359 = t_358.permute(*[0, 3, 1, 2])
t_359_padded = F.pad(t_359, [2, 3, 2, 3], value=0) t_359_padded = F.pad(t_359, [2, 3, 2, 3], value=0)
t_360 = self.n_Conv_0(t_359_padded) t_360 = self.n_Conv_0(t_359_padded.to(self.n_Conv_0.bias.dtype) if devices.unet_needs_upcast else t_359_padded)
t_361 = F.relu(t_360) t_361 = F.relu(t_360)
t_361 = F.pad(t_361, [0, 1, 0, 1], value=float('-inf')) t_361 = F.pad(t_361, [0, 1, 0, 1], value=float('-inf'))
t_362 = self.n_MaxPool_0(t_361) t_362 = self.n_MaxPool_0(t_361)
......
...@@ -34,14 +34,18 @@ def get_cuda_device_string(): ...@@ -34,14 +34,18 @@ def get_cuda_device_string():
return "cuda" return "cuda"
def get_optimal_device(): def get_optimal_device_name():
if torch.cuda.is_available(): if torch.cuda.is_available():
return torch.device(get_cuda_device_string()) return get_cuda_device_string()
if has_mps(): if has_mps():
return torch.device("mps") return "mps"
return "cpu"
return cpu
def get_optimal_device():
return torch.device(get_optimal_device_name())
def get_device_for(task): def get_device_for(task):
...@@ -79,6 +83,8 @@ cpu = torch.device("cpu") ...@@ -79,6 +83,8 @@ cpu = torch.device("cpu")
device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None
dtype = torch.float16 dtype = torch.float16
dtype_vae = torch.float16 dtype_vae = torch.float16
dtype_unet = torch.float16
unet_needs_upcast = False
def randn(seed, shape): def randn(seed, shape):
...@@ -106,6 +112,42 @@ def autocast(disable=False): ...@@ -106,6 +112,42 @@ def autocast(disable=False):
return torch.autocast("cuda") return torch.autocast("cuda")
def without_autocast(disable=False):
return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext()
class NansException(Exception):
pass
def test_for_nans(x, where):
from modules import shared
if shared.cmd_opts.disable_nan_check:
return
if not torch.all(torch.isnan(x)).item():
return
if where == "unet":
message = "A tensor with all NaNs was produced in Unet."
if not shared.cmd_opts.no_half:
message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try setting the \"Upcast cross attention layer to float32\" option in Settings > Stable Diffusion or using the --no-half commandline argument to fix this."
elif where == "vae":
message = "A tensor with all NaNs was produced in VAE."
if not shared.cmd_opts.no_half and not shared.cmd_opts.no_half_vae:
message += " This could be because there's not enough precision to represent the picture. Try adding --no-half-vae commandline argument to fix this."
else:
message = "A tensor with all NaNs was produced."
message += " Use --disable-nan-check commandline argument to disable this check."
raise NansException(message)
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383 # MPS workaround for https://github.com/pytorch/pytorch/issues/79383
orig_tensor_to = torch.Tensor.to orig_tensor_to = torch.Tensor.to
def tensor_to_fix(self, *args, **kwargs): def tensor_to_fix(self, *args, **kwargs):
...@@ -133,8 +175,30 @@ def numpy_fix(self, *args, **kwargs): ...@@ -133,8 +175,30 @@ def numpy_fix(self, *args, **kwargs):
return orig_tensor_numpy(self, *args, **kwargs) return orig_tensor_numpy(self, *args, **kwargs)
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working # MPS workaround for https://github.com/pytorch/pytorch/issues/89784
if has_mps() and version.parse(torch.__version__) < version.parse("1.13"): orig_cumsum = torch.cumsum
orig_Tensor_cumsum = torch.Tensor.cumsum
def cumsum_fix(input, cumsum_func, *args, **kwargs):
if input.device.type == 'mps':
output_dtype = kwargs.get('dtype', input.dtype)
if output_dtype == torch.int64:
return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
elif cumsum_needs_bool_fix and output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
return cumsum_func(input, *args, **kwargs)
if has_mps():
if version.parse(torch.__version__) < version.parse("1.13"):
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
torch.Tensor.to = tensor_to_fix torch.Tensor.to = tensor_to_fix
torch.nn.functional.layer_norm = layer_norm_fix torch.nn.functional.layer_norm = layer_norm_fix
torch.Tensor.numpy = numpy_fix torch.Tensor.numpy = numpy_fix
elif version.parse(torch.__version__) > version.parse("1.13.1"):
cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0))
torch.cumsum = lambda input, *args, **kwargs: ( cumsum_fix(input, orig_cumsum, *args, **kwargs) )
torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) )
orig_narrow = torch.narrow
torch.narrow = lambda *args, **kwargs: ( orig_narrow(*args, **kwargs).clone() )
...@@ -2,9 +2,42 @@ import sys ...@@ -2,9 +2,42 @@ import sys
import traceback import traceback
def print_error_explanation(message):
lines = message.strip().split("\n")
max_len = max([len(x) for x in lines])
print('=' * max_len, file=sys.stderr)
for line in lines:
print(line, file=sys.stderr)
print('=' * max_len, file=sys.stderr)
def display(e: Exception, task):
print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
message = str(e)
if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message:
print_error_explanation("""
The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its config file.
See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20 for how to solve this.
""")
already_displayed = {}
def display_once(e: Exception, task):
if task in already_displayed:
return
display(e, task)
already_displayed[task] = 1
def run(code, task): def run(code, task):
try: try:
code() code()
except Exception as e: except Exception as e:
print(f"{task}: {type(e).__name__}", file=sys.stderr) display(task, e)
print(traceback.format_exc(), file=sys.stderr)
...@@ -7,9 +7,11 @@ import git ...@@ -7,9 +7,11 @@ import git
from modules import paths, shared from modules import paths, shared
extensions = [] extensions = []
extensions_dir = os.path.join(paths.script_path, "extensions") extensions_dir = os.path.join(paths.data_path, "extensions")
extensions_builtin_dir = os.path.join(paths.script_path, "extensions-builtin") extensions_builtin_dir = os.path.join(paths.script_path, "extensions-builtin")
if not os.path.exists(extensions_dir):
os.makedirs(extensions_dir)
def active(): def active():
return [x for x in extensions if x.enabled] return [x for x in extensions if x.enabled]
......
import re
from collections import defaultdict
from modules import errors
extra_network_registry = {}
def initialize():
extra_network_registry.clear()
def register_extra_network(extra_network):
extra_network_registry[extra_network.name] = extra_network
class ExtraNetworkParams:
def __init__(self, items=None):
self.items = items or []
class ExtraNetwork:
def __init__(self, name):
self.name = name
def activate(self, p, params_list):
"""
Called by processing on every run. Whatever the extra network is meant to do should be activated here.
Passes arguments related to this extra network in params_list.
User passes arguments by specifying this in his prompt:
<name:arg1:arg2:arg3>
Where name matches the name of this ExtraNetwork object, and arg1:arg2:arg3 are any natural number of text arguments
separated by colon.
Even if the user does not mention this ExtraNetwork in his prompt, the call will stil be made, with empty params_list -
in this case, all effects of this extra networks should be disabled.
Can be called multiple times before deactivate() - each new call should override the previous call completely.
For example, if this ExtraNetwork's name is 'hypernet' and user's prompt is:
> "1girl, <hypernet:agm:1.1> <extrasupernet:master:12:13:14> <hypernet:ray>"
params_list will be:
[
ExtraNetworkParams(items=["agm", "1.1"]),
ExtraNetworkParams(items=["ray"])
]
"""
raise NotImplementedError
def deactivate(self, p):
"""
Called at the end of processing for housekeeping. No need to do anything here.
"""
raise NotImplementedError
def activate(p, extra_network_data):
"""call activate for extra networks in extra_network_data in specified order, then call
activate for all remaining registered networks with an empty argument list"""
for extra_network_name, extra_network_args in extra_network_data.items():
extra_network = extra_network_registry.get(extra_network_name, None)
if extra_network is None:
print(f"Skipping unknown extra network: {extra_network_name}")
continue
try:
extra_network.activate(p, extra_network_args)
except Exception as e:
errors.display(e, f"activating extra network {extra_network_name} with arguments {extra_network_args}")
for extra_network_name, extra_network in extra_network_registry.items():
args = extra_network_data.get(extra_network_name, None)
if args is not None:
continue
try:
extra_network.activate(p, [])
except Exception as e:
errors.display(e, f"activating extra network {extra_network_name}")
def deactivate(p, extra_network_data):
"""call deactivate for extra networks in extra_network_data in specified order, then call
deactivate for all remaining registered networks"""
for extra_network_name, extra_network_args in extra_network_data.items():
extra_network = extra_network_registry.get(extra_network_name, None)
if extra_network is None:
continue
try:
extra_network.deactivate(p)
except Exception as e:
errors.display(e, f"deactivating extra network {extra_network_name}")
for extra_network_name, extra_network in extra_network_registry.items():
args = extra_network_data.get(extra_network_name, None)
if args is not None:
continue
try:
extra_network.deactivate(p)
except Exception as e:
errors.display(e, f"deactivating unmentioned extra network {extra_network_name}")
re_extra_net = re.compile(r"<(\w+):([^>]+)>")
def parse_prompt(prompt):
res = defaultdict(list)
def found(m):
name = m.group(1)
args = m.group(2)
res[name].append(ExtraNetworkParams(items=args.split(":")))
return ""
prompt = re.sub(re_extra_net, found, prompt)
return prompt, res
def parse_prompts(prompts):
res = []
extra_data = None
for prompt in prompts:
updated_prompt, parsed_extra_data = parse_prompt(prompt)
if extra_data is None:
extra_data = parsed_extra_data
res.append(updated_prompt)
return res, extra_data
from modules import extra_networks
from modules.hypernetworks import hypernetwork
class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
def __init__(self):
super().__init__('hypernet')
def activate(self, p, params_list):
names = []
multipliers = []
for params in params_list:
assert len(params.items) > 0
names.append(params.items[0])
multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
hypernetwork.load_hypernetworks(names, multipliers)
def deactivate(self, p):
pass
This diff is collapsed.
import base64 import base64
import io import io
import math
import os import os
import re import re
from pathlib import Path from pathlib import Path
import gradio as gr import gradio as gr
from modules.shared import script_path from modules.paths import data_path
from modules import shared from modules import shared, ui_tempdir, script_callbacks
import tempfile import tempfile
from PIL import Image from PIL import Image
...@@ -36,9 +37,15 @@ def quote(text): ...@@ -36,9 +37,15 @@ def quote(text):
def image_from_url_text(filedata): def image_from_url_text(filedata):
if type(filedata) == dict and filedata["is_file"]: if filedata is None:
return None
if type(filedata) == list and len(filedata) > 0 and type(filedata[0]) == dict and filedata[0].get("is_file", False):
filedata = filedata[0]
if type(filedata) == dict and filedata.get("is_file", False):
filename = filedata["name"] filename = filedata["name"]
is_in_right_dir = any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in shared.demo.temp_dirs) is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename)
assert is_in_right_dir, 'trying to open image file outside of allowed directories' assert is_in_right_dir, 'trying to open image file outside of allowed directories'
return Image.open(filename) return Image.open(filename)
...@@ -72,8 +79,6 @@ def integrate_settings_paste_fields(component_dict): ...@@ -72,8 +79,6 @@ def integrate_settings_paste_fields(component_dict):
from modules import ui from modules import ui
settings_map = { settings_map = {
'sd_hypernetwork': 'Hypernet',
'sd_hypernetwork_strength': 'Hypernet strength',
'CLIP_stop_at_last_layers': 'Clip skip', 'CLIP_stop_at_last_layers': 'Clip skip',
'inpainting_mask_weight': 'Conditional mask weight', 'inpainting_mask_weight': 'Conditional mask weight',
'sd_model_checkpoint': 'Model hash', 'sd_model_checkpoint': 'Model hash',
...@@ -93,7 +98,7 @@ def integrate_settings_paste_fields(component_dict): ...@@ -93,7 +98,7 @@ def integrate_settings_paste_fields(component_dict):
def create_buttons(tabs_list): def create_buttons(tabs_list):
buttons = {} buttons = {}
for tab in tabs_list: for tab in tabs_list:
buttons[tab] = gr.Button(f"Send to {tab}") buttons[tab] = gr.Button(f"Send to {tab}", elem_id=f"{tab}_tab")
return buttons return buttons
...@@ -102,35 +107,57 @@ def bind_buttons(buttons, send_image, send_generate_info): ...@@ -102,35 +107,57 @@ def bind_buttons(buttons, send_image, send_generate_info):
bind_list.append([buttons, send_image, send_generate_info]) bind_list.append([buttons, send_image, send_generate_info])
def send_image_and_dimensions(x):
if isinstance(x, Image.Image):
img = x
else:
img = image_from_url_text(x)
if shared.opts.send_size and isinstance(img, Image.Image):
w = img.width
h = img.height
else:
w = gr.update()
h = gr.update()
return img, w, h
def run_bind(): def run_bind():
for buttons, send_image, send_generate_info in bind_list: for buttons, source_image_component, send_generate_info in bind_list:
for tab in buttons: for tab in buttons:
button = buttons[tab] button = buttons[tab]
if send_image and paste_fields[tab]["init_img"]: destination_image_component = paste_fields[tab]["init_img"]
if type(send_image) == gr.Gallery: fields = paste_fields[tab]["fields"]
button.click(
fn=lambda x: image_from_url_text(x), destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None)
_js="extract_image_from_gallery", destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None)
inputs=[send_image],
outputs=[paste_fields[tab]["init_img"]], if source_image_component and destination_image_component:
) if isinstance(source_image_component, gr.Gallery):
func = send_image_and_dimensions if destination_width_component else image_from_url_text
jsfunc = "extract_image_from_gallery"
else: else:
func = send_image_and_dimensions if destination_width_component else lambda x: x
jsfunc = None
button.click( button.click(
fn=lambda x: x, fn=func,
inputs=[send_image], _js=jsfunc,
outputs=[paste_fields[tab]["init_img"]], inputs=[source_image_component],
outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component],
) )
if send_generate_info and paste_fields[tab]["fields"] is not None: if send_generate_info and fields is not None:
if send_generate_info in paste_fields: if send_generate_info in paste_fields:
paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (['Size-1', 'Size-2'] if shared.opts.send_size else []) + (["Seed"] if shared.opts.send_seed else []) paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else [])
button.click( button.click(
fn=lambda *x: x, fn=lambda *x: x,
inputs=[field for field, name in paste_fields[send_generate_info]["fields"] if name in paste_field_names], inputs=[field for field, name in paste_fields[send_generate_info]["fields"] if name in paste_field_names],
outputs=[field for field, name in paste_fields[tab]["fields"] if name in paste_field_names], outputs=[field for field, name in fields if name in paste_field_names],
) )
else: else:
connect_paste(button, paste_fields[tab]["fields"], send_generate_info) connect_paste(button, fields, send_generate_info)
button.click( button.click(
fn=None, fn=None,
...@@ -164,6 +191,39 @@ def find_hypernetwork_key(hypernet_name, hypernet_hash=None): ...@@ -164,6 +191,39 @@ def find_hypernetwork_key(hypernet_name, hypernet_hash=None):
return None return None
def restore_old_hires_fix_params(res):
"""for infotexts that specify old First pass size parameter, convert it into
width, height, and hr scale"""
firstpass_width = res.get('First pass size-1', None)
firstpass_height = res.get('First pass size-2', None)
if shared.opts.use_old_hires_fix_width_height:
hires_width = int(res.get("Hires resize-1", 0))
hires_height = int(res.get("Hires resize-2", 0))
if hires_width and hires_height:
res['Size-1'] = hires_width
res['Size-2'] = hires_height
return
if firstpass_width is None or firstpass_height is None:
return
firstpass_width, firstpass_height = int(firstpass_width), int(firstpass_height)
width = int(res.get("Size-1", 512))
height = int(res.get("Size-2", 512))
if firstpass_width == 0 or firstpass_height == 0:
from modules import processing
firstpass_width, firstpass_height = processing.old_hires_fix_first_pass_dimensions(width, height)
res['Size-1'] = firstpass_width
res['Size-2'] = firstpass_height
res['Hires resize-1'] = width
res['Hires resize-2'] = height
def parse_generation_parameters(x: str): def parse_generation_parameters(x: str):
"""parses generation parameters string, the one you see in text field under the picture in UI: """parses generation parameters string, the one you see in text field under the picture in UI:
``` ```
...@@ -213,13 +273,15 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model ...@@ -213,13 +273,15 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
if "Clip skip" not in res: if "Clip skip" not in res:
res["Clip skip"] = "1" res["Clip skip"] = "1"
if "Hypernet strength" not in res: hypernet = res.get("Hypernet", None)
res["Hypernet strength"] = "1" if hypernet is not None:
res["Prompt"] += f"""<hypernet:{hypernet}:{res.get("Hypernet strength", "1.0")}>"""
if "Hires resize-1" not in res:
res["Hires resize-1"] = 0
res["Hires resize-2"] = 0
if "Hypernet" in res: restore_old_hires_fix_params(res)
hypernet_name = res["Hypernet"]
hypernet_hash = res.get("Hypernet hash", None)
res["Hypernet"] = find_hypernetwork_key(hypernet_name, hypernet_hash)
return res return res
...@@ -227,12 +289,13 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model ...@@ -227,12 +289,13 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
def connect_paste(button, paste_fields, input_comp, jsfunc=None): def connect_paste(button, paste_fields, input_comp, jsfunc=None):
def paste_func(prompt): def paste_func(prompt):
if not prompt and not shared.cmd_opts.hide_ui_dir_config: if not prompt and not shared.cmd_opts.hide_ui_dir_config:
filename = os.path.join(script_path, "params.txt") filename = os.path.join(data_path, "params.txt")
if os.path.exists(filename): if os.path.exists(filename):
with open(filename, "r", encoding="utf8") as file: with open(filename, "r", encoding="utf8") as file:
prompt = file.read() prompt = file.read()
params = parse_generation_parameters(prompt) params = parse_generation_parameters(prompt)
script_callbacks.infotext_pasted_callback(prompt, params)
res = [] res = []
for output, key in paste_fields: for output, key in paste_fields:
......
...@@ -6,12 +6,11 @@ import facexlib ...@@ -6,12 +6,11 @@ import facexlib
import gfpgan import gfpgan
import modules.face_restoration import modules.face_restoration
from modules import shared, devices, modelloader from modules import paths, shared, devices, modelloader
from modules.paths import models_path
model_dir = "GFPGAN" model_dir = "GFPGAN"
user_path = None user_path = None
model_path = os.path.join(models_path, model_dir) model_path = os.path.join(paths.models_path, model_dir)
model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth" model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
have_gfpgan = False have_gfpgan = False
loaded_gfpgan_model = None loaded_gfpgan_model = None
......
import hashlib
import json
import os.path
import filelock
from modules.paths import data_path
cache_filename = os.path.join(data_path, "cache.json")
cache_data = None
def dump_cache():
with filelock.FileLock(cache_filename+".lock"):
with open(cache_filename, "w", encoding="utf8") as file:
json.dump(cache_data, file, indent=4)
def cache(subsection):
global cache_data
if cache_data is None:
with filelock.FileLock(cache_filename+".lock"):
if not os.path.isfile(cache_filename):
cache_data = {}
else:
with open(cache_filename, "r", encoding="utf8") as file:
cache_data = json.load(file)
s = cache_data.get(subsection, {})
cache_data[subsection] = s
return s
def calculate_sha256(filename):
hash_sha256 = hashlib.sha256()
blksize = 1024 * 1024
with open(filename, "rb") as f:
for chunk in iter(lambda: f.read(blksize), b""):
hash_sha256.update(chunk)
return hash_sha256.hexdigest()
def sha256_from_cache(filename, title):
hashes = cache("hashes")
ondisk_mtime = os.path.getmtime(filename)
if title not in hashes:
return None
cached_sha256 = hashes[title].get("sha256", None)
cached_mtime = hashes[title].get("mtime", 0)
if ondisk_mtime > cached_mtime or cached_sha256 is None:
return None
return cached_sha256
def sha256(filename, title):
hashes = cache("hashes")
sha256_value = sha256_from_cache(filename, title)
if sha256_value is not None:
return sha256_value
print(f"Calculating sha256 for {filename}: ", end='')
sha256_value = calculate_sha256(filename)
print(f"{sha256_value}")
hashes[title] = {
"mtime": os.path.getmtime(filename),
"sha256": sha256_value,
}
dump_cache()
return sha256_value
This diff is collapsed.
...@@ -9,15 +9,15 @@ from modules import devices, sd_hijack, shared ...@@ -9,15 +9,15 @@ from modules import devices, sd_hijack, shared
not_available = ["hardswish", "multiheadattention"] not_available = ["hardswish", "multiheadattention"]
keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available) keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout) def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {filename}", "" return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {filename}", ""
def train_hypernetwork(*args): def train_hypernetwork(*args):
shared.loaded_hypernetworks = []
initial_hypernetwork = shared.loaded_hypernetwork
assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible' assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
...@@ -34,7 +34,6 @@ Hypernetwork saved to {html.escape(filename)} ...@@ -34,7 +34,6 @@ Hypernetwork saved to {html.escape(filename)}
except Exception: except Exception:
raise raise
finally: finally:
shared.loaded_hypernetwork = initial_hypernetwork
shared.sd_model.cond_stage_model.to(devices.device) shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device) shared.sd_model.first_stage_model.to(devices.device)
sd_hijack.apply_optimizations() sd_hijack.apply_optimizations()
......
...@@ -39,11 +39,14 @@ def image_grid(imgs, batch_size=1, rows=None): ...@@ -39,11 +39,14 @@ def image_grid(imgs, batch_size=1, rows=None):
cols = math.ceil(len(imgs) / rows) cols = math.ceil(len(imgs) / rows)
params = script_callbacks.ImageGridLoopParams(imgs, cols, rows)
script_callbacks.image_grid_callback(params)
w, h = imgs[0].size w, h = imgs[0].size
grid = Image.new('RGB', size=(cols * w, rows * h), color='black') grid = Image.new('RGB', size=(params.cols * w, params.rows * h), color='black')
for i, img in enumerate(imgs): for i, img in enumerate(params.imgs):
grid.paste(img, box=(i % cols * w, i // cols * h)) grid.paste(img, box=(i % params.cols * w, i // params.cols * h))
return grid return grid
...@@ -192,7 +195,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts): ...@@ -192,7 +195,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in
ver_texts] ver_texts]
pad_top = max(hor_text_heights) + line_spacing * 2 pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2
result = Image.new("RGB", (im.width + pad_left, im.height + pad_top), "white") result = Image.new("RGB", (im.width + pad_left, im.height + pad_top), "white")
result.paste(im, (pad_left, pad_top)) result.paste(im, (pad_left, pad_top))
...@@ -227,16 +230,32 @@ def draw_prompt_matrix(im, width, height, all_prompts): ...@@ -227,16 +230,32 @@ def draw_prompt_matrix(im, width, height, all_prompts):
return draw_grid_annotations(im, width, height, hor_texts, ver_texts) return draw_grid_annotations(im, width, height, hor_texts, ver_texts)
def resize_image(resize_mode, im, width, height): def resize_image(resize_mode, im, width, height, upscaler_name=None):
"""
Resizes an image with the specified resize_mode, width, and height.
Args:
resize_mode: The mode to use when resizing the image.
0: Resize the image to the specified width and height.
1: Resize the image to fill the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
2: Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
im: The image to resize.
width: The width to resize the image to.
height: The height to resize the image to.
upscaler_name: The name of the upscaler to use. If not provided, defaults to opts.upscaler_for_img2img.
"""
upscaler_name = upscaler_name or opts.upscaler_for_img2img
def resize(im, w, h): def resize(im, w, h):
if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None" or im.mode == 'L': if upscaler_name is None or upscaler_name == "None" or im.mode == 'L':
return im.resize((w, h), resample=LANCZOS) return im.resize((w, h), resample=LANCZOS)
scale = max(w / im.width, h / im.height) scale = max(w / im.width, h / im.height)
if scale > 1.0: if scale > 1.0:
upscalers = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_img2img] upscalers = [x for x in shared.sd_upscalers if x.name == upscaler_name]
assert len(upscalers) > 0, f"could not find upscaler named {opts.upscaler_for_img2img}" assert len(upscalers) > 0, f"could not find upscaler named {upscaler_name}"
upscaler = upscalers[0] upscaler = upscalers[0]
im = upscaler.scaler.upscale(im, scale, upscaler.data_path) im = upscaler.scaler.upscale(im, scale, upscaler.data_path)
...@@ -525,6 +544,9 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i ...@@ -525,6 +544,9 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data) image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
elif extension.lower() in (".jpg", ".jpeg", ".webp"): elif extension.lower() in (".jpg", ".jpeg", ".webp"):
if image_to_save.mode == 'RGBA':
image_to_save = image_to_save.convert("RGB")
image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality) image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality)
if opts.enable_pnginfo and info is not None: if opts.enable_pnginfo and info is not None:
...@@ -583,6 +605,7 @@ def read_info_from_image(image): ...@@ -583,6 +605,7 @@ def read_info_from_image(image):
except ValueError: except ValueError:
exif_comment = exif_comment.decode('utf8', errors="ignore") exif_comment = exif_comment.decode('utf8', errors="ignore")
if exif_comment:
items['exif comment'] = exif_comment items['exif comment'] = exif_comment
geninfo = exif_comment geninfo = exif_comment
......
...@@ -16,11 +16,16 @@ import modules.images as images ...@@ -16,11 +16,16 @@ import modules.images as images
import modules.scripts import modules.scripts
def process_batch(p, input_dir, output_dir, args): def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
processing.fix_seed(p) processing.fix_seed(p)
images = shared.listfiles(input_dir) images = shared.listfiles(input_dir)
inpaint_masks = shared.listfiles(inpaint_mask_dir)
is_inpaint_batch = inpaint_mask_dir and len(inpaint_masks) > 0
if is_inpaint_batch:
print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.")
print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.") print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
save_normally = output_dir == '' save_normally = output_dir == ''
...@@ -43,6 +48,15 @@ def process_batch(p, input_dir, output_dir, args): ...@@ -43,6 +48,15 @@ def process_batch(p, input_dir, output_dir, args):
img = ImageOps.exif_transpose(img) img = ImageOps.exif_transpose(img)
p.init_images = [img] * p.batch_size p.init_images = [img] * p.batch_size
if is_inpaint_batch:
# try to find corresponding mask for an image using simple filename matching
mask_image_path = os.path.join(inpaint_mask_dir, os.path.basename(image))
# if not found use first one ("same mask for all images" use-case)
if not mask_image_path in inpaint_masks:
mask_image_path = inpaint_masks[0]
mask_image = Image.open(mask_image_path)
p.image_mask = mask_image
proc = modules.scripts.scripts_img2img.run(p, *args) proc = modules.scripts.scripts_img2img.run(p, *args)
if proc is None: if proc is None:
proc = process_images(p) proc = process_images(p)
...@@ -59,38 +73,34 @@ def process_batch(p, input_dir, output_dir, args): ...@@ -59,38 +73,34 @@ def process_batch(p, input_dir, output_dir, args):
processed_image.save(os.path.join(output_dir, filename)) processed_image.save(os.path.join(output_dir, filename))
def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_with_mask_orig, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args): def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, *args):
is_inpaint = mode == 1 is_batch = mode == 5
is_batch = mode == 2
if is_inpaint: if mode == 0: # img2img
# Drawn mask image = init_img.convert("RGB")
if mask_mode == 0: mask = None
is_mask_sketch = isinstance(init_img_with_mask, dict) elif mode == 1: # img2img sketch
is_mask_paint = not is_mask_sketch image = sketch.convert("RGB")
if is_mask_sketch: mask = None
# Sketch: mask iff. not transparent elif mode == 2: # inpaint
image, mask = init_img_with_mask["image"], init_img_with_mask["mask"] image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1') alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L') mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L')
else: image = image.convert("RGB")
# Color-sketch: mask iff. painted over elif mode == 3: # inpaint sketch
image = init_img_with_mask image = inpaint_color_sketch
orig = init_img_with_mask_orig or init_img_with_mask orig = inpaint_color_sketch_orig or inpaint_color_sketch
pred = np.any(np.array(image) != np.array(orig), axis=-1) pred = np.any(np.array(image) != np.array(orig), axis=-1)
mask = Image.fromarray(pred.astype(np.uint8) * 255, "L") mask = Image.fromarray(pred.astype(np.uint8) * 255, "L")
mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100) mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100)
blur = ImageFilter.GaussianBlur(mask_blur) blur = ImageFilter.GaussianBlur(mask_blur)
image = Image.composite(image.filter(blur), orig, mask.filter(blur)) image = Image.composite(image.filter(blur), orig, mask.filter(blur))
image = image.convert("RGB") image = image.convert("RGB")
# Uploaded mask elif mode == 4: # inpaint upload mask
else:
image = init_img_inpaint image = init_img_inpaint
mask = init_mask_inpaint mask = init_mask_inpaint
# No mask
else: else:
image = init_img image = None
mask = None mask = None
# Use the EXIF orientation of photos taken by smartphones. # Use the EXIF orientation of photos taken by smartphones.
...@@ -105,7 +115,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro ...@@ -105,7 +115,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
outpath_grids=opts.outdir_grids or opts.outdir_img2img_grids, outpath_grids=opts.outdir_grids or opts.outdir_img2img_grids,
prompt=prompt, prompt=prompt,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
styles=[prompt_style, prompt_style2], styles=prompt_styles,
seed=seed, seed=seed,
subseed=subseed, subseed=subseed,
subseed_strength=subseed_strength, subseed_strength=subseed_strength,
...@@ -143,7 +153,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro ...@@ -143,7 +153,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
if is_batch: if is_batch:
assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled" assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, args) process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args)
processed = Processed(p, [], p.seed, "") processed = Processed(p, [], p.seed, "")
else: else:
...@@ -162,4 +172,4 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro ...@@ -162,4 +172,4 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
if opts.do_not_show_images: if opts.do_not_show_images:
processed.images = [] processed.images = []
return processed.images, generation_info_js, plaintext_to_html(processed.info) return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)
This diff is collapsed.
...@@ -71,10 +71,13 @@ class MemUsageMonitor(threading.Thread): ...@@ -71,10 +71,13 @@ class MemUsageMonitor(threading.Thread):
def read(self): def read(self):
if not self.disabled: if not self.disabled:
free, total = torch.cuda.mem_get_info() free, total = torch.cuda.mem_get_info()
self.data["free"] = free
self.data["total"] = total self.data["total"] = total
torch_stats = torch.cuda.memory_stats(self.device) torch_stats = torch.cuda.memory_stats(self.device)
self.data["active"] = torch_stats["active.all.current"]
self.data["active_peak"] = torch_stats["active_bytes.all.peak"] self.data["active_peak"] = torch_stats["active_bytes.all.peak"]
self.data["reserved"] = torch_stats["reserved_bytes.all.current"]
self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"] self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"]
self.data["system_peak"] = total - self.data["min_free"] self.data["system_peak"] = total - self.data["min_free"]
......
This diff is collapsed.
This diff is collapsed.
...@@ -4,7 +4,15 @@ import sys ...@@ -4,7 +4,15 @@ import sys
import modules.safe import modules.safe
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
models_path = os.path.join(script_path, "models")
# Parse the --data-dir flag first so we can use it as a base for our other argument default values
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",)
cmd_opts_pre = parser.parse_known_args()[0]
data_path = cmd_opts_pre.data_dir
models_path = os.path.join(data_path, "models")
# data_path = cmd_opts_pre.data
sys.path.insert(0, script_path) sys.path.insert(0, script_path)
# search for directory of stable diffusion in following places # search for directory of stable diffusion in following places
...@@ -38,3 +46,17 @@ for d, must_exist, what, options in path_dirs: ...@@ -38,3 +46,17 @@ for d, must_exist, what, options in path_dirs:
else: else:
sys.path.append(d) sys.path.append(d)
paths[what] = d paths[what] = d
class Prioritize:
def __init__(self, name):
self.name = name
self.path = None
def __enter__(self):
self.path = sys.path.copy()
sys.path = [paths[self.name]] + sys.path
def __exit__(self, exc_type, exc_val, exc_tb):
sys.path = self.path
self.path = None
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment