Commit ea2cc26f authored by nanahira's avatar nanahira

Merge branch 'master' of github.com:AUTOMATIC1111/stable-diffusion-webui

parents c3fa4486 0cc0ee1b
...@@ -37,20 +37,20 @@ body: ...@@ -37,20 +37,20 @@ body:
id: what-should id: what-should
attributes: attributes:
label: What should have happened? label: What should have happened?
description: tell what you think the normal behavior should be description: Tell what you think the normal behavior should be
validations: validations:
required: true required: true
- type: input - type: input
id: commit id: commit
attributes: attributes:
label: Commit where the problem happens label: Commit where the problem happens
description: Which commit are you running ? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit hash** shown in the cmd/terminal when you launch the UI) description: Which commit are you running ? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit** link at the bottom of the UI, or from the cmd/terminal if you can't launch it.)
validations: validations:
required: true required: true
- type: dropdown - type: dropdown
id: platforms id: platforms
attributes: attributes:
label: What platforms do you use to access UI ? label: What platforms do you use to access the UI ?
multiple: true multiple: true
options: options:
- Windows - Windows
...@@ -74,10 +74,27 @@ body: ...@@ -74,10 +74,27 @@ body:
id: cmdargs id: cmdargs
attributes: attributes:
label: Command Line Arguments label: Command Line Arguments
description: Are you using any launching parameters/command line arguments (modified webui-user.py) ? If yes, please write them below description: Are you using any launching parameters/command line arguments (modified webui-user .bat/.sh) ? If yes, please write them below. Write "No" otherwise.
render: Shell render: Shell
validations:
required: true
- type: textarea
id: extensions
attributes:
label: List of extensions
description: Are you using any extensions other than built-ins? If yes, provide a list, you can copy it at "Extensions" tab. Write "No" otherwise.
validations:
required: true
- type: textarea
id: logs
attributes:
label: Console logs
description: Please provide **full** cmd/terminal logs from the moment you started UI to the end of it, after your bug happened. If it's very long, provide a link to pastebin or similar service.
render: Shell
validations:
required: true
- type: textarea - type: textarea
id: misc id: misc
attributes: attributes:
label: Additional information, context and logs label: Additional information
description: Please provide us with any relevant additional info, context or log output. description: Please provide us with any relevant additional info or context.
name: Feature request name: Feature request
description: Suggest an idea for this project description: Suggest an idea for this project
title: "[Feature Request]: " title: "[Feature Request]: "
labels: ["suggestion"] labels: ["enhancement"]
body: body:
- type: checkboxes - type: checkboxes
......
...@@ -18,8 +18,8 @@ More technical discussion about your changes go here, plus anything that a maint ...@@ -18,8 +18,8 @@ More technical discussion about your changes go here, plus anything that a maint
List the environment you have developed / tested this on. As per the contributing page, changes should be able to work on Windows out of the box. List the environment you have developed / tested this on. As per the contributing page, changes should be able to work on Windows out of the box.
- OS: [e.g. Windows, Linux] - OS: [e.g. Windows, Linux]
- Browser [e.g. chrome, safari] - Browser: [e.g. chrome, safari]
- Graphics card [e.g. NVIDIA RTX 2080 8GB, AMD RX 6600 8GB] - Graphics card: [e.g. NVIDIA RTX 2080 8GB, AMD RX 6600 8GB]
**Screenshots or videos of your changes** **Screenshots or videos of your changes**
......
...@@ -19,22 +19,19 @@ jobs: ...@@ -19,22 +19,19 @@ jobs:
- name: Checkout Code - name: Checkout Code
uses: actions/checkout@v3 uses: actions/checkout@v3
- name: Set up Python 3.10 - name: Set up Python 3.10
uses: actions/setup-python@v3 uses: actions/setup-python@v4
with: with:
python-version: 3.10.6 python-version: 3.10.6
- uses: actions/cache@v2 cache: pip
with: cache-dependency-path: |
path: ~/.cache/pip **/requirements*txt
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
restore-keys: |
${{ runner.os }}-pip-
- name: Install PyLint - name: Install PyLint
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install pylint pip install pylint
# This lets PyLint check to see if it can resolve imports # This lets PyLint check to see if it can resolve imports
- name: Install dependencies - name: Install dependencies
run : | run: |
export COMMANDLINE_ARGS="--skip-torch-cuda-test --exit" export COMMANDLINE_ARGS="--skip-torch-cuda-test --exit"
python launch.py python launch.py
- name: Analysing the code with pylint - name: Analysing the code with pylint
......
name: Run basic features tests on CPU with empty SD model
on:
- push
- pull_request
jobs:
test:
runs-on: ubuntu-latest
steps:
- name: Checkout Code
uses: actions/checkout@v3
- name: Set up Python 3.10
uses: actions/setup-python@v4
with:
python-version: 3.10.6
cache: pip
cache-dependency-path: |
**/requirements*txt
- name: Run tests
run: python launch.py --tests --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test
- name: Upload main app stdout-stderr
uses: actions/upload-artifact@v3
if: always()
with:
name: stdout-stderr
path: |
test/stdout.txt
test/stderr.txt
__pycache__ __pycache__
*.ckpt *.ckpt
*.safetensors
*.pth *.pth
/ESRGAN/* /ESRGAN/*
/SwinIR/* /SwinIR/*
...@@ -31,3 +32,4 @@ notification.mp3 ...@@ -31,3 +32,4 @@ notification.mp3
/extensions /extensions
/test/stdout.txt /test/stdout.txt
/test/stderr.txt /test/stderr.txt
/cache.json
This diff is collapsed.
# Stable Diffusion web UI # Stable Diffusion web UI
A browser interface based on Gradio library for Stable Diffusion. A browser interface based on Gradio library for Stable Diffusion.
![](txt2img_Screenshot.png) ![](screenshot.png)
Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Scripts) wiki page for extra scripts developed by users.
## Features ## Features
[Detailed feature showcase with images](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features): [Detailed feature showcase with images](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features):
...@@ -19,7 +17,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web ...@@ -19,7 +17,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
- a man in a (tuxedo:1.21) - alternative syntax - a man in a (tuxedo:1.21) - alternative syntax
- select text and press ctrl+up or ctrl+down to automatically adjust attention to selected text (code contributed by anonymous user) - select text and press ctrl+up or ctrl+down to automatically adjust attention to selected text (code contributed by anonymous user)
- Loopback, run img2img processing multiple times - Loopback, run img2img processing multiple times
- X/Y plot, a way to draw a 2 dimensional plot of images with different parameters - X/Y/Z plot, a way to draw a 3 dimensional plot of images with different parameters
- Textual Inversion - Textual Inversion
- have as many embeddings as you want and use any names you like for them - have as many embeddings as you want and use any names you like for them
- use multiple embeddings with different numbers of vectors per token - use multiple embeddings with different numbers of vectors per token
...@@ -51,9 +49,9 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web ...@@ -51,9 +49,9 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
- Running arbitrary python code from UI (must run with --allow-code to enable) - Running arbitrary python code from UI (must run with --allow-code to enable)
- Mouseover hints for most UI elements - Mouseover hints for most UI elements
- Possible to change defaults/mix/max/step values for UI elements via text config - Possible to change defaults/mix/max/step values for UI elements via text config
- Random artist button
- Tiling support, a checkbox to create images that can be tiled like textures - Tiling support, a checkbox to create images that can be tiled like textures
- Progress bar and live image generation preview - Progress bar and live image generation preview
- Can use a separate neural network to produce previews with almost none VRAM or compute requirement
- Negative prompt, an extra text field that allows you to list what you don't want to see in generated image - Negative prompt, an extra text field that allows you to list what you don't want to see in generated image
- Styles, a way to save part of prompt and easily apply them via dropdown later - Styles, a way to save part of prompt and easily apply them via dropdown later
- Variations, a way to generate same image but with tiny differences - Variations, a way to generate same image but with tiny differences
...@@ -78,32 +76,22 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web ...@@ -78,32 +76,22 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
- hypernetworks and embeddings options - hypernetworks and embeddings options
- Preprocessing images: cropping, mirroring, autotagging using BLIP or deepdanbooru (for anime) - Preprocessing images: cropping, mirroring, autotagging using BLIP or deepdanbooru (for anime)
- Clip skip - Clip skip
- Use Hypernetworks - Hypernetworks
- Use VAEs - Loras (same as Hypernetworks but more pretty)
- A sparate UI where you can choose, with preview, which embeddings, hypernetworks or Loras to add to your prompt.
- Can select to load a different VAE from settings screen
- Estimated completion time in progress bar - Estimated completion time in progress bar
- API - API
- Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML. - Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML.
- via extension: [Aesthetic Gradients](https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients), a way to generate images with a specific aesthetic by using clip images embds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients)) - via extension: [Aesthetic Gradients](https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients), a way to generate images with a specific aesthetic by using clip images embeds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients))
- [Stable Diffusion 2.0](https://github.com/Stability-AI/stablediffusion) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20) for instructions
## Where are Aesthetic Gradients?!?! - [Alt-Diffusion](https://arxiv.org/abs/2211.06679) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#alt-diffusion) for instructions
Aesthetic Gradients are now an extension. You can install it using git: - Now without any bad letters!
- Load checkpoints in safetensors format
```commandline - Eased resolution restriction: generated image's domension must be a multiple of 8 rather than 64
git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients extensions/aesthetic-gradients - Now with a license!
``` - Reorder elements in the UI from settings screen
-
After running this command, make sure that you have `aesthetic-gradients` dir in webui's `extensions` directory and restart
the UI. The interface for Aesthetic Gradients should appear exactly the same as it was.
## Where is History/Image browser?!?!
Image browser is now an extension. You can install it using git:
```commandline
git clone https://github.com/yfszzx/stable-diffusion-webui-images-browser extensions/images-browser
```
After running this command, make sure that you have `images-browser` dir in webui's `extensions` directory and restart
the UI. The interface for Image browser should appear exactly the same as it was.
## Installation and Running ## Installation and Running
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs. Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
...@@ -116,9 +104,7 @@ Alternatively, use online services (like Google Colab): ...@@ -116,9 +104,7 @@ Alternatively, use online services (like Google Colab):
1. Install [Python 3.10.6](https://www.python.org/downloads/windows/), checking "Add Python to PATH" 1. Install [Python 3.10.6](https://www.python.org/downloads/windows/), checking "Add Python to PATH"
2. Install [git](https://git-scm.com/download/win). 2. Install [git](https://git-scm.com/download/win).
3. Download the stable-diffusion-webui repository, for example by running `git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git`. 3. Download the stable-diffusion-webui repository, for example by running `git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git`.
4. Place `model.ckpt` in the `models` directory (see [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) for where to get it). 4. Run `webui-user.bat` from Windows Explorer as normal, non-administrator, user.
5. _*(Optional)*_ Place `GFPGANv1.4.pth` in the base directory, alongside `webui.py` (see [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) for where to get it).
6. Run `webui-user.bat` from Windows Explorer as normal, non-administrator, user.
### Automatic Installation on Linux ### Automatic Installation on Linux
1. Install the dependencies: 1. Install the dependencies:
...@@ -134,7 +120,7 @@ sudo pacman -S wget git python3 ...@@ -134,7 +120,7 @@ sudo pacman -S wget git python3
```bash ```bash
bash <(wget -qO- https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/master/webui.sh) bash <(wget -qO- https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/master/webui.sh)
``` ```
3. Run `webui.sh`.
### Installation on Apple Silicon ### Installation on Apple Silicon
Find the instructions [here](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Installation-on-Apple-Silicon). Find the instructions [here](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Installation-on-Apple-Silicon).
...@@ -146,6 +132,8 @@ Here's how to add code to this repo: [Contributing](https://github.com/AUTOMATIC ...@@ -146,6 +132,8 @@ Here's how to add code to this repo: [Contributing](https://github.com/AUTOMATIC
The documentation was moved from this README over to the project's [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki). The documentation was moved from this README over to the project's [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki).
## Credits ## Credits
Licenses for borrowed code can be found in `Settings -> Licenses` screen, and also in `html/licenses.html` file.
- Stable Diffusion - https://github.com/CompVis/stable-diffusion, https://github.com/CompVis/taming-transformers - Stable Diffusion - https://github.com/CompVis/stable-diffusion, https://github.com/CompVis/taming-transformers
- k-diffusion - https://github.com/crowsonkb/k-diffusion.git - k-diffusion - https://github.com/crowsonkb/k-diffusion.git
- GFPGAN - https://github.com/TencentARC/GFPGAN.git - GFPGAN - https://github.com/TencentARC/GFPGAN.git
...@@ -154,9 +142,11 @@ The documentation was moved from this README over to the project's [wiki](https: ...@@ -154,9 +142,11 @@ The documentation was moved from this README over to the project's [wiki](https:
- SwinIR - https://github.com/JingyunLiang/SwinIR - SwinIR - https://github.com/JingyunLiang/SwinIR
- Swin2SR - https://github.com/mv-lab/swin2sr - Swin2SR - https://github.com/mv-lab/swin2sr
- LDSR - https://github.com/Hafiidz/latent-diffusion - LDSR - https://github.com/Hafiidz/latent-diffusion
- MiDaS - https://github.com/isl-org/MiDaS
- Ideas for optimizations - https://github.com/basujindal/stable-diffusion - Ideas for optimizations - https://github.com/basujindal/stable-diffusion
- Cross Attention layer optimization - Doggettx - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing. - Cross Attention layer optimization - Doggettx - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing.
- Cross Attention layer optimization - InvokeAI, lstein - https://github.com/invoke-ai/InvokeAI (originally http://github.com/lstein/stable-diffusion) - Cross Attention layer optimization - InvokeAI, lstein - https://github.com/invoke-ai/InvokeAI (originally http://github.com/lstein/stable-diffusion)
- Sub-quadratic Cross Attention layer optimization - Alex Birch (https://github.com/Birch-san/diffusers/pull/1), Amin Rezaei (https://github.com/AminRezaei0x443/memory-efficient-attention)
- Textual Inversion - Rinon Gal - https://github.com/rinongal/textual_inversion (we're not using his code, but we are using his ideas). - Textual Inversion - Rinon Gal - https://github.com/rinongal/textual_inversion (we're not using his code, but we are using his ideas).
- Idea for SD upscale - https://github.com/jquesnelle/txt2imghd - Idea for SD upscale - https://github.com/jquesnelle/txt2imghd
- Noise generation for outpainting mk2 - https://github.com/parlance-zz/g-diffuser-bot - Noise generation for outpainting mk2 - https://github.com/parlance-zz/g-diffuser-bot
...@@ -164,6 +154,8 @@ The documentation was moved from this README over to the project's [wiki](https: ...@@ -164,6 +154,8 @@ The documentation was moved from this README over to the project's [wiki](https:
- Idea for Composable Diffusion - https://github.com/energy-based-model/Compositional-Visual-Generation-with-Composable-Diffusion-Models-PyTorch - Idea for Composable Diffusion - https://github.com/energy-based-model/Compositional-Visual-Generation-with-Composable-Diffusion-Models-PyTorch
- xformers - https://github.com/facebookresearch/xformers - xformers - https://github.com/facebookresearch/xformers
- DeepDanbooru - interrogator for anime diffusers https://github.com/KichangKim/DeepDanbooru - DeepDanbooru - interrogator for anime diffusers https://github.com/KichangKim/DeepDanbooru
- Sampling in float32 precision from a float16 UNet - marunine for the idea, Birch-san for the example Diffusers implementation (https://github.com/Birch-san/diffusers-play/tree/92feee6)
- Instruct pix2pix - Tim Brooks (star), Aleksander Holynski (star), Alexei A. Efros (no star) - https://github.com/timothybrooks/instruct-pix2pix
- Security advice - RyotaK - Security advice - RyotaK
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user. - Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
- (You) - (You)
This source diff could not be displayed because it is too large. You can view the blob instead.
model:
base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 10000 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: modules.xlmr.BertSeriesModelWithTransformation
params:
name: "XLMR-Large"
\ No newline at end of file
# File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
# See more details in LICENSE.
model:
base_learning_rate: 1.0e-04
target: modules.models.diffusion.ddpm_edit.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: edited
cond_stage_key: edit
# image_size: 64
# image_size: 32
image_size: 16
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: hybrid
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: false
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 0 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 8
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
data:
target: main.DataModuleFromConfig
params:
batch_size: 128
num_workers: 1
wrap: false
validation:
target: edit_dataset.EditDataset
params:
path: data/clip-filtered-dataset
cache_dir: data/
cache_name: data_10k
split: val
min_text_sim: 0.2
min_image_sim: 0.75
min_direction_sim: 0.2
max_samples_per_prompt: 1
min_resize_res: 512
max_resize_res: 512
crop_res: 512
output_as_edit: False
real_input: True
model:
base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 10000 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
model:
base_learning_rate: 7.5e-05
target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: hybrid # important
monitor: val/loss_simple_ema
scale_factor: 0.18215
finetune_keys: null
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 9 # 4 data + 4 downscaled image + 1 mask
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
import os
import gc import gc
import time import time
import warnings
import numpy as np import numpy as np
import torch import torch
...@@ -8,27 +8,47 @@ import torchvision ...@@ -8,27 +8,47 @@ import torchvision
from PIL import Image from PIL import Image
from einops import rearrange, repeat from einops import rearrange, repeat
from omegaconf import OmegaConf from omegaconf import OmegaConf
import safetensors.torch
from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import instantiate_from_config, ismap from ldm.util import instantiate_from_config, ismap
from modules import shared, sd_hijack
warnings.filterwarnings("ignore", category=UserWarning) cached_ldsr_model: torch.nn.Module = None
# Create LDSR Class # Create LDSR Class
class LDSR: class LDSR:
def load_model_from_config(self, half_attention): def load_model_from_config(self, half_attention):
print(f"Loading model from {self.modelPath}") global cached_ldsr_model
pl_sd = torch.load(self.modelPath, map_location="cpu")
sd = pl_sd["state_dict"] if shared.opts.ldsr_cached and cached_ldsr_model is not None:
config = OmegaConf.load(self.yamlPath) print("Loading model from cache")
model = instantiate_from_config(config.model) model: torch.nn.Module = cached_ldsr_model
model.load_state_dict(sd, strict=False) else:
model.cuda() print(f"Loading model from {self.modelPath}")
if half_attention: _, extension = os.path.splitext(self.modelPath)
model = model.half() if extension.lower() == ".safetensors":
pl_sd = safetensors.torch.load_file(self.modelPath, device="cpu")
model.eval() else:
pl_sd = torch.load(self.modelPath, map_location="cpu")
sd = pl_sd["state_dict"] if "state_dict" in pl_sd else pl_sd
config = OmegaConf.load(self.yamlPath)
config.model.target = "ldm.models.diffusion.ddpm.LatentDiffusionV1"
model: torch.nn.Module = instantiate_from_config(config.model)
model.load_state_dict(sd, strict=False)
model = model.to(shared.device)
if half_attention:
model = model.half()
if shared.cmd_opts.opt_channelslast:
model = model.to(memory_format=torch.channels_last)
sd_hijack.model_hijack.hijack(model) # apply optimization
model.eval()
if shared.opts.ldsr_cached:
cached_ldsr_model = model
return {"model": model} return {"model": model}
def __init__(self, model_path, yaml_path): def __init__(self, model_path, yaml_path):
...@@ -93,7 +113,8 @@ class LDSR: ...@@ -93,7 +113,8 @@ class LDSR:
down_sample_method = 'Lanczos' down_sample_method = 'Lanczos'
gc.collect() gc.collect()
torch.cuda.empty_cache() if torch.cuda.is_available:
torch.cuda.empty_cache()
im_og = image im_og = image
width_og, height_og = im_og.size width_og, height_og = im_og.size
...@@ -130,7 +151,9 @@ class LDSR: ...@@ -130,7 +151,9 @@ class LDSR:
del model del model
gc.collect() gc.collect()
torch.cuda.empty_cache() if torch.cuda.is_available:
torch.cuda.empty_cache()
return a return a
...@@ -145,7 +168,7 @@ def get_cond(selected_path): ...@@ -145,7 +168,7 @@ def get_cond(selected_path):
c = rearrange(c, '1 c h w -> 1 h w c') c = rearrange(c, '1 c h w -> 1 h w c')
c = 2. * c - 1. c = 2. * c - 1.
c = c.to(torch.device("cuda")) c = c.to(shared.device)
example["LR_image"] = c example["LR_image"] = c
example["image"] = c_up example["image"] = c_up
......
import os
from modules import paths
def preload(parser):
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(paths.models_path, 'LDSR'))
...@@ -5,8 +5,9 @@ import traceback ...@@ -5,8 +5,9 @@ import traceback
from basicsr.utils.download_util import load_file_from_url from basicsr.utils.download_util import load_file_from_url
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
from modules.ldsr_model_arch import LDSR from ldsr_model_arch import LDSR
from modules import shared from modules import shared, script_callbacks
import sd_hijack_autoencoder, sd_hijack_ddpm_v1
class UpscalerLDSR(Upscaler): class UpscalerLDSR(Upscaler):
...@@ -24,6 +25,7 @@ class UpscalerLDSR(Upscaler): ...@@ -24,6 +25,7 @@ class UpscalerLDSR(Upscaler):
yaml_path = os.path.join(self.model_path, "project.yaml") yaml_path = os.path.join(self.model_path, "project.yaml")
old_model_path = os.path.join(self.model_path, "model.pth") old_model_path = os.path.join(self.model_path, "model.pth")
new_model_path = os.path.join(self.model_path, "model.ckpt") new_model_path = os.path.join(self.model_path, "model.ckpt")
safetensors_model_path = os.path.join(self.model_path, "model.safetensors")
if os.path.exists(yaml_path): if os.path.exists(yaml_path):
statinfo = os.stat(yaml_path) statinfo = os.stat(yaml_path)
if statinfo.st_size >= 10485760: if statinfo.st_size >= 10485760:
...@@ -32,8 +34,11 @@ class UpscalerLDSR(Upscaler): ...@@ -32,8 +34,11 @@ class UpscalerLDSR(Upscaler):
if os.path.exists(old_model_path): if os.path.exists(old_model_path):
print("Renaming model from model.pth to model.ckpt") print("Renaming model from model.pth to model.ckpt")
os.rename(old_model_path, new_model_path) os.rename(old_model_path, new_model_path)
model = load_file_from_url(url=self.model_url, model_dir=self.model_path, if os.path.exists(safetensors_model_path):
file_name="model.ckpt", progress=True) model = safetensors_model_path
else:
model = load_file_from_url(url=self.model_url, model_dir=self.model_path,
file_name="model.ckpt", progress=True)
yaml = load_file_from_url(url=self.yaml_url, model_dir=self.model_path, yaml = load_file_from_url(url=self.yaml_url, model_dir=self.model_path,
file_name="project.yaml", progress=True) file_name="project.yaml", progress=True)
...@@ -52,3 +57,13 @@ class UpscalerLDSR(Upscaler): ...@@ -52,3 +57,13 @@ class UpscalerLDSR(Upscaler):
return img return img
ddim_steps = shared.opts.ldsr_steps ddim_steps = shared.opts.ldsr_steps
return ldsr.super_resolution(img, ddim_steps, self.scale) return ldsr.super_resolution(img, ddim_steps, self.scale)
def on_ui_settings():
import gradio as gr
shared.opts.add_option("ldsr_steps", shared.OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}, section=('upscaling', "Upscaling")))
shared.opts.add_option("ldsr_cached", shared.OptionInfo(False, "Cache LDSR model in memory", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")))
script_callbacks.on_ui_settings(on_ui_settings)
This diff is collapsed.
This diff is collapsed.
from modules import extra_networks, shared
import lora
class ExtraNetworkLora(extra_networks.ExtraNetwork):
def __init__(self):
super().__init__('lora')
def activate(self, p, params_list):
additional = shared.opts.sd_lora
if additional != "" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0:
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
names = []
multipliers = []
for params in params_list:
assert len(params.items) > 0
names.append(params.items[0])
multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
lora.load_loras(names, multipliers)
def deactivate(self, p):
pass
import glob
import os
import re
import torch
from modules import shared, devices, sd_models
re_digits = re.compile(r"\d+")
re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)")
re_unet_mid_blocks = re.compile(r"lora_unet_mid_block_attentions_(\d+)_(.+)")
re_unet_up_blocks = re.compile(r"lora_unet_up_blocks_(\d+)_attentions_(\d+)_(.+)")
re_text_block = re.compile(r"lora_te_text_model_encoder_layers_(\d+)_(.+)")
def convert_diffusers_name_to_compvis(key):
def match(match_list, regex):
r = re.match(regex, key)
if not r:
return False
match_list.clear()
match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()])
return True
m = []
if match(m, re_unet_down_blocks):
return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[1]}_1_{m[2]}"
if match(m, re_unet_mid_blocks):
return f"diffusion_model_middle_block_1_{m[1]}"
if match(m, re_unet_up_blocks):
return f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_1_{m[2]}"
if match(m, re_text_block):
return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
return key
class LoraOnDisk:
def __init__(self, name, filename):
self.name = name
self.filename = filename
class LoraModule:
def __init__(self, name):
self.name = name
self.multiplier = 1.0
self.modules = {}
self.mtime = None
class LoraUpDownModule:
def __init__(self):
self.up = None
self.down = None
self.alpha = None
def assign_lora_names_to_compvis_modules(sd_model):
lora_layer_mapping = {}
for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():
lora_name = name.replace(".", "_")
lora_layer_mapping[lora_name] = module
module.lora_layer_name = lora_name
for name, module in shared.sd_model.model.named_modules():
lora_name = name.replace(".", "_")
lora_layer_mapping[lora_name] = module
module.lora_layer_name = lora_name
sd_model.lora_layer_mapping = lora_layer_mapping
def load_lora(name, filename):
lora = LoraModule(name)
lora.mtime = os.path.getmtime(filename)
sd = sd_models.read_state_dict(filename)
keys_failed_to_match = []
for key_diffusers, weight in sd.items():
fullkey = convert_diffusers_name_to_compvis(key_diffusers)
key, lora_key = fullkey.split(".", 1)
sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
if sd_module is None:
keys_failed_to_match.append(key_diffusers)
continue
lora_module = lora.modules.get(key, None)
if lora_module is None:
lora_module = LoraUpDownModule()
lora.modules[key] = lora_module
if lora_key == "alpha":
lora_module.alpha = weight.item()
continue
if type(sd_module) == torch.nn.Linear:
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
elif type(sd_module) == torch.nn.Conv2d:
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
else:
assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}'
with torch.no_grad():
module.weight.copy_(weight)
module.to(device=devices.device, dtype=devices.dtype)
if lora_key == "lora_up.weight":
lora_module.up = module
elif lora_key == "lora_down.weight":
lora_module.down = module
else:
assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha'
if len(keys_failed_to_match) > 0:
print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}")
return lora
def load_loras(names, multipliers=None):
already_loaded = {}
for lora in loaded_loras:
if lora.name in names:
already_loaded[lora.name] = lora
loaded_loras.clear()
loras_on_disk = [available_loras.get(name, None) for name in names]
if any([x is None for x in loras_on_disk]):
list_available_loras()
loras_on_disk = [available_loras.get(name, None) for name in names]
for i, name in enumerate(names):
lora = already_loaded.get(name, None)
lora_on_disk = loras_on_disk[i]
if lora_on_disk is not None:
if lora is None or os.path.getmtime(lora_on_disk.filename) > lora.mtime:
lora = load_lora(name, lora_on_disk.filename)
if lora is None:
print(f"Couldn't find Lora with name {name}")
continue
lora.multiplier = multipliers[i] if multipliers else 1.0
loaded_loras.append(lora)
def lora_forward(module, input, res):
if len(loaded_loras) == 0:
return res
lora_layer_name = getattr(module, 'lora_layer_name', None)
for lora in loaded_loras:
module = lora.modules.get(lora_layer_name, None)
if module is not None:
if shared.opts.lora_apply_to_outputs and res.shape == input.shape:
res = res + module.up(module.down(res)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
else:
res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
return res
def lora_Linear_forward(self, input):
return lora_forward(self, input, torch.nn.Linear_forward_before_lora(self, input))
def lora_Conv2d_forward(self, input):
return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora(self, input))
def list_available_loras():
available_loras.clear()
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
candidates = \
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.pt'), recursive=True) + \
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.safetensors'), recursive=True) + \
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.ckpt'), recursive=True)
for filename in sorted(candidates):
if os.path.isdir(filename):
continue
name = os.path.splitext(os.path.basename(filename))[0]
available_loras[name] = LoraOnDisk(name, filename)
available_loras = {}
loaded_loras = []
list_available_loras()
import os
from modules import paths
def preload(parser):
parser.add_argument("--lora-dir", type=str, help="Path to directory with Lora networks.", default=os.path.join(paths.models_path, 'Lora'))
import torch
import gradio as gr
import lora
import extra_networks_lora
import ui_extra_networks_lora
from modules import script_callbacks, ui_extra_networks, extra_networks, shared
def unload():
torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora
torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora
def before_ui():
ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora())
extra_networks.register_extra_network(extra_networks_lora.ExtraNetworkLora())
if not hasattr(torch.nn, 'Linear_forward_before_lora'):
torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward
if not hasattr(torch.nn, 'Conv2d_forward_before_lora'):
torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward
torch.nn.Linear.forward = lora.lora_Linear_forward
torch.nn.Conv2d.forward = lora.lora_Conv2d_forward
script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
script_callbacks.on_script_unloaded(unload)
script_callbacks.on_before_ui(before_ui)
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras),
"lora_apply_to_outputs": shared.OptionInfo(False, "Apply Lora to outputs rather than inputs when possible (experimental)"),
}))
import json
import os
import lora
from modules import shared, ui_extra_networks
class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
def __init__(self):
super().__init__('Lora')
def refresh(self):
lora.list_available_loras()
def list_items(self):
for name, lora_on_disk in lora.available_loras.items():
path, ext = os.path.splitext(lora_on_disk.filename)
previews = [path + ".png", path + ".preview.png"]
preview = None
for file in previews:
if os.path.isfile(file):
preview = self.link_preview(file)
break
yield {
"name": name,
"filename": path,
"preview": preview,
"search_term": self.search_terms_from_path(lora_on_disk.filename),
"prompt": json.dumps(f"<lora:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
"local_preview": path + ".png",
}
def allowed_directories_for_previews(self):
return [shared.cmd_opts.lora_dir]
import os
from modules import paths
def preload(parser):
parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(paths.models_path, 'ScuNET'))
...@@ -9,7 +9,7 @@ from basicsr.utils.download_util import load_file_from_url ...@@ -9,7 +9,7 @@ from basicsr.utils.download_util import load_file_from_url
import modules.upscaler import modules.upscaler
from modules import devices, modelloader from modules import devices, modelloader
from modules.scunet_model_arch import SCUNet as net from scunet_model_arch import SCUNet as net
class UpscalerScuNET(modules.upscaler.Upscaler): class UpscalerScuNET(modules.upscaler.Upscaler):
...@@ -49,12 +49,12 @@ class UpscalerScuNET(modules.upscaler.Upscaler): ...@@ -49,12 +49,12 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
if model is None: if model is None:
return img return img
device = devices.device_scunet device = devices.get_device_for('scunet')
img = np.array(img) img = np.array(img)
img = img[:, :, ::-1] img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255 img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float() img = torch.from_numpy(img).float()
img = devices.mps_contiguous_to(img.unsqueeze(0), device) img = img.unsqueeze(0).to(device)
with torch.no_grad(): with torch.no_grad():
output = model(img) output = model(img)
...@@ -66,7 +66,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler): ...@@ -66,7 +66,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
return PIL.Image.fromarray(output, 'RGB') return PIL.Image.fromarray(output, 'RGB')
def load_model(self, path: str): def load_model(self, path: str):
device = devices.device_scunet device = devices.get_device_for('scunet')
if "http" in path: if "http" in path:
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name, filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
progress=True) progress=True)
......
import os
from modules import paths
def preload(parser):
parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(paths.models_path, 'SwinIR'))
...@@ -7,15 +7,14 @@ from PIL import Image ...@@ -7,15 +7,14 @@ from PIL import Image
from basicsr.utils.download_util import load_file_from_url from basicsr.utils.download_util import load_file_from_url
from tqdm import tqdm from tqdm import tqdm
from modules import modelloader, devices from modules import modelloader, devices, script_callbacks, shared
from modules.shared import cmd_opts, opts from modules.shared import cmd_opts, opts, state
from modules.swinir_model_arch import SwinIR as net from swinir_model_arch import SwinIR as net
from modules.swinir_model_arch_v2 import Swin2SR as net2 from swinir_model_arch_v2 import Swin2SR as net2
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
precision_scope = (
torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext device_swinir = devices.get_device_for('swinir')
)
class UpscalerSwinIR(Upscaler): class UpscalerSwinIR(Upscaler):
...@@ -42,7 +41,7 @@ class UpscalerSwinIR(Upscaler): ...@@ -42,7 +41,7 @@ class UpscalerSwinIR(Upscaler):
model = self.load_model(model_file) model = self.load_model(model_file)
if model is None: if model is None:
return img return img
model = model.to(devices.device_swinir) model = model.to(device_swinir, dtype=devices.dtype)
img = upscale(img, model) img = upscale(img, model)
try: try:
torch.cuda.empty_cache() torch.cuda.empty_cache()
...@@ -94,25 +93,27 @@ class UpscalerSwinIR(Upscaler): ...@@ -94,25 +93,27 @@ class UpscalerSwinIR(Upscaler):
model.load_state_dict(pretrained_model[params], strict=True) model.load_state_dict(pretrained_model[params], strict=True)
else: else:
model.load_state_dict(pretrained_model, strict=True) model.load_state_dict(pretrained_model, strict=True)
if not cmd_opts.no_half:
model = model.half()
return model return model
def upscale( def upscale(
img, img,
model, model,
tile=opts.SWIN_tile, tile=None,
tile_overlap=opts.SWIN_tile_overlap, tile_overlap=None,
window_size=8, window_size=8,
scale=4, scale=4,
): ):
tile = tile or opts.SWIN_tile
tile_overlap = tile_overlap or opts.SWIN_tile_overlap
img = np.array(img) img = np.array(img)
img = img[:, :, ::-1] img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255 img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float() img = torch.from_numpy(img).float()
img = devices.mps_contiguous_to(img.unsqueeze(0), devices.device_swinir) img = img.unsqueeze(0).to(device_swinir, dtype=devices.dtype)
with torch.no_grad(), precision_scope("cuda"): with torch.no_grad(), devices.autocast():
_, _, h_old, w_old = img.size() _, _, h_old, w_old = img.size()
h_pad = (h_old // window_size + 1) * window_size - h_old h_pad = (h_old // window_size + 1) * window_size - h_old
w_pad = (w_old // window_size + 1) * window_size - w_old w_pad = (w_old // window_size + 1) * window_size - w_old
...@@ -139,12 +140,18 @@ def inference(img, model, tile, tile_overlap, window_size, scale): ...@@ -139,12 +140,18 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
stride = tile - tile_overlap stride = tile - tile_overlap
h_idx_list = list(range(0, h - tile, stride)) + [h - tile] h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
w_idx_list = list(range(0, w - tile, stride)) + [w - tile] w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=devices.device_swinir).type_as(img) E = torch.zeros(b, c, h * sf, w * sf, dtype=devices.dtype, device=device_swinir).type_as(img)
W = torch.zeros_like(E, dtype=torch.half, device=devices.device_swinir) W = torch.zeros_like(E, dtype=devices.dtype, device=device_swinir)
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar: with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
for h_idx in h_idx_list: for h_idx in h_idx_list:
if state.interrupted or state.skipped:
break
for w_idx in w_idx_list: for w_idx in w_idx_list:
if state.interrupted or state.skipped:
break
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
out_patch = model(in_patch) out_patch = model(in_patch)
out_patch_mask = torch.ones_like(out_patch) out_patch_mask = torch.ones_like(out_patch)
...@@ -159,3 +166,13 @@ def inference(img, model, tile, tile_overlap, window_size, scale): ...@@ -159,3 +166,13 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
output = E.div_(W) output = E.div_(W)
return output return output
def on_ui_settings():
import gradio as gr
shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")))
shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling")))
script_callbacks.on_ui_settings(on_ui_settings)
// Stable Diffusion WebUI - Bracket checker
// Version 1.0
// By Hingashi no Florin/Bwin4L
// Counts open and closed brackets (round, square, curly) in the prompt and negative prompt text boxes in the txt2img and img2img tabs.
// If there's a mismatch, the keyword counter turns red and if you hover on it, a tooltip tells you what's wrong.
function checkBrackets(evt, textArea, counterElt) {
errorStringParen = '(...) - Different number of opening and closing parentheses detected.\n';
errorStringSquare = '[...] - Different number of opening and closing square brackets detected.\n';
errorStringCurly = '{...} - Different number of opening and closing curly brackets detected.\n';
openBracketRegExp = /\(/g;
closeBracketRegExp = /\)/g;
openSquareBracketRegExp = /\[/g;
closeSquareBracketRegExp = /\]/g;
openCurlyBracketRegExp = /\{/g;
closeCurlyBracketRegExp = /\}/g;
totalOpenBracketMatches = 0;
totalCloseBracketMatches = 0;
totalOpenSquareBracketMatches = 0;
totalCloseSquareBracketMatches = 0;
totalOpenCurlyBracketMatches = 0;
totalCloseCurlyBracketMatches = 0;
openBracketMatches = textArea.value.match(openBracketRegExp);
if(openBracketMatches) {
totalOpenBracketMatches = openBracketMatches.length;
}
closeBracketMatches = textArea.value.match(closeBracketRegExp);
if(closeBracketMatches) {
totalCloseBracketMatches = closeBracketMatches.length;
}
openSquareBracketMatches = textArea.value.match(openSquareBracketRegExp);
if(openSquareBracketMatches) {
totalOpenSquareBracketMatches = openSquareBracketMatches.length;
}
closeSquareBracketMatches = textArea.value.match(closeSquareBracketRegExp);
if(closeSquareBracketMatches) {
totalCloseSquareBracketMatches = closeSquareBracketMatches.length;
}
openCurlyBracketMatches = textArea.value.match(openCurlyBracketRegExp);
if(openCurlyBracketMatches) {
totalOpenCurlyBracketMatches = openCurlyBracketMatches.length;
}
closeCurlyBracketMatches = textArea.value.match(closeCurlyBracketRegExp);
if(closeCurlyBracketMatches) {
totalCloseCurlyBracketMatches = closeCurlyBracketMatches.length;
}
if(totalOpenBracketMatches != totalCloseBracketMatches) {
if(!counterElt.title.includes(errorStringParen)) {
counterElt.title += errorStringParen;
}
} else {
counterElt.title = counterElt.title.replace(errorStringParen, '');
}
if(totalOpenSquareBracketMatches != totalCloseSquareBracketMatches) {
if(!counterElt.title.includes(errorStringSquare)) {
counterElt.title += errorStringSquare;
}
} else {
counterElt.title = counterElt.title.replace(errorStringSquare, '');
}
if(totalOpenCurlyBracketMatches != totalCloseCurlyBracketMatches) {
if(!counterElt.title.includes(errorStringCurly)) {
counterElt.title += errorStringCurly;
}
} else {
counterElt.title = counterElt.title.replace(errorStringCurly, '');
}
if(counterElt.title != '') {
counterElt.classList.add('error');
} else {
counterElt.classList.remove('error');
}
}
function setupBracketChecking(id_prompt, id_counter){
var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea");
var counter = gradioApp().getElementById(id_counter)
textarea.addEventListener("input", function(evt){
checkBrackets(evt, textarea, counter)
});
}
var shadowRootLoaded = setInterval(function() {
var shadowRoot = document.querySelector('gradio-app').shadowRoot;
if(! shadowRoot) return false;
var shadowTextArea = shadowRoot.querySelectorAll('#txt2img_prompt > label > textarea');
if(shadowTextArea.length < 1) return false;
clearInterval(shadowRootLoaded);
setupBracketChecking('txt2img_prompt', 'txt2img_token_counter')
setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter')
setupBracketChecking('img2img_prompt', 'imgimg_token_counter')
setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter')
}, 1000);
<div class='card' {preview_html} onclick={card_clicked}>
<div class='actions'>
<div class='additional'>
<ul>
<a href="#" title="replace preview image with currently selected in gallery" onclick={save_card_preview}>replace preview</a>
</ul>
<span style="display:none" class='search_term'>{search_term}</span>
</div>
<span class='name'>{name}</span>
</div>
</div>
<div class='nocards'>
<h1>Nothing here. Add some content to the following directories:</h1>
<ul>
{dirs}
</ul>
</div>
<div>
<a href="/docs">API</a>
 • 
<a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui">Github</a>
 • 
<a href="https://gradio.app">Gradio</a>
 • 
<a href="/" onclick="javascript:gradioApp().getElementById('settings_restart_gradio').click(); return false">Reload UI</a>
</div>
<br />
<div class="versions">
{versions}
</div>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24">
<filter id='shadow' color-interpolation-filters="sRGB">
<feDropShadow flood-color="black" dx="0" dy="0" flood-opacity="0.9" stdDeviation="0.5"/>
<feDropShadow flood-color="black" dx="0" dy="0" flood-opacity="0.9" stdDeviation="0.5"/>
</filter>
<path style="filter:url(#shadow);" fill="#FFFFFF" d="M13.18 19C13.35 19.72 13.64 20.39 14.03 21H5C3.9 21 3 20.11 3 19V5C3 3.9 3.9 3 5 3H19C20.11 3 21 3.9 21 5V11.18C20.5 11.07 20 11 19.5 11C19.33 11 19.17 11 19 11.03V5H5V19H13.18M11.21 15.83L9.25 13.47L6.5 17H13.03C13.14 15.54 13.73 14.22 14.64 13.19L13.96 12.29L11.21 15.83M19 13.5V12L16.75 14.25L19 16.5V15C20.38 15 21.5 16.12 21.5 17.5C21.5 17.9 21.41 18.28 21.24 18.62L22.33 19.71C22.75 19.08 23 18.32 23 17.5C23 15.29 21.21 13.5 19 13.5M19 20C17.62 20 16.5 18.88 16.5 17.5C16.5 17.1 16.59 16.72 16.76 16.38L15.67 15.29C15.25 15.92 15 16.68 15 17.5C15 19.71 16.79 21.5 19 21.5V23L21.25 20.75L19 18.5V20Z" />
</svg>
This diff is collapsed.
...@@ -21,11 +21,16 @@ function dimensionChange(e, is_width, is_height){ ...@@ -21,11 +21,16 @@ function dimensionChange(e, is_width, is_height){
var targetElement = null; var targetElement = null;
var tabIndex = get_tab_index('mode_img2img') var tabIndex = get_tab_index('mode_img2img')
if(tabIndex == 0){ if(tabIndex == 0){ // img2img
targetElement = gradioApp().querySelector('div[data-testid=image] img'); targetElement = gradioApp().querySelector('div[data-testid=image] img');
} else if(tabIndex == 1){ } else if(tabIndex == 1){ //Sketch
targetElement = gradioApp().querySelector('#img2img_sketch div[data-testid=image] img');
} else if(tabIndex == 2){ // Inpaint
targetElement = gradioApp().querySelector('#img2maskimg div[data-testid=image] img'); targetElement = gradioApp().querySelector('#img2maskimg div[data-testid=image] img');
} else if(tabIndex == 3){ // Inpaint sketch
targetElement = gradioApp().querySelector('#inpaint_sketch div[data-testid=image] img');
} }
if(targetElement){ if(targetElement){
......
...@@ -9,7 +9,7 @@ contextMenuInit = function(){ ...@@ -9,7 +9,7 @@ contextMenuInit = function(){
function showContextMenu(event,element,menuEntries){ function showContextMenu(event,element,menuEntries){
let posx = event.clientX + document.body.scrollLeft + document.documentElement.scrollLeft; let posx = event.clientX + document.body.scrollLeft + document.documentElement.scrollLeft;
let posy = event.clientY + document.body.scrollTop + document.documentElement.scrollTop; let posy = event.clientY + document.body.scrollTop + document.documentElement.scrollTop;
let oldMenu = gradioApp().querySelector('#context-menu') let oldMenu = gradioApp().querySelector('#context-menu')
if(oldMenu){ if(oldMenu){
...@@ -61,15 +61,15 @@ contextMenuInit = function(){ ...@@ -61,15 +61,15 @@ contextMenuInit = function(){
} }
function appendContextMenuOption(targetEmementSelector,entryName,entryFunction){ function appendContextMenuOption(targetElementSelector,entryName,entryFunction){
currentItems = menuSpecs.get(targetEmementSelector) currentItems = menuSpecs.get(targetElementSelector)
if(!currentItems){ if(!currentItems){
currentItems = [] currentItems = []
menuSpecs.set(targetEmementSelector,currentItems); menuSpecs.set(targetElementSelector,currentItems);
} }
let newItem = {'id':targetEmementSelector+'_'+uid(), let newItem = {'id':targetElementSelector+'_'+uid(),
'name':entryName, 'name':entryName,
'func':entryFunction, 'func':entryFunction,
'isNew':true} 'isNew':true}
...@@ -97,7 +97,7 @@ contextMenuInit = function(){ ...@@ -97,7 +97,7 @@ contextMenuInit = function(){
if(source.id && source.id.indexOf('check_progress')>-1){ if(source.id && source.id.indexOf('check_progress')>-1){
return return
} }
let oldMenu = gradioApp().querySelector('#context-menu') let oldMenu = gradioApp().querySelector('#context-menu')
if(oldMenu){ if(oldMenu){
oldMenu.remove() oldMenu.remove()
...@@ -117,7 +117,7 @@ contextMenuInit = function(){ ...@@ -117,7 +117,7 @@ contextMenuInit = function(){
}) })
}); });
eventListenerApplied=true eventListenerApplied=true
} }
return [appendContextMenuOption, removeContextMenuOption, addContextMenuEventListener] return [appendContextMenuOption, removeContextMenuOption, addContextMenuEventListener]
...@@ -152,8 +152,8 @@ addContextMenuEventListener = initResponse[2]; ...@@ -152,8 +152,8 @@ addContextMenuEventListener = initResponse[2];
generateOnRepeat('#img2img_generate','#img2img_interrupt'); generateOnRepeat('#img2img_generate','#img2img_interrupt');
}) })
let cancelGenerateForever = function(){ let cancelGenerateForever = function(){
clearInterval(window.generateOnRepeatInterval) clearInterval(window.generateOnRepeatInterval)
} }
appendContextMenuOption('#txt2img_interrupt','Cancel generate forever',cancelGenerateForever) appendContextMenuOption('#txt2img_interrupt','Cancel generate forever',cancelGenerateForever)
...@@ -162,7 +162,7 @@ addContextMenuEventListener = initResponse[2]; ...@@ -162,7 +162,7 @@ addContextMenuEventListener = initResponse[2];
appendContextMenuOption('#img2img_generate', 'Cancel generate forever',cancelGenerateForever) appendContextMenuOption('#img2img_generate', 'Cancel generate forever',cancelGenerateForever)
appendContextMenuOption('#roll','Roll three', appendContextMenuOption('#roll','Roll three',
function(){ function(){
let rollbutton = get_uiCurrentTabContent().querySelector('#roll'); let rollbutton = get_uiCurrentTabContent().querySelector('#roll');
setTimeout(function(){rollbutton.click()},100) setTimeout(function(){rollbutton.click()},100)
setTimeout(function(){rollbutton.click()},200) setTimeout(function(){rollbutton.click()},200)
......
...@@ -9,11 +9,19 @@ function dropReplaceImage( imgWrap, files ) { ...@@ -9,11 +9,19 @@ function dropReplaceImage( imgWrap, files ) {
return; return;
} }
const tmpFile = files[0];
imgWrap.querySelector('.modify-upload button + button, .touch-none + div button + button')?.click(); imgWrap.querySelector('.modify-upload button + button, .touch-none + div button + button')?.click();
const callback = () => { const callback = () => {
const fileInput = imgWrap.querySelector('input[type="file"]'); const fileInput = imgWrap.querySelector('input[type="file"]');
if ( fileInput ) { if ( fileInput ) {
fileInput.files = files; if ( files.length === 0 ) {
files = new DataTransfer();
files.items.add(tmpFile);
fileInput.files = files.files;
} else {
fileInput.files = files;
}
fileInput.dispatchEvent(new Event('change')); fileInput.dispatchEvent(new Event('change'));
} }
}; };
......
addEventListener('keydown', (event) => { function keyupEditAttention(event){
let target = event.originalTarget || event.composedPath()[0]; let target = event.originalTarget || event.composedPath()[0];
if (!target.matches("#toprow textarea.gr-text-input[placeholder]")) return; if (!target.matches("[id*='_toprow'] textarea.gr-text-input[placeholder]")) return;
if (! (event.metaKey || event.ctrlKey)) return; if (! (event.metaKey || event.ctrlKey)) return;
let isPlus = event.key == "ArrowUp"
let plus = "ArrowUp" let isMinus = event.key == "ArrowDown"
let minus = "ArrowDown" if (!isPlus && !isMinus) return;
if (event.key != plus && event.key != minus) return;
let selectionStart = target.selectionStart; let selectionStart = target.selectionStart;
let selectionEnd = target.selectionEnd; let selectionEnd = target.selectionEnd;
// If the user hasn't selected anything, let's select their current parenthesis block let text = target.value;
if (selectionStart === selectionEnd) {
function selectCurrentParenthesisBlock(OPEN, CLOSE){
if (selectionStart !== selectionEnd) return false;
// Find opening parenthesis around current cursor // Find opening parenthesis around current cursor
const before = target.value.substring(0, selectionStart); const before = text.substring(0, selectionStart);
let beforeParen = before.lastIndexOf("("); let beforeParen = before.lastIndexOf(OPEN);
if (beforeParen == -1) return; if (beforeParen == -1) return false;
let beforeParenClose = before.lastIndexOf(")"); let beforeParenClose = before.lastIndexOf(CLOSE);
while (beforeParenClose !== -1 && beforeParenClose > beforeParen) { while (beforeParenClose !== -1 && beforeParenClose > beforeParen) {
beforeParen = before.lastIndexOf("(", beforeParen - 1); beforeParen = before.lastIndexOf(OPEN, beforeParen - 1);
beforeParenClose = before.lastIndexOf(")", beforeParenClose - 1); beforeParenClose = before.lastIndexOf(CLOSE, beforeParenClose - 1);
} }
// Find closing parenthesis around current cursor // Find closing parenthesis around current cursor
const after = target.value.substring(selectionStart); const after = text.substring(selectionStart);
let afterParen = after.indexOf(")"); let afterParen = after.indexOf(CLOSE);
if (afterParen == -1) return; if (afterParen == -1) return false;
let afterParenOpen = after.indexOf("("); let afterParenOpen = after.indexOf(OPEN);
while (afterParenOpen !== -1 && afterParen > afterParenOpen) { while (afterParenOpen !== -1 && afterParen > afterParenOpen) {
afterParen = after.indexOf(")", afterParen + 1); afterParen = after.indexOf(CLOSE, afterParen + 1);
afterParenOpen = after.indexOf("(", afterParenOpen + 1); afterParenOpen = after.indexOf(OPEN, afterParenOpen + 1);
} }
if (beforeParen === -1 || afterParen === -1) return; if (beforeParen === -1 || afterParen === -1) return false;
// Set the selection to the text between the parenthesis // Set the selection to the text between the parenthesis
const parenContent = target.value.substring(beforeParen + 1, selectionStart + afterParen); const parenContent = text.substring(beforeParen + 1, selectionStart + afterParen);
const lastColon = parenContent.lastIndexOf(":"); const lastColon = parenContent.lastIndexOf(":");
selectionStart = beforeParen + 1; selectionStart = beforeParen + 1;
selectionEnd = selectionStart + lastColon; selectionEnd = selectionStart + lastColon;
target.setSelectionRange(selectionStart, selectionEnd); target.setSelectionRange(selectionStart, selectionEnd);
} return true;
}
// If the user hasn't selected anything, let's select their current parenthesis block
if(! selectCurrentParenthesisBlock('<', '>')){
selectCurrentParenthesisBlock('(', ')')
}
event.preventDefault(); event.preventDefault();
if (selectionStart == 0 || target.value[selectionStart - 1] != "(") { closeCharacter = ')'
target.value = target.value.slice(0, selectionStart) + delta = opts.keyedit_precision_attention
"(" + target.value.slice(selectionStart, selectionEnd) + ":1.0)" +
target.value.slice(selectionEnd); if (selectionStart > 0 && text[selectionStart - 1] == '<'){
closeCharacter = '>'
target.focus(); delta = opts.keyedit_precision_extra
target.selectionStart = selectionStart + 1; } else if (selectionStart == 0 || text[selectionStart - 1] != "(") {
target.selectionEnd = selectionEnd + 1;
// do not include spaces at the end
} else { while(selectionEnd > selectionStart && text[selectionEnd-1] == ' '){
end = target.value.slice(selectionEnd + 1).indexOf(")") + 1; selectionEnd -= 1;
weight = parseFloat(target.value.slice(selectionEnd + 1, selectionEnd + 1 + end)); }
if (isNaN(weight)) return; if(selectionStart == selectionEnd){
if (event.key == minus) weight -= 0.1; return
if (event.key == plus) weight += 0.1; }
weight = parseFloat(weight.toPrecision(12)); text = text.slice(0, selectionStart) + "(" + text.slice(selectionStart, selectionEnd) + ":1.0)" + text.slice(selectionEnd);
target.value = target.value.slice(0, selectionEnd + 1) + selectionStart += 1;
weight + selectionEnd += 1;
target.value.slice(selectionEnd + 1 + end - 1); }
target.focus(); end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1;
target.selectionStart = selectionStart; weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end));
target.selectionEnd = selectionEnd; if (isNaN(weight)) return;
}
// Since we've modified a Gradio Textbox component manually, we need to simulate an `input` DOM event to ensure its weight += isPlus ? delta : -delta;
// internal Svelte data binding remains in sync. weight = parseFloat(weight.toPrecision(12));
target.dispatchEvent(new Event("input", { bubbles: true })); if(String(weight).length == 1) weight += ".0"
});
text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1);
target.focus();
target.value = text;
target.selectionStart = selectionStart;
target.selectionEnd = selectionEnd;
updateInput(target)
}
addEventListener('keydown', (event) => {
keyupEditAttention(event);
});
\ No newline at end of file
function extensions_apply(_, _){ function extensions_apply(_, _){
disable = [] var disable = []
update = [] var update = []
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){ gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
if(x.name.startsWith("enable_") && ! x.checked) if(x.name.startsWith("enable_") && ! x.checked)
disable.push(x.name.substr(7)) disable.push(x.name.substr(7))
...@@ -16,11 +17,24 @@ function extensions_apply(_, _){ ...@@ -16,11 +17,24 @@ function extensions_apply(_, _){
} }
function extensions_check(){ function extensions_check(){
var disable = []
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
if(x.name.startsWith("enable_") && ! x.checked)
disable.push(x.name.substr(7))
})
gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){ gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){
x.innerHTML = "Loading..." x.innerHTML = "Loading..."
}) })
return []
var id = randomId()
requestProgress(id, gradioApp().getElementById('extensions_installed_top'), null, function(){
})
return [id, JSON.stringify(disable)]
} }
function install_extension_from_index(button, url){ function install_extension_from_index(button, url){
...@@ -29,7 +43,7 @@ function install_extension_from_index(button, url){ ...@@ -29,7 +43,7 @@ function install_extension_from_index(button, url){
textarea = gradioApp().querySelector('#extension_to_install textarea') textarea = gradioApp().querySelector('#extension_to_install textarea')
textarea.value = url textarea.value = url
textarea.dispatchEvent(new Event("input", { bubbles: true })) updateInput(textarea)
gradioApp().querySelector('#install_extension_button').click() gradioApp().querySelector('#install_extension_button').click()
} }
function setupExtraNetworksForTab(tabname){
gradioApp().querySelector('#'+tabname+'_extra_tabs').classList.add('extra-networks')
var tabs = gradioApp().querySelector('#'+tabname+'_extra_tabs > div')
var search = gradioApp().querySelector('#'+tabname+'_extra_search textarea')
var refresh = gradioApp().getElementById(tabname+'_extra_refresh')
var close = gradioApp().getElementById(tabname+'_extra_close')
search.classList.add('search')
tabs.appendChild(search)
tabs.appendChild(refresh)
tabs.appendChild(close)
search.addEventListener("input", function(evt){
searchTerm = search.value.toLowerCase()
gradioApp().querySelectorAll('#'+tabname+'_extra_tabs div.card').forEach(function(elem){
text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase()
elem.style.display = text.indexOf(searchTerm) == -1 ? "none" : ""
})
});
}
var activePromptTextarea = {};
function setupExtraNetworks(){
setupExtraNetworksForTab('txt2img')
setupExtraNetworksForTab('img2img')
function registerPrompt(tabname, id){
var textarea = gradioApp().querySelector("#" + id + " > label > textarea");
if (! activePromptTextarea[tabname]){
activePromptTextarea[tabname] = textarea
}
textarea.addEventListener("focus", function(){
activePromptTextarea[tabname] = textarea;
});
}
registerPrompt('txt2img', 'txt2img_prompt')
registerPrompt('txt2img', 'txt2img_neg_prompt')
registerPrompt('img2img', 'img2img_prompt')
registerPrompt('img2img', 'img2img_neg_prompt')
}
onUiLoaded(setupExtraNetworks)
var re_extranet = /<([^:]+:[^:]+):[\d\.]+>/;
var re_extranet_g = /\s+<([^:]+:[^:]+):[\d\.]+>/g;
function tryToRemoveExtraNetworkFromPrompt(textarea, text){
var m = text.match(re_extranet)
if(! m) return false
var partToSearch = m[1]
var replaced = false
var newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found, index){
m = found.match(re_extranet);
if(m[1] == partToSearch){
replaced = true;
return ""
}
return found;
})
if(replaced){
textarea.value = newTextareaText
return true;
}
return false
}
function cardClicked(tabname, textToAdd, allowNegativePrompt){
var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea")
if(! tryToRemoveExtraNetworkFromPrompt(textarea, textToAdd)){
textarea.value = textarea.value + " " + textToAdd
}
updateInput(textarea)
}
function saveCardPreview(event, tabname, filename){
var textarea = gradioApp().querySelector("#" + tabname + '_preview_filename > label > textarea')
var button = gradioApp().getElementById(tabname + '_save_preview')
textarea.value = filename
updateInput(textarea)
button.click()
event.stopPropagation()
event.preventDefault()
}
function extraNetworksSearchButton(tabs_id, event){
searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > div > textarea')
button = event.target
text = button.classList.contains("search-all") ? "" : button.textContent.trim()
searchTextarea.value = text
updateInput(searchTextarea)
}
\ No newline at end of file
This diff is collapsed.
function setInactive(elem, inactive){
if(inactive){
elem.classList.add('inactive')
} else{
elem.classList.remove('inactive')
}
}
function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y){
hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale')
hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x')
hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y')
gradioApp().getElementById('txt2img_hires_fix_row2').style.display = opts.use_old_hires_fix_width_height ? "none" : ""
setInactive(hrUpscaleBy, opts.use_old_hires_fix_width_height || hr_resize_x > 0 || hr_resize_y > 0)
setInactive(hrResizeX, opts.use_old_hires_fix_width_height || hr_resize_x == 0)
setInactive(hrResizeY, opts.use_old_hires_fix_width_height || hr_resize_y == 0)
return [enable, width, height, hr_scale, hr_resize_x, hr_resize_y]
}
...@@ -148,9 +148,18 @@ function showGalleryImage() { ...@@ -148,9 +148,18 @@ function showGalleryImage() {
if(e && e.parentElement.tagName == 'DIV'){ if(e && e.parentElement.tagName == 'DIV'){
e.style.cursor='pointer' e.style.cursor='pointer'
e.style.userSelect='none' e.style.userSelect='none'
e.addEventListener('click', function (evt) {
if(!opts.js_modal_lightbox) return; var isFirefox = isFirefox = navigator.userAgent.toLowerCase().indexOf('firefox') > -1
// For Firefox, listening on click first switched to next image then shows the lightbox.
// If you know how to fix this without switching to mousedown event, please.
// For other browsers the event is click to make it possiblr to drag picture.
var event = isFirefox ? 'mousedown' : 'click'
e.addEventListener(event, function (evt) {
if(!opts.js_modal_lightbox || evt.button != 0) return;
modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed) modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed)
evt.preventDefault()
showModal(evt) showModal(evt)
}, true); }, true);
} }
......
...@@ -10,10 +10,8 @@ ignore_ids_for_localization={ ...@@ -10,10 +10,8 @@ ignore_ids_for_localization={
modelmerger_tertiary_model_name: 'OPTION', modelmerger_tertiary_model_name: 'OPTION',
train_embedding: 'OPTION', train_embedding: 'OPTION',
train_hypernetwork: 'OPTION', train_hypernetwork: 'OPTION',
txt2img_style_index: 'OPTION', txt2img_styles: 'OPTION',
txt2img_style2_index: 'OPTION', img2img_styles: 'OPTION',
img2img_style_index: 'OPTION',
img2img_style2_index: 'OPTION',
setting_random_artist_categories: 'SPAN', setting_random_artist_categories: 'SPAN',
setting_face_restoration_model: 'SPAN', setting_face_restoration_model: 'SPAN',
setting_realesrgan_enabled_models: 'SPAN', setting_realesrgan_enabled_models: 'SPAN',
......
...@@ -15,7 +15,7 @@ onUiUpdate(function(){ ...@@ -15,7 +15,7 @@ onUiUpdate(function(){
} }
} }
const galleryPreviews = gradioApp().querySelectorAll('img.h-full.w-full.overflow-hidden'); const galleryPreviews = gradioApp().querySelectorAll('div[id^="tab_"][style*="display: block"] img.h-full.w-full.overflow-hidden');
if (galleryPreviews == null) return; if (galleryPreviews == null) return;
......
This diff is collapsed.
function start_training_textual_inversion(){ function start_training_textual_inversion(){
requestProgress('ti')
gradioApp().querySelector('#ti_error').innerHTML='' gradioApp().querySelector('#ti_error').innerHTML=''
return args_to_array(arguments) var id = randomId()
requestProgress(id, gradioApp().getElementById('ti_output'), gradioApp().getElementById('ti_gallery'), function(){}, function(progress){
gradioApp().getElementById('ti_progress').innerHTML = progress.textinfo
})
var res = args_to_array(arguments)
res[0] = id
return res
} }
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -100,13 +100,13 @@ class PydanticModelGenerator: ...@@ -100,13 +100,13 @@ class PydanticModelGenerator:
StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator( StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
"StableDiffusionProcessingTxt2Img", "StableDiffusionProcessingTxt2Img",
StableDiffusionProcessingTxt2Img, StableDiffusionProcessingTxt2Img,
[{"key": "sampler_index", "type": str, "default": "Euler"}] [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}]
).generate_model() ).generate_model()
StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator( StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
"StableDiffusionProcessingImg2Img", "StableDiffusionProcessingImg2Img",
StableDiffusionProcessingImg2Img, StableDiffusionProcessingImg2Img,
[{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}] [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}]
).generate_model() ).generate_model()
class TextToImageResponse(BaseModel): class TextToImageResponse(BaseModel):
...@@ -125,10 +125,10 @@ class ExtrasBaseRequest(BaseModel): ...@@ -125,10 +125,10 @@ class ExtrasBaseRequest(BaseModel):
gfpgan_visibility: float = Field(default=0, title="GFPGAN Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of GFPGAN, values should be between 0 and 1.") gfpgan_visibility: float = Field(default=0, title="GFPGAN Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of GFPGAN, values should be between 0 and 1.")
codeformer_visibility: float = Field(default=0, title="CodeFormer Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of CodeFormer, values should be between 0 and 1.") codeformer_visibility: float = Field(default=0, title="CodeFormer Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of CodeFormer, values should be between 0 and 1.")
codeformer_weight: float = Field(default=0, title="CodeFormer Weight", ge=0, le=1, allow_inf_nan=False, description="Sets the weight of CodeFormer, values should be between 0 and 1.") codeformer_weight: float = Field(default=0, title="CodeFormer Weight", ge=0, le=1, allow_inf_nan=False, description="Sets the weight of CodeFormer, values should be between 0 and 1.")
upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=4, description="By how much to upscale the image, only used when resize_mode=0.") upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=8, description="By how much to upscale the image, only used when resize_mode=0.")
upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.") upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.")
upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.") upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.")
upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the choosen size?") upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the chosen size?")
upscaler_1: str = Field(default="None", title="Main upscaler", description=f"The name of the main upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}") upscaler_1: str = Field(default="None", title="Main upscaler", description=f"The name of the main upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
upscaler_2: str = Field(default="None", title="Secondary upscaler", description=f"The name of the secondary upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}") upscaler_2: str = Field(default="None", title="Secondary upscaler", description=f"The name of the secondary upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
extras_upscaler_2_visibility: float = Field(default=0, title="Secondary upscaler visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of secondary upscaler, values should be between 0 and 1.") extras_upscaler_2_visibility: float = Field(default=0, title="Secondary upscaler visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of secondary upscaler, values should be between 0 and 1.")
...@@ -157,7 +157,8 @@ class PNGInfoRequest(BaseModel): ...@@ -157,7 +157,8 @@ class PNGInfoRequest(BaseModel):
image: str = Field(title="Image", description="The base64 encoded PNG image") image: str = Field(title="Image", description="The base64 encoded PNG image")
class PNGInfoResponse(BaseModel): class PNGInfoResponse(BaseModel):
info: str = Field(title="Image info", description="A string with all the info the image had") info: str = Field(title="Image info", description="A string with the parameters used to generate the image")
items: dict = Field(title="Items", description="An object containing all the info the image had")
class ProgressRequest(BaseModel): class ProgressRequest(BaseModel):
skip_current_image: bool = Field(default=False, title="Skip current image", description="Skip current image serialization") skip_current_image: bool = Field(default=False, title="Skip current image", description="Skip current image serialization")
...@@ -167,6 +168,7 @@ class ProgressResponse(BaseModel): ...@@ -167,6 +168,7 @@ class ProgressResponse(BaseModel):
eta_relative: float = Field(title="ETA in secs") eta_relative: float = Field(title="ETA in secs")
state: dict = Field(title="State", description="The current state snapshot") state: dict = Field(title="State", description="The current state snapshot")
current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.") current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
textinfo: str = Field(default=None, title="Info text", description="Info text used by WebUI.")
class InterrogateRequest(BaseModel): class InterrogateRequest(BaseModel):
image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.") image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
...@@ -175,6 +177,15 @@ class InterrogateRequest(BaseModel): ...@@ -175,6 +177,15 @@ class InterrogateRequest(BaseModel):
class InterrogateResponse(BaseModel): class InterrogateResponse(BaseModel):
caption: str = Field(default=None, title="Caption", description="The generated caption for the image.") caption: str = Field(default=None, title="Caption", description="The generated caption for the image.")
class TrainResponse(BaseModel):
info: str = Field(title="Train info", description="Response string from train embedding or hypernetwork task.")
class CreateResponse(BaseModel):
info: str = Field(title="Create info", description="Response string from create embedding or hypernetwork task.")
class PreprocessResponse(BaseModel):
info: str = Field(title="Preprocess info", description="Response string from preprocessing task.")
fields = {} fields = {}
for key, metadata in opts.data_labels.items(): for key, metadata in opts.data_labels.items():
value = opts.data.get(key) value = opts.data.get(key)
...@@ -209,13 +220,15 @@ class UpscalerItem(BaseModel): ...@@ -209,13 +220,15 @@ class UpscalerItem(BaseModel):
model_name: Optional[str] = Field(title="Model Name") model_name: Optional[str] = Field(title="Model Name")
model_path: Optional[str] = Field(title="Path") model_path: Optional[str] = Field(title="Path")
model_url: Optional[str] = Field(title="URL") model_url: Optional[str] = Field(title="URL")
scale: Optional[float] = Field(title="Scale")
class SDModelItem(BaseModel): class SDModelItem(BaseModel):
title: str = Field(title="Title") title: str = Field(title="Title")
model_name: str = Field(title="Model Name") model_name: str = Field(title="Model Name")
hash: str = Field(title="Hash") hash: Optional[str] = Field(title="Short hash")
sha256: Optional[str] = Field(title="sha256 hash")
filename: str = Field(title="Filename") filename: str = Field(title="Filename")
config: str = Field(title="Config file") config: Optional[str] = Field(title="Config file")
class HypernetworkItem(BaseModel): class HypernetworkItem(BaseModel):
name: str = Field(title="Name") name: str = Field(title="Name")
...@@ -240,3 +253,17 @@ class ArtistItem(BaseModel): ...@@ -240,3 +253,17 @@ class ArtistItem(BaseModel):
score: float = Field(title="Score") score: float = Field(title="Score")
category: str = Field(title="Category") category: str = Field(title="Category")
class EmbeddingItem(BaseModel):
step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available")
sd_checkpoint: Optional[str] = Field(title="SD Checkpoint", description="The hash of the checkpoint this embedding was trained on, if available")
sd_checkpoint_name: Optional[str] = Field(title="SD Checkpoint Name", description="The name of the checkpoint this embedding was trained on, if available. Note that this is the name that was used by the trainer; for a stable identifier, use `sd_checkpoint` instead")
shape: int = Field(title="Shape", description="The length of each individual vector in the embedding")
vectors: int = Field(title="Vectors", description="The number of vectors in the embedding")
class EmbeddingsResponse(BaseModel):
loaded: Dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model")
skipped: Dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")
class MemoryResponse(BaseModel):
ram: dict = Field(title="RAM", description="System memory stats")
cuda: dict = Field(title="CUDA", description="nVidia CUDA memory stats")
import os.path
import csv
from collections import namedtuple
Artist = namedtuple("Artist", ['name', 'weight', 'category'])
class ArtistsDatabase:
def __init__(self, filename):
self.cats = set()
self.artists = []
if not os.path.exists(filename):
return
with open(filename, "r", newline='', encoding="utf8") as file:
reader = csv.DictReader(file)
for row in reader:
artist = Artist(row["artist"], float(row["score"]), row["category"])
self.artists.append(artist)
self.cats.add(artist.category)
def categories(self):
return sorted(self.cats)
import html
import sys
import threading
import traceback
import time
from modules import shared, progress
queue_lock = threading.Lock()
def wrap_queued_call(func):
def f(*args, **kwargs):
with queue_lock:
res = func(*args, **kwargs)
return res
return f
def wrap_gradio_gpu_call(func, extra_outputs=None):
def f(*args, **kwargs):
# if the first argument is a string that says "task(...)", it is treated as a job id
if len(args) > 0 and type(args[0]) == str and args[0][0:5] == "task(" and args[0][-1] == ")":
id_task = args[0]
progress.add_task_to_queue(id_task)
else:
id_task = None
with queue_lock:
shared.state.begin()
progress.start_task(id_task)
try:
res = func(*args, **kwargs)
finally:
progress.finish_task(id_task)
shared.state.end()
return res
return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True)
def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
if run_memmon:
shared.mem_mon.monitor()
t = time.perf_counter()
try:
res = list(func(*args, **kwargs))
except Exception as e:
# When printing out our debug argument list, do not print out more than a MB of text
max_debug_str_len = 131072 # (1024*1024)/8
print("Error completing request", file=sys.stderr)
argStr = f"Arguments: {str(args)} {str(kwargs)}"
print(argStr[:max_debug_str_len], file=sys.stderr)
if len(argStr) > max_debug_str_len:
print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
shared.state.job = ""
shared.state.job_count = 0
if extra_outputs_array is None:
extra_outputs_array = [None, '']
res = extra_outputs_array + [f"<div class='error'>{html.escape(type(e).__name__+': '+str(e))}</div>"]
shared.state.skipped = False
shared.state.interrupted = False
shared.state.job_count = 0
if not add_stats:
return tuple(res)
elapsed = time.perf_counter() - t
elapsed_m = int(elapsed // 60)
elapsed_s = elapsed % 60
elapsed_text = f"{elapsed_s:.2f}s"
if elapsed_m > 0:
elapsed_text = f"{elapsed_m}m "+elapsed_text
if run_memmon:
mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
active_peak = mem_stats['active_peak']
reserved_peak = mem_stats['reserved_peak']
sys_peak = mem_stats['system_peak']
sys_total = mem_stats['total']
sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2)
vram_html = f"<p class='vram'>Torch active/reserved: {active_peak}/{reserved_peak} MiB, <wbr>Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)</p>"
else:
vram_html = ''
# last item is always HTML
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>"
return tuple(res)
return f
...@@ -382,7 +382,7 @@ class VQAutoEncoder(nn.Module): ...@@ -382,7 +382,7 @@ class VQAutoEncoder(nn.Module):
self.load_state_dict(torch.load(model_path, map_location='cpu')['params']) self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
logger.info(f'vqgan is loaded from: {model_path} [params]') logger.info(f'vqgan is loaded from: {model_path} [params]')
else: else:
raise ValueError(f'Wrong params!') raise ValueError('Wrong params!')
def forward(self, x): def forward(self, x):
...@@ -431,7 +431,7 @@ class VQGANDiscriminator(nn.Module): ...@@ -431,7 +431,7 @@ class VQGANDiscriminator(nn.Module):
elif 'params' in chkpt: elif 'params' in chkpt:
self.load_state_dict(torch.load(model_path, map_location='cpu')['params']) self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
else: else:
raise ValueError(f'Wrong params!') raise ValueError('Wrong params!')
def forward(self, x): def forward(self, x):
return self.main(x) return self.main(x)
\ No newline at end of file
...@@ -8,7 +8,7 @@ import torch ...@@ -8,7 +8,7 @@ import torch
import modules.face_restoration import modules.face_restoration
import modules.shared import modules.shared
from modules import shared, devices, modelloader from modules import shared, devices, modelloader
from modules.paths import script_path, models_path from modules.paths import models_path
# codeformer people made a choice to include modified basicsr library to their project which makes # codeformer people made a choice to include modified basicsr library to their project which makes
# it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN. # it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN.
...@@ -36,6 +36,7 @@ def setup_model(dirname): ...@@ -36,6 +36,7 @@ def setup_model(dirname):
from basicsr.utils.download_util import load_file_from_url from basicsr.utils.download_util import load_file_from_url
from basicsr.utils import imwrite, img2tensor, tensor2img from basicsr.utils import imwrite, img2tensor, tensor2img
from facelib.utils.face_restoration_helper import FaceRestoreHelper from facelib.utils.face_restoration_helper import FaceRestoreHelper
from facelib.detection.retinaface import retinaface
from modules.shared import cmd_opts from modules.shared import cmd_opts
net_class = CodeFormer net_class = CodeFormer
...@@ -65,6 +66,8 @@ def setup_model(dirname): ...@@ -65,6 +66,8 @@ def setup_model(dirname):
net.load_state_dict(checkpoint) net.load_state_dict(checkpoint)
net.eval() net.eval()
if hasattr(retinaface, 'device'):
retinaface.device = devices.device_codeformer
face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=devices.device_codeformer) face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=devices.device_codeformer)
self.net = net self.net = net
......
...@@ -21,7 +21,7 @@ class DeepDanbooru: ...@@ -21,7 +21,7 @@ class DeepDanbooru:
files = modelloader.load_models( files = modelloader.load_models(
model_path=os.path.join(paths.models_path, "torch_deepdanbooru"), model_path=os.path.join(paths.models_path, "torch_deepdanbooru"),
model_url='https://github.com/AUTOMATIC1111/TorchDeepDanbooru/releases/download/v1/model-resnet_custom_v3.pt', model_url='https://github.com/AUTOMATIC1111/TorchDeepDanbooru/releases/download/v1/model-resnet_custom_v3.pt',
ext_filter=".pt", ext_filter=[".pt"],
download_name='model-resnet_custom_v3.pt', download_name='model-resnet_custom_v3.pt',
) )
...@@ -58,7 +58,7 @@ class DeepDanbooru: ...@@ -58,7 +58,7 @@ class DeepDanbooru:
a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255 a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255
with torch.no_grad(), devices.autocast(): with torch.no_grad(), devices.autocast():
x = torch.from_numpy(a).cuda() x = torch.from_numpy(a).to(devices.device)
y = self.model(x)[0].detach().cpu().numpy() y = self.model(x)[0].detach().cpu().numpy()
probability_dict = {} probability_dict = {}
...@@ -79,7 +79,9 @@ class DeepDanbooru: ...@@ -79,7 +79,9 @@ class DeepDanbooru:
res = [] res = []
for tag in tags: filtertags = set([x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")])
for tag in [x for x in tags if x not in filtertags]:
probability = probability_dict[tag] probability = probability_dict[tag]
tag_outformat = tag tag_outformat = tag
if use_spaces: if use_spaces:
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
# this file is adapted from https://github.com/victorca25/iNNfer # this file is adapted from https://github.com/victorca25/iNNfer
from collections import OrderedDict
import math import math
import functools import functools
import torch import torch
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment