Commit 87535fcf authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub

Merge branch 'dev' into Branch_AddNewFilenameGen

parents 02e35188 1ffb44b0
...@@ -18,7 +18,7 @@ jobs: ...@@ -18,7 +18,7 @@ jobs:
cache-dependency-path: | cache-dependency-path: |
**/requirements*txt **/requirements*txt
- name: Run tests - name: Run tests
run: python launch.py --tests --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test run: python launch.py --tests test --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test
- name: Upload main app stdout-stderr - name: Upload main app stdout-stderr
uses: actions/upload-artifact@v3 uses: actions/upload-artifact@v3
if: always() if: always()
......
...@@ -32,4 +32,4 @@ notification.mp3 ...@@ -32,4 +32,4 @@ notification.mp3
/extensions /extensions
/test/stdout.txt /test/stdout.txt
/test/stderr.txt /test/stderr.txt
/cache.json /cache.json*
...@@ -13,9 +13,9 @@ A browser interface based on Gradio library for Stable Diffusion. ...@@ -13,9 +13,9 @@ A browser interface based on Gradio library for Stable Diffusion.
- Prompt Matrix - Prompt Matrix
- Stable Diffusion Upscale - Stable Diffusion Upscale
- Attention, specify parts of text that the model should pay more attention to - Attention, specify parts of text that the model should pay more attention to
- a man in a ((tuxedo)) - will pay more attention to tuxedo - a man in a `((tuxedo))` - will pay more attention to tuxedo
- a man in a (tuxedo:1.21) - alternative syntax - a man in a `(tuxedo:1.21)` - alternative syntax
- select text and press ctrl+up or ctrl+down to automatically adjust attention to selected text (code contributed by anonymous user) - select text and press `Ctrl+Up` or `Ctrl+Down` to automatically adjust attention to selected text (code contributed by anonymous user)
- Loopback, run img2img processing multiple times - Loopback, run img2img processing multiple times
- X/Y/Z plot, a way to draw a 3 dimensional plot of images with different parameters - X/Y/Z plot, a way to draw a 3 dimensional plot of images with different parameters
- Textual Inversion - Textual Inversion
...@@ -28,7 +28,7 @@ A browser interface based on Gradio library for Stable Diffusion. ...@@ -28,7 +28,7 @@ A browser interface based on Gradio library for Stable Diffusion.
- CodeFormer, face restoration tool as an alternative to GFPGAN - CodeFormer, face restoration tool as an alternative to GFPGAN
- RealESRGAN, neural network upscaler - RealESRGAN, neural network upscaler
- ESRGAN, neural network upscaler with a lot of third party models - ESRGAN, neural network upscaler with a lot of third party models
- SwinIR and Swin2SR([see here](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/2092)), neural network upscalers - SwinIR and Swin2SR ([see here](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/2092)), neural network upscalers
- LDSR, Latent diffusion super resolution upscaling - LDSR, Latent diffusion super resolution upscaling
- Resizing aspect ratio options - Resizing aspect ratio options
- Sampling method selection - Sampling method selection
...@@ -46,7 +46,7 @@ A browser interface based on Gradio library for Stable Diffusion. ...@@ -46,7 +46,7 @@ A browser interface based on Gradio library for Stable Diffusion.
- drag and drop an image/text-parameters to promptbox - drag and drop an image/text-parameters to promptbox
- Read Generation Parameters Button, loads parameters in promptbox to UI - Read Generation Parameters Button, loads parameters in promptbox to UI
- Settings page - Settings page
- Running arbitrary python code from UI (must run with --allow-code to enable) - Running arbitrary python code from UI (must run with `--allow-code` to enable)
- Mouseover hints for most UI elements - Mouseover hints for most UI elements
- Possible to change defaults/mix/max/step values for UI elements via text config - Possible to change defaults/mix/max/step values for UI elements via text config
- Tiling support, a checkbox to create images that can be tiled like textures - Tiling support, a checkbox to create images that can be tiled like textures
...@@ -69,7 +69,7 @@ A browser interface based on Gradio library for Stable Diffusion. ...@@ -69,7 +69,7 @@ A browser interface based on Gradio library for Stable Diffusion.
- also supports weights for prompts: `a cat :1.2 AND a dog AND a penguin :2.2` - also supports weights for prompts: `a cat :1.2 AND a dog AND a penguin :2.2`
- No token limit for prompts (original stable diffusion lets you use up to 75 tokens) - No token limit for prompts (original stable diffusion lets you use up to 75 tokens)
- DeepDanbooru integration, creates danbooru style tags for anime prompts - DeepDanbooru integration, creates danbooru style tags for anime prompts
- [xformers](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers), major speed increase for select cards: (add --xformers to commandline args) - [xformers](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers), major speed increase for select cards: (add `--xformers` to commandline args)
- via extension: [History tab](https://github.com/yfszzx/stable-diffusion-webui-images-browser): view, direct and delete images conveniently within the UI - via extension: [History tab](https://github.com/yfszzx/stable-diffusion-webui-images-browser): view, direct and delete images conveniently within the UI
- Generate forever option - Generate forever option
- Training tab - Training tab
...@@ -78,11 +78,11 @@ A browser interface based on Gradio library for Stable Diffusion. ...@@ -78,11 +78,11 @@ A browser interface based on Gradio library for Stable Diffusion.
- Clip skip - Clip skip
- Hypernetworks - Hypernetworks
- Loras (same as Hypernetworks but more pretty) - Loras (same as Hypernetworks but more pretty)
- A sparate UI where you can choose, with preview, which embeddings, hypernetworks or Loras to add to your prompt. - A sparate UI where you can choose, with preview, which embeddings, hypernetworks or Loras to add to your prompt
- Can select to load a different VAE from settings screen - Can select to load a different VAE from settings screen
- Estimated completion time in progress bar - Estimated completion time in progress bar
- API - API
- Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML. - Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML
- via extension: [Aesthetic Gradients](https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients), a way to generate images with a specific aesthetic by using clip images embeds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients)) - via extension: [Aesthetic Gradients](https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients), a way to generate images with a specific aesthetic by using clip images embeds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients))
- [Stable Diffusion 2.0](https://github.com/Stability-AI/stablediffusion) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20) for instructions - [Stable Diffusion 2.0](https://github.com/Stability-AI/stablediffusion) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20) for instructions
- [Alt-Diffusion](https://arxiv.org/abs/2211.06679) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#alt-diffusion) for instructions - [Alt-Diffusion](https://arxiv.org/abs/2211.06679) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#alt-diffusion) for instructions
...@@ -91,7 +91,6 @@ A browser interface based on Gradio library for Stable Diffusion. ...@@ -91,7 +91,6 @@ A browser interface based on Gradio library for Stable Diffusion.
- Eased resolution restriction: generated image's domension must be a multiple of 8 rather than 64 - Eased resolution restriction: generated image's domension must be a multiple of 8 rather than 64
- Now with a license! - Now with a license!
- Reorder elements in the UI from settings screen - Reorder elements in the UI from settings screen
-
## Installation and Running ## Installation and Running
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs. Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
...@@ -101,7 +100,7 @@ Alternatively, use online services (like Google Colab): ...@@ -101,7 +100,7 @@ Alternatively, use online services (like Google Colab):
- [List of Online Services](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Online-Services) - [List of Online Services](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Online-Services)
### Automatic Installation on Windows ### Automatic Installation on Windows
1. Install [Python 3.10.6](https://www.python.org/downloads/windows/), checking "Add Python to PATH" 1. Install [Python 3.10.6](https://www.python.org/downloads/windows/), checking "Add Python to PATH".
2. Install [git](https://git-scm.com/download/win). 2. Install [git](https://git-scm.com/download/win).
3. Download the stable-diffusion-webui repository, for example by running `git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git`. 3. Download the stable-diffusion-webui repository, for example by running `git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git`.
4. Run `webui-user.bat` from Windows Explorer as normal, non-administrator, user. 4. Run `webui-user.bat` from Windows Explorer as normal, non-administrator, user.
...@@ -121,6 +120,7 @@ sudo pacman -S wget git python3 ...@@ -121,6 +120,7 @@ sudo pacman -S wget git python3
bash <(wget -qO- https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/master/webui.sh) bash <(wget -qO- https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/master/webui.sh)
``` ```
3. Run `webui.sh`. 3. Run `webui.sh`.
4. Check `webui-user.sh` for options.
### Installation on Apple Silicon ### Installation on Apple Silicon
Find the instructions [here](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Installation-on-Apple-Silicon). Find the instructions [here](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Installation-on-Apple-Silicon).
......
...@@ -4,8 +4,8 @@ channels: ...@@ -4,8 +4,8 @@ channels:
- defaults - defaults
dependencies: dependencies:
- python=3.10 - python=3.10
- pip=22.2.2 - pip=23.0
- cudatoolkit=11.3 - cudatoolkit=11.8
- pytorch=1.12.1 - pytorch=2.0
- torchvision=0.13.1 - torchvision=0.15
- numpy=1.23.1 - numpy=1.23
\ No newline at end of file
...@@ -8,7 +8,7 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork): ...@@ -8,7 +8,7 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
def activate(self, p, params_list): def activate(self, p, params_list):
additional = shared.opts.sd_lora additional = shared.opts.sd_lora
if additional != "" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0: if additional != "None" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0:
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts] p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier])) params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
......
This diff is collapsed.
...@@ -9,7 +9,11 @@ from modules import script_callbacks, ui_extra_networks, extra_networks, shared ...@@ -9,7 +9,11 @@ from modules import script_callbacks, ui_extra_networks, extra_networks, shared
def unload(): def unload():
torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora
torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora
torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora
torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_lora
torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_lora
torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_lora
def before_ui(): def before_ui():
...@@ -20,11 +24,27 @@ def before_ui(): ...@@ -20,11 +24,27 @@ def before_ui():
if not hasattr(torch.nn, 'Linear_forward_before_lora'): if not hasattr(torch.nn, 'Linear_forward_before_lora'):
torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward
if not hasattr(torch.nn, 'Linear_load_state_dict_before_lora'):
torch.nn.Linear_load_state_dict_before_lora = torch.nn.Linear._load_from_state_dict
if not hasattr(torch.nn, 'Conv2d_forward_before_lora'): if not hasattr(torch.nn, 'Conv2d_forward_before_lora'):
torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward
if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_lora'):
torch.nn.Conv2d_load_state_dict_before_lora = torch.nn.Conv2d._load_from_state_dict
if not hasattr(torch.nn, 'MultiheadAttention_forward_before_lora'):
torch.nn.MultiheadAttention_forward_before_lora = torch.nn.MultiheadAttention.forward
if not hasattr(torch.nn, 'MultiheadAttention_load_state_dict_before_lora'):
torch.nn.MultiheadAttention_load_state_dict_before_lora = torch.nn.MultiheadAttention._load_from_state_dict
torch.nn.Linear.forward = lora.lora_Linear_forward torch.nn.Linear.forward = lora.lora_Linear_forward
torch.nn.Linear._load_from_state_dict = lora.lora_Linear_load_state_dict
torch.nn.Conv2d.forward = lora.lora_Conv2d_forward torch.nn.Conv2d.forward = lora.lora_Conv2d_forward
torch.nn.Conv2d._load_from_state_dict = lora.lora_Conv2d_load_state_dict
torch.nn.MultiheadAttention.forward = lora.lora_MultiheadAttention_forward
torch.nn.MultiheadAttention._load_from_state_dict = lora.lora_MultiheadAttention_load_state_dict
script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules) script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
script_callbacks.on_script_unloaded(unload) script_callbacks.on_script_unloaded(unload)
...@@ -32,7 +52,5 @@ script_callbacks.on_before_ui(before_ui) ...@@ -32,7 +52,5 @@ script_callbacks.on_before_ui(before_ui)
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), { shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras), "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras),
"lora_apply_to_outputs": shared.OptionInfo(False, "Apply Lora to outputs rather than inputs when possible (experimental)"),
})) }))
...@@ -5,11 +5,15 @@ import traceback ...@@ -5,11 +5,15 @@ import traceback
import PIL.Image import PIL.Image
import numpy as np import numpy as np
import torch import torch
from tqdm import tqdm
from basicsr.utils.download_util import load_file_from_url from basicsr.utils.download_util import load_file_from_url
import modules.upscaler import modules.upscaler
from modules import devices, modelloader from modules import devices, modelloader
from scunet_model_arch import SCUNet as net from scunet_model_arch import SCUNet as net
from modules.shared import opts
from modules import images
class UpscalerScuNET(modules.upscaler.Upscaler): class UpscalerScuNET(modules.upscaler.Upscaler):
...@@ -42,28 +46,78 @@ class UpscalerScuNET(modules.upscaler.Upscaler): ...@@ -42,28 +46,78 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
scalers.append(scaler_data2) scalers.append(scaler_data2)
self.scalers = scalers self.scalers = scalers
def do_upscale(self, img: PIL.Image, selected_file): @staticmethod
@torch.no_grad()
def tiled_inference(img, model):
# test the image tile by tile
h, w = img.shape[2:]
tile = opts.SCUNET_tile
tile_overlap = opts.SCUNET_tile_overlap
if tile == 0:
return model(img)
device = devices.get_device_for('scunet')
assert tile % 8 == 0, "tile size should be a multiple of window_size"
sf = 1
stride = tile - tile_overlap
h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
E = torch.zeros(1, 3, h * sf, w * sf, dtype=img.dtype, device=device)
W = torch.zeros_like(E, dtype=devices.dtype, device=device)
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="ScuNET tiles") as pbar:
for h_idx in h_idx_list:
for w_idx in w_idx_list:
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
out_patch = model(in_patch)
out_patch_mask = torch.ones_like(out_patch)
E[
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
].add_(out_patch)
W[
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
].add_(out_patch_mask)
pbar.update(1)
output = E.div_(W)
return output
def do_upscale(self, img: PIL.Image.Image, selected_file):
torch.cuda.empty_cache() torch.cuda.empty_cache()
model = self.load_model(selected_file) model = self.load_model(selected_file)
if model is None: if model is None:
print(f"ScuNET: Unable to load model from {selected_file}", file=sys.stderr)
return img return img
device = devices.get_device_for('scunet') device = devices.get_device_for('scunet')
img = np.array(img) tile = opts.SCUNET_tile
img = img[:, :, ::-1] h, w = img.height, img.width
img = np.moveaxis(img, 2, 0) / 255 np_img = np.array(img)
img = torch.from_numpy(img).float() np_img = np_img[:, :, ::-1] # RGB to BGR
img = img.unsqueeze(0).to(device) np_img = np_img.transpose((2, 0, 1)) / 255 # HWC to CHW
torch_img = torch.from_numpy(np_img).float().unsqueeze(0).to(device) # type: ignore
with torch.no_grad():
output = model(img) if tile > h or tile > w:
output = output.squeeze().float().cpu().clamp_(0, 1).numpy() _img = torch.zeros(1, 3, max(h, tile), max(w, tile), dtype=torch_img.dtype, device=torch_img.device)
output = 255. * np.moveaxis(output, 0, 2) _img[:, :, :h, :w] = torch_img # pad image
output = output.astype(np.uint8) torch_img = _img
output = output[:, :, ::-1]
torch_output = self.tiled_inference(torch_img, model).squeeze(0)
torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any
np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy()
del torch_img, torch_output
torch.cuda.empty_cache() torch.cuda.empty_cache()
return PIL.Image.fromarray(output, 'RGB')
output = np_output.transpose((1, 2, 0)) # CHW to HWC
output = output[:, :, ::-1] # BGR to RGB
return PIL.Image.fromarray((output * 255).astype(np.uint8))
def load_model(self, path: str): def load_model(self, path: str):
device = devices.get_device_for('scunet') device = devices.get_device_for('scunet')
...@@ -84,4 +138,3 @@ class UpscalerScuNET(modules.upscaler.Upscaler): ...@@ -84,4 +138,3 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
model = model.to(device) model = model.to(device)
return model return model
// Stable Diffusion WebUI - Bracket checker // Stable Diffusion WebUI - Bracket checker
// Version 1.0 // By Hingashi no Florin/Bwin4L & @akx
// By Hingashi no Florin/Bwin4L
// Counts open and closed brackets (round, square, curly) in the prompt and negative prompt text boxes in the txt2img and img2img tabs. // Counts open and closed brackets (round, square, curly) in the prompt and negative prompt text boxes in the txt2img and img2img tabs.
// If there's a mismatch, the keyword counter turns red and if you hover on it, a tooltip tells you what's wrong. // If there's a mismatch, the keyword counter turns red and if you hover on it, a tooltip tells you what's wrong.
function checkBrackets(evt, textArea, counterElt) { function checkBrackets(textArea, counterElt) {
errorStringParen = '(...) - Different number of opening and closing parentheses detected.\n'; var counts = {};
errorStringSquare = '[...] - Different number of opening and closing square brackets detected.\n'; (textArea.value.match(/[(){}\[\]]/g) || []).forEach(bracket => {
errorStringCurly = '{...} - Different number of opening and closing curly brackets detected.\n'; counts[bracket] = (counts[bracket] || 0) + 1;
});
openBracketRegExp = /\(/g; var errors = [];
closeBracketRegExp = /\)/g;
openSquareBracketRegExp = /\[/g;
closeSquareBracketRegExp = /\]/g;
openCurlyBracketRegExp = /\{/g;
closeCurlyBracketRegExp = /\}/g;
totalOpenBracketMatches = 0;
totalCloseBracketMatches = 0;
totalOpenSquareBracketMatches = 0;
totalCloseSquareBracketMatches = 0;
totalOpenCurlyBracketMatches = 0;
totalCloseCurlyBracketMatches = 0;
openBracketMatches = textArea.value.match(openBracketRegExp);
if(openBracketMatches) {
totalOpenBracketMatches = openBracketMatches.length;
}
closeBracketMatches = textArea.value.match(closeBracketRegExp);
if(closeBracketMatches) {
totalCloseBracketMatches = closeBracketMatches.length;
}
openSquareBracketMatches = textArea.value.match(openSquareBracketRegExp);
if(openSquareBracketMatches) {
totalOpenSquareBracketMatches = openSquareBracketMatches.length;
}
closeSquareBracketMatches = textArea.value.match(closeSquareBracketRegExp);
if(closeSquareBracketMatches) {
totalCloseSquareBracketMatches = closeSquareBracketMatches.length;
}
openCurlyBracketMatches = textArea.value.match(openCurlyBracketRegExp);
if(openCurlyBracketMatches) {
totalOpenCurlyBracketMatches = openCurlyBracketMatches.length;
}
closeCurlyBracketMatches = textArea.value.match(closeCurlyBracketRegExp);
if(closeCurlyBracketMatches) {
totalCloseCurlyBracketMatches = closeCurlyBracketMatches.length;
}
if(totalOpenBracketMatches != totalCloseBracketMatches) {
if(!counterElt.title.includes(errorStringParen)) {
counterElt.title += errorStringParen;
}
} else {
counterElt.title = counterElt.title.replace(errorStringParen, '');
}
if(totalOpenSquareBracketMatches != totalCloseSquareBracketMatches) {
if(!counterElt.title.includes(errorStringSquare)) {
counterElt.title += errorStringSquare;
}
} else {
counterElt.title = counterElt.title.replace(errorStringSquare, '');
}
if(totalOpenCurlyBracketMatches != totalCloseCurlyBracketMatches) { function checkPair(open, close, kind) {
if(!counterElt.title.includes(errorStringCurly)) { if (counts[open] !== counts[close]) {
counterElt.title += errorStringCurly; errors.push(
`${open}...${close} - Detected ${counts[open] || 0} opening and ${counts[close] || 0} closing ${kind}.`
);
} }
} else {
counterElt.title = counterElt.title.replace(errorStringCurly, '');
} }
if(counterElt.title != '') { checkPair('(', ')', 'round brackets');
counterElt.classList.add('error'); checkPair('[', ']', 'square brackets');
} else { checkPair('{', '}', 'curly brackets');
counterElt.classList.remove('error'); counterElt.title = errors.join('\n');
} counterElt.classList.toggle('error', errors.length !== 0);
} }
function setupBracketChecking(id_prompt, id_counter){ function setupBracketChecking(id_prompt, id_counter) {
var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea"); var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea");
var counter = gradioApp().getElementById(id_counter) var counter = gradioApp().getElementById(id_counter)
textarea.addEventListener("input", function(evt){
checkBrackets(evt, textarea, counter)
});
}
var shadowRootLoaded = setInterval(function() {
var shadowRoot = document.querySelector('gradio-app').shadowRoot;
if(! shadowRoot) return false;
var shadowTextArea = shadowRoot.querySelectorAll('#txt2img_prompt > label > textarea'); if (textarea && counter) {
if(shadowTextArea.length < 1) return false; textarea.addEventListener("input", () => checkBrackets(textarea, counter));
}
clearInterval(shadowRootLoaded); }
setupBracketChecking('txt2img_prompt', 'txt2img_token_counter') onUiLoaded(function () {
setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter') setupBracketChecking('txt2img_prompt', 'txt2img_token_counter');
setupBracketChecking('img2img_prompt', 'imgimg_token_counter') setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter');
setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter') setupBracketChecking('img2img_prompt', 'img2img_token_counter');
}, 1000); setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter');
});
<div class='card' {preview_html} onclick={card_clicked}> <div class='card' style={style} onclick={card_clicked}>
{metadata_button} {metadata_button}
<div class='actions'> <div class='actions'>
......
...@@ -636,3 +636,29 @@ SOFTWARE. ...@@ -636,3 +636,29 @@ SOFTWARE.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
</pre> </pre>
<h2><a href="https://github.com/explosion/curated-transformers/blob/main/LICENSE">Curated transformers</a></h2>
<small>The MPS workaround for nn.Linear on macOS 13.2.X is based on the MPS workaround for nn.Linear created by danieldk for Curated transformers</small>
<pre>
The MIT License (MIT)
Copyright (C) 2021 ExplosionAI GmbH
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
</pre>
\ No newline at end of file
...@@ -12,7 +12,7 @@ function dimensionChange(e, is_width, is_height){ ...@@ -12,7 +12,7 @@ function dimensionChange(e, is_width, is_height){
currentHeight = e.target.value*1.0 currentHeight = e.target.value*1.0
} }
var inImg2img = Boolean(gradioApp().querySelector("button.rounded-t-lg.border-gray-200")) var inImg2img = gradioApp().querySelector("#tab_img2img").style.display == "block";
if(!inImg2img){ if(!inImg2img){
return; return;
...@@ -22,7 +22,7 @@ function dimensionChange(e, is_width, is_height){ ...@@ -22,7 +22,7 @@ function dimensionChange(e, is_width, is_height){
var tabIndex = get_tab_index('mode_img2img') var tabIndex = get_tab_index('mode_img2img')
if(tabIndex == 0){ // img2img if(tabIndex == 0){ // img2img
targetElement = gradioApp().querySelector('div[data-testid=image] img'); targetElement = gradioApp().querySelector('#img2img_image div[data-testid=image] img');
} else if(tabIndex == 1){ //Sketch } else if(tabIndex == 1){ //Sketch
targetElement = gradioApp().querySelector('#img2img_sketch div[data-testid=image] img'); targetElement = gradioApp().querySelector('#img2img_sketch div[data-testid=image] img');
} else if(tabIndex == 2){ // Inpaint } else if(tabIndex == 2){ // Inpaint
...@@ -38,7 +38,7 @@ function dimensionChange(e, is_width, is_height){ ...@@ -38,7 +38,7 @@ function dimensionChange(e, is_width, is_height){
if(!arPreviewRect){ if(!arPreviewRect){
arPreviewRect = document.createElement('div') arPreviewRect = document.createElement('div')
arPreviewRect.id = "imageARPreview"; arPreviewRect.id = "imageARPreview";
gradioApp().getRootNode().appendChild(arPreviewRect) gradioApp().appendChild(arPreviewRect)
} }
...@@ -91,7 +91,9 @@ onUiUpdate(function(){ ...@@ -91,7 +91,9 @@ onUiUpdate(function(){
if(arPreviewRect){ if(arPreviewRect){
arPreviewRect.style.display = 'none'; arPreviewRect.style.display = 'none';
} }
var inImg2img = Boolean(gradioApp().querySelector("button.rounded-t-lg.border-gray-200")) var tabImg2img = gradioApp().querySelector("#tab_img2img");
if (tabImg2img) {
var inImg2img = tabImg2img.style.display == "block";
if(inImg2img){ if(inImg2img){
let inputs = gradioApp().querySelectorAll('input'); let inputs = gradioApp().querySelectorAll('input');
inputs.forEach(function(e){ inputs.forEach(function(e){
...@@ -110,4 +112,5 @@ onUiUpdate(function(){ ...@@ -110,4 +112,5 @@ onUiUpdate(function(){
} }
}) })
} }
}
}); });
...@@ -43,7 +43,7 @@ contextMenuInit = function(){ ...@@ -43,7 +43,7 @@ contextMenuInit = function(){
}) })
gradioApp().getRootNode().appendChild(contextMenu) gradioApp().appendChild(contextMenu)
let menuWidth = contextMenu.offsetWidth + 4; let menuWidth = contextMenu.offsetWidth + 4;
let menuHeight = contextMenu.offsetHeight + 4; let menuHeight = contextMenu.offsetHeight + 4;
...@@ -161,14 +161,6 @@ addContextMenuEventListener = initResponse[2]; ...@@ -161,14 +161,6 @@ addContextMenuEventListener = initResponse[2];
appendContextMenuOption('#img2img_interrupt','Cancel generate forever',cancelGenerateForever) appendContextMenuOption('#img2img_interrupt','Cancel generate forever',cancelGenerateForever)
appendContextMenuOption('#img2img_generate', 'Cancel generate forever',cancelGenerateForever) appendContextMenuOption('#img2img_generate', 'Cancel generate forever',cancelGenerateForever)
appendContextMenuOption('#roll','Roll three',
function(){
let rollbutton = get_uiCurrentTabContent().querySelector('#roll');
setTimeout(function(){rollbutton.click()},100)
setTimeout(function(){rollbutton.click()},200)
setTimeout(function(){rollbutton.click()},300)
}
)
})(); })();
//End example Context Menu Items //End example Context Menu Items
......
function keyupEditAttention(event){ function keyupEditAttention(event){
let target = event.originalTarget || event.composedPath()[0]; let target = event.originalTarget || event.composedPath()[0];
if (!target.matches("[id*='_toprow'] textarea.gr-text-input[placeholder]")) return; if (! target.matches("[id*='_toprow'] [id*='_prompt'] textarea")) return;
if (! (event.metaKey || event.ctrlKey)) return; if (! (event.metaKey || event.ctrlKey)) return;
let isPlus = event.key == "ArrowUp" let isPlus = event.key == "ArrowUp"
...@@ -44,9 +44,27 @@ function keyupEditAttention(event){ ...@@ -44,9 +44,27 @@ function keyupEditAttention(event){
return true; return true;
} }
// If the user hasn't selected anything, let's select their current parenthesis block function selectCurrentWord(){
if(! selectCurrentParenthesisBlock('<', '>')){ if (selectionStart !== selectionEnd) return false;
selectCurrentParenthesisBlock('(', ')') const delimiters = opts.keyedit_delimiters + " \r\n\t";
// seek backward until to find beggining
while (!delimiters.includes(text[selectionStart - 1]) && selectionStart > 0) {
selectionStart--;
}
// seek forward to find end
while (!delimiters.includes(text[selectionEnd]) && selectionEnd < text.length) {
selectionEnd++;
}
target.setSelectionRange(selectionStart, selectionEnd);
return true;
}
// If the user hasn't selected anything, let's select their current parenthesis block or word
if (!selectCurrentParenthesisBlock('<', '>') && !selectCurrentParenthesisBlock('(', ')')) {
selectCurrentWord();
} }
event.preventDefault(); event.preventDefault();
...@@ -81,7 +99,13 @@ function keyupEditAttention(event){ ...@@ -81,7 +99,13 @@ function keyupEditAttention(event){
weight = parseFloat(weight.toPrecision(12)); weight = parseFloat(weight.toPrecision(12));
if(String(weight).length == 1) weight += ".0" if(String(weight).length == 1) weight += ".0"
if (closeCharacter == ')' && weight == 1) {
text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + 5);
selectionStart--;
selectionEnd--;
} else {
text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1); text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1);
}
target.focus(); target.focus();
target.value = text; target.value = text;
......
function extensions_apply(_, _){ function extensions_apply(_, _, disable_all){
var disable = [] var disable = []
var update = [] var update = []
...@@ -13,10 +13,10 @@ function extensions_apply(_, _){ ...@@ -13,10 +13,10 @@ function extensions_apply(_, _){
restart_reload() restart_reload()
return [JSON.stringify(disable), JSON.stringify(update)] return [JSON.stringify(disable), JSON.stringify(update), disable_all]
} }
function extensions_check(){ function extensions_check(_, _){
var disable = [] var disable = []
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){ gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
......
...@@ -139,3 +139,41 @@ function extraNetworksShowMetadata(text){ ...@@ -139,3 +139,41 @@ function extraNetworksShowMetadata(text){
popup(elem); popup(elem);
} }
function requestGet(url, data, handler, errorHandler){
var xhr = new XMLHttpRequest();
var args = Object.keys(data).map(function(k){ return encodeURIComponent(k) + '=' + encodeURIComponent(data[k]) }).join('&')
xhr.open("GET", url + "?" + args, true);
xhr.onreadystatechange = function () {
if (xhr.readyState === 4) {
if (xhr.status === 200) {
try {
var js = JSON.parse(xhr.responseText);
handler(js)
} catch (error) {
console.error(error);
errorHandler()
}
} else{
errorHandler()
}
}
};
var js = JSON.stringify(data);
xhr.send(js);
}
function extraNetworksRequestMetadata(event, extraPage, cardName){
showError = function(){ extraNetworksShowMetadata("there was an error getting metadata"); }
requestGet("./sd_extra_networks/metadata", {"page": extraPage, "item": cardName}, function(data){
if(data && data.metadata){
extraNetworksShowMetadata(data.metadata)
} else{
showError()
}
}, showError)
event.stopPropagation()
}
...@@ -16,7 +16,7 @@ onUiUpdate(function(){ ...@@ -16,7 +16,7 @@ onUiUpdate(function(){
let modalObserver = new MutationObserver(function(mutations) { let modalObserver = new MutationObserver(function(mutations) {
mutations.forEach(function(mutationRecord) { mutations.forEach(function(mutationRecord) {
let selectedTab = gradioApp().querySelector('#tabs div button.bg-white')?.innerText let selectedTab = gradioApp().querySelector('#tabs div button')?.innerText
if (mutationRecord.target.style.display === 'none' && selectedTab === 'txt2img' || selectedTab === 'img2img') if (mutationRecord.target.style.display === 'none' && selectedTab === 'txt2img' || selectedTab === 'img2img')
gradioApp().getElementById(selectedTab+"_generation_info_button").click() gradioApp().getElementById(selectedTab+"_generation_info_button").click()
}); });
......
...@@ -18,11 +18,10 @@ titles = { ...@@ -18,11 +18,10 @@ titles = {
"\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.", "\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.",
"\u{1f4c2}": "Open images output directory", "\u{1f4c2}": "Open images output directory",
"\u{1f4be}": "Save style", "\u{1f4be}": "Save style",
"\u{1f5d1}": "Clear prompt", "\u{1f5d1}\ufe0f": "Clear prompt",
"\u{1f4cb}": "Apply selected styles to current prompt", "\u{1f4cb}": "Apply selected styles to current prompt",
"\u{1f4d2}": "Paste available values into the field", "\u{1f4d2}": "Paste available values into the field",
"\u{1f3b4}": "Show extra networks", "\u{1f3b4}": "Show/hide extra networks",
"Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt", "Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
"SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back", "SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back",
...@@ -40,7 +39,6 @@ titles = { ...@@ -40,7 +39,6 @@ titles = {
"Inpaint at full resolution": "Upscale masked region to target resolution, do inpainting, downscale back and paste into original image", "Inpaint at full resolution": "Upscale masked region to target resolution, do inpainting, downscale back and paste into original image",
"Denoising strength": "Determines how little respect the algorithm should have for image's content. At 0, nothing will change, and at 1 you'll get an unrelated image. With values below 1.0, processing will take less steps than the Sampling Steps slider specifies.", "Denoising strength": "Determines how little respect the algorithm should have for image's content. At 0, nothing will change, and at 1 you'll get an unrelated image. With values below 1.0, processing will take less steps than the Sampling Steps slider specifies.",
"Denoising strength change factor": "In loopback mode, on each loop the denoising strength is multiplied by this value. <1 means decreasing variety so your sequence will converge on a fixed picture. >1 means increasing variety so your sequence will become more and more chaotic.",
"Skip": "Stop processing current image and continue processing.", "Skip": "Stop processing current image and continue processing.",
"Interrupt": "Stop processing images and return any results accumulated so far.", "Interrupt": "Stop processing images and return any results accumulated so far.",
...@@ -71,8 +69,10 @@ titles = { ...@@ -71,8 +69,10 @@ titles = {
"Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg],[prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp], [hasprompt<prompt1|default><prompt2>..]; leave empty for default.", "Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg],[prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp], [hasprompt<prompt1|default><prompt2>..]; leave empty for default.",
"Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle", "Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle",
"Loopback": "Process an image, use it as an input, repeat.", "Loopback": "Performs img2img processing multiple times. Output images are used as input for the next loop.",
"Loops": "How many times to repeat processing an image and using it as input for the next iteration", "Loops": "How many times to process an image. Each output is used as the input of the next loop. If set to 1, behavior will be as if this script were not used.",
"Final denoising strength": "The denoising strength for the final loop of each image in the batch.",
"Denoising strength curve": "The denoising curve controls the rate of denoising strength change each loop. Aggressive: Most of the change will happen towards the start of the loops. Linear: Change will be constant through all loops. Lazy: Most of the change will happen towards the end of the loops.",
"Style 1": "Style to apply; styles have components for both positive and negative prompts and apply to both", "Style 1": "Style to apply; styles have components for both positive and negative prompts and apply to both",
"Style 2": "Style to apply; styles have components for both positive and negative prompts and apply to both", "Style 2": "Style to apply; styles have components for both positive and negative prompts and apply to both",
......
...@@ -32,13 +32,7 @@ function negmod(n, m) { ...@@ -32,13 +32,7 @@ function negmod(n, m) {
function updateOnBackgroundChange() { function updateOnBackgroundChange() {
const modalImage = gradioApp().getElementById("modalImage") const modalImage = gradioApp().getElementById("modalImage")
if (modalImage && modalImage.offsetParent) { if (modalImage && modalImage.offsetParent) {
let allcurrentButtons = gradioApp().querySelectorAll(".gallery-item.transition-all.\\!ring-2") let currentButton = selected_gallery_button();
let currentButton = null
allcurrentButtons.forEach(function(elem) {
if (elem.parentElement.offsetParent) {
currentButton = elem;
}
})
if (currentButton?.children?.length > 0 && modalImage.src != currentButton.children[0].src) { if (currentButton?.children?.length > 0 && modalImage.src != currentButton.children[0].src) {
modalImage.src = currentButton.children[0].src; modalImage.src = currentButton.children[0].src;
...@@ -50,22 +44,10 @@ function updateOnBackgroundChange() { ...@@ -50,22 +44,10 @@ function updateOnBackgroundChange() {
} }
function modalImageSwitch(offset) { function modalImageSwitch(offset) {
var allgalleryButtons = gradioApp().querySelectorAll(".gallery-item.transition-all") var galleryButtons = all_gallery_buttons();
var galleryButtons = []
allgalleryButtons.forEach(function(elem) {
if (elem.parentElement.offsetParent) {
galleryButtons.push(elem);
}
})
if (galleryButtons.length > 1) { if (galleryButtons.length > 1) {
var allcurrentButtons = gradioApp().querySelectorAll(".gallery-item.transition-all.\\!ring-2") var currentButton = selected_gallery_button();
var currentButton = null
allcurrentButtons.forEach(function(elem) {
if (elem.parentElement.offsetParent) {
currentButton = elem;
}
})
var result = -1 var result = -1
galleryButtons.forEach(function(v, i) { galleryButtons.forEach(function(v, i) {
...@@ -136,20 +118,15 @@ function modalKeyHandler(event) { ...@@ -136,20 +118,15 @@ function modalKeyHandler(event) {
} }
} }
function showGalleryImage() { function setupImageForLightbox(e) {
setTimeout(function() {
fullImg_preview = gradioApp().querySelectorAll('img.w-full.object-contain')
if (fullImg_preview != null) {
fullImg_preview.forEach(function function_name(e) {
if (e.dataset.modded) if (e.dataset.modded)
return; return;
e.dataset.modded = true; e.dataset.modded = true;
if(e && e.parentElement.tagName == 'DIV'){
e.style.cursor='pointer' e.style.cursor='pointer'
e.style.userSelect='none' e.style.userSelect='none'
var isFirefox = isFirefox = navigator.userAgent.toLowerCase().indexOf('firefox') > -1 var isFirefox = navigator.userAgent.toLowerCase().indexOf('firefox') > -1
// For Firefox, listening on click first switched to next image then shows the lightbox. // For Firefox, listening on click first switched to next image then shows the lightbox.
// If you know how to fix this without switching to mousedown event, please. // If you know how to fix this without switching to mousedown event, please.
...@@ -158,15 +135,12 @@ function showGalleryImage() { ...@@ -158,15 +135,12 @@ function showGalleryImage() {
e.addEventListener(event, function (evt) { e.addEventListener(event, function (evt) {
if(!opts.js_modal_lightbox || evt.button != 0) return; if(!opts.js_modal_lightbox || evt.button != 0) return;
modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed) modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed)
evt.preventDefault() evt.preventDefault()
showModal(evt) showModal(evt)
}, true); }, true);
}
});
}
}, 100);
} }
function modalZoomSet(modalImage, enable) { function modalZoomSet(modalImage, enable) {
...@@ -199,21 +173,21 @@ function modalTileImageToggle(event) { ...@@ -199,21 +173,21 @@ function modalTileImageToggle(event) {
} }
function galleryImageHandler(e) { function galleryImageHandler(e) {
if (e && e.parentElement.tagName == 'BUTTON') { //if (e && e.parentElement.tagName == 'BUTTON') {
e.onclick = showGalleryImage; e.onclick = showGalleryImage;
} //}
} }
onUiUpdate(function() { onUiUpdate(function() {
fullImg_preview = gradioApp().querySelectorAll('img.w-full') fullImg_preview = gradioApp().querySelectorAll('.gradio-gallery > div > img')
if (fullImg_preview != null) { if (fullImg_preview != null) {
fullImg_preview.forEach(galleryImageHandler); fullImg_preview.forEach(setupImageForLightbox);
} }
updateOnBackgroundChange(); updateOnBackgroundChange();
}) })
document.addEventListener("DOMContentLoaded", function() { document.addEventListener("DOMContentLoaded", function() {
const modalFragment = document.createDocumentFragment(); //const modalFragment = document.createDocumentFragment();
const modal = document.createElement('div') const modal = document.createElement('div')
modal.onclick = closeModal; modal.onclick = closeModal;
modal.id = "lightboxModal"; modal.id = "lightboxModal";
...@@ -277,9 +251,12 @@ document.addEventListener("DOMContentLoaded", function() { ...@@ -277,9 +251,12 @@ document.addEventListener("DOMContentLoaded", function() {
modal.appendChild(modalNext) modal.appendChild(modalNext)
try {
gradioApp().appendChild(modal);
} catch (e) {
gradioApp().body.appendChild(modal);
}
gradioApp().getRootNode().appendChild(modal) document.body.appendChild(modal);
document.body.appendChild(modalFragment);
}); });
...@@ -15,7 +15,7 @@ onUiUpdate(function(){ ...@@ -15,7 +15,7 @@ onUiUpdate(function(){
} }
} }
const galleryPreviews = gradioApp().querySelectorAll('div[id^="tab_"][style*="display: block"] div[id$="_results"] img.h-full.w-full.overflow-hidden'); const galleryPreviews = gradioApp().querySelectorAll('div[id^="tab_"][style*="display: block"] div[id$="_results"] .thumbnail-item > img');
if (galleryPreviews == null) return; if (galleryPreviews == null) return;
......
// code related to showing and updating progressbar shown as the image is being made // code related to showing and updating progressbar shown as the image is being made
galleries = {}
storedGallerySelections = {}
galleryObservers = {}
function rememberGallerySelection(id_gallery){ function rememberGallerySelection(id_gallery){
storedGallerySelections[id_gallery] = getGallerySelectedIndex(id_gallery)
} }
function getGallerySelectedIndex(id_gallery){ function getGallerySelectedIndex(id_gallery){
let galleryButtons = gradioApp().querySelectorAll('#'+id_gallery+' .gallery-item')
let galleryBtnSelected = gradioApp().querySelector('#'+id_gallery+' .gallery-item.\\!ring-2')
let currentlySelectedIndex = -1
galleryButtons.forEach(function(v, i){ if(v==galleryBtnSelected) { currentlySelectedIndex = i } })
return currentlySelectedIndex
}
// this is a workaround for https://github.com/gradio-app/gradio/issues/2984
function check_gallery(id_gallery){
let gallery = gradioApp().getElementById(id_gallery)
// if gallery has no change, no need to setting up observer again.
if (gallery && galleries[id_gallery] !== gallery){
galleries[id_gallery] = gallery;
if(galleryObservers[id_gallery]){
galleryObservers[id_gallery].disconnect();
}
storedGallerySelections[id_gallery] = -1
galleryObservers[id_gallery] = new MutationObserver(function (){
let galleryButtons = gradioApp().querySelectorAll('#'+id_gallery+' .gallery-item')
let galleryBtnSelected = gradioApp().querySelector('#'+id_gallery+' .gallery-item.\\!ring-2')
let currentlySelectedIndex = getGallerySelectedIndex(id_gallery)
prevSelectedIndex = storedGallerySelections[id_gallery]
storedGallerySelections[id_gallery] = -1
if (prevSelectedIndex !== -1 && galleryButtons.length>prevSelectedIndex && !galleryBtnSelected) {
// automatically re-open previously selected index (if exists)
activeElement = gradioApp().activeElement;
let scrollX = window.scrollX;
let scrollY = window.scrollY;
galleryButtons[prevSelectedIndex].click();
showGalleryImage();
// When the gallery button is clicked, it gains focus and scrolls itself into view
// We need to scroll back to the previous position
setTimeout(function (){
window.scrollTo(scrollX, scrollY);
}, 50);
if(activeElement){
// i fought this for about an hour; i don't know why the focus is lost or why this helps recover it
// if someone has a better solution please by all means
setTimeout(function (){
activeElement.focus({
preventScroll: true // Refocus the element that was focused before the gallery was opened without scrolling to it
})
}, 1);
}
}
})
galleryObservers[id_gallery].observe( gallery, { childList:true, subtree:false })
}
} }
onUiUpdate(function(){
check_gallery('txt2img_gallery')
check_gallery('img2img_gallery')
})
function request(url, data, handler, errorHandler){ function request(url, data, handler, errorHandler){
var xhr = new XMLHttpRequest(); var xhr = new XMLHttpRequest();
var url = url; var url = url;
...@@ -203,7 +138,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre ...@@ -203,7 +138,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
return return
} }
if(elapsedFromStart > 5 && !res.queued && !res.active){ if(elapsedFromStart > 40 && !res.queued && !res.active){
removeProgressBar() removeProgressBar()
return return
} }
......
...@@ -7,9 +7,31 @@ function set_theme(theme){ ...@@ -7,9 +7,31 @@ function set_theme(theme){
} }
} }
function all_gallery_buttons() {
var allGalleryButtons = gradioApp().querySelectorAll('[style="display: block;"].tabitem div[id$=_gallery].gradio-gallery .thumbnails > .thumbnail-item.thumbnail-small');
var visibleGalleryButtons = [];
allGalleryButtons.forEach(function(elem) {
if (elem.parentElement.offsetParent) {
visibleGalleryButtons.push(elem);
}
})
return visibleGalleryButtons;
}
function selected_gallery_button() {
var allCurrentButtons = gradioApp().querySelectorAll('[style="display: block;"].tabitem div[id$=_gallery].gradio-gallery .thumbnail-item.thumbnail-small.selected');
var visibleCurrentButton = null;
allCurrentButtons.forEach(function(elem) {
if (elem.parentElement.offsetParent) {
visibleCurrentButton = elem;
}
})
return visibleCurrentButton;
}
function selected_gallery_index(){ function selected_gallery_index(){
var buttons = gradioApp().querySelectorAll('[style="display: block;"].tabitem div[id$=_gallery] .gallery-item') var buttons = all_gallery_buttons();
var button = gradioApp().querySelector('[style="display: block;"].tabitem div[id$=_gallery] .gallery-item.\\!ring-2') var button = selected_gallery_button();
var result = -1 var result = -1
buttons.forEach(function(v, i){ if(v==button) { result = i } }) buttons.forEach(function(v, i){ if(v==button) { result = i } })
...@@ -18,14 +40,18 @@ function selected_gallery_index(){ ...@@ -18,14 +40,18 @@ function selected_gallery_index(){
} }
function extract_image_from_gallery(gallery){ function extract_image_from_gallery(gallery){
if(gallery.length == 1){ if (gallery.length == 0){
return [gallery[0]] return [null];
}
if (gallery.length == 1){
return [gallery[0]];
} }
index = selected_gallery_index() index = selected_gallery_index()
if (index < 0 || index >= gallery.length){ if (index < 0 || index >= gallery.length){
return [null] // Use the first image in the gallery as the default
index = 0;
} }
return [gallery[index]]; return [gallery[index]];
...@@ -86,7 +112,7 @@ function get_tab_index(tabId){ ...@@ -86,7 +112,7 @@ function get_tab_index(tabId){
var res = 0 var res = 0
gradioApp().getElementById(tabId).querySelector('div').querySelectorAll('button').forEach(function(button, i){ gradioApp().getElementById(tabId).querySelector('div').querySelectorAll('button').forEach(function(button, i){
if(button.className.indexOf('bg-white') != -1) if(button.className.indexOf('selected') != -1)
res = i res = i
}) })
...@@ -255,7 +281,6 @@ onUiUpdate(function(){ ...@@ -255,7 +281,6 @@ onUiUpdate(function(){
} }
prompt.parentElement.insertBefore(counter, prompt) prompt.parentElement.insertBefore(counter, prompt)
counter.classList.add("token-counter")
prompt.parentElement.style.position = "relative" prompt.parentElement.style.position = "relative"
promptTokecountUpdateFuncs[id] = function(){ update_token_counter(id_button); } promptTokecountUpdateFuncs[id] = function(){ update_token_counter(id_button); }
......
...@@ -5,24 +5,25 @@ import sys ...@@ -5,24 +5,25 @@ import sys
import importlib.util import importlib.util
import shlex import shlex
import platform import platform
import argparse
import json import json
parser = argparse.ArgumentParser(add_help=False) from modules import cmd_args
parser.add_argument("--ui-settings-file", type=str, default='config.json') from modules.paths_internal import script_path, extensions_dir
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.realpath(__file__)))
args, _ = parser.parse_known_args(sys.argv)
script_path = os.path.dirname(__file__) commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
data_path = os.getcwd() sys.argv += shlex.split(commandline_args)
args, _ = cmd_args.parser.parse_known_args()
dir_repos = "repositories"
dir_extensions = "extensions"
python = sys.executable python = sys.executable
git = os.environ.get('GIT', "git") git = os.environ.get('GIT', "git")
index_url = os.environ.get('INDEX_URL', "") index_url = os.environ.get('INDEX_URL', "")
stored_commit_hash = None stored_commit_hash = None
skip_install = False skip_install = False
dir_repos = "repositories"
if 'GRADIO_ANALYTICS_ENABLED' not in os.environ:
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
def check_python_version(): def check_python_version():
...@@ -70,23 +71,6 @@ def commit_hash(): ...@@ -70,23 +71,6 @@ def commit_hash():
return stored_commit_hash return stored_commit_hash
def extract_arg(args, name):
return [x for x in args if x != name], name in args
def extract_opt(args, name):
opt = None
is_present = False
if name in args:
is_present = True
idx = args.index(name)
del args[idx]
if idx < len(args) and args[idx][0] != "-":
opt = args[idx]
del args[idx]
return args, is_present, opt
def run(command, desc=None, errdesc=None, custom_env=None, live=False): def run(command, desc=None, errdesc=None, custom_env=None, live=False):
if desc is not None: if desc is not None:
print(desc) print(desc)
...@@ -137,12 +121,12 @@ def run_python(code, desc=None, errdesc=None): ...@@ -137,12 +121,12 @@ def run_python(code, desc=None, errdesc=None):
return run(f'"{python}" -c "{code}"', desc, errdesc) return run(f'"{python}" -c "{code}"', desc, errdesc)
def run_pip(args, desc=None): def run_pip(args, desc=None, live=False):
if skip_install: if skip_install:
return return
index_url_line = f' --index-url {index_url}' if index_url != '' else '' index_url_line = f' --index-url {index_url}' if index_url != '' else ''
return run(f'"{python}" -m pip {args} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}") return run(f'"{python}" -m pip {args} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}", live=live)
def check_run_python(code): def check_run_python(code):
...@@ -222,26 +206,29 @@ def list_extensions(settings_file): ...@@ -222,26 +206,29 @@ def list_extensions(settings_file):
print(e, file=sys.stderr) print(e, file=sys.stderr)
disabled_extensions = set(settings.get('disabled_extensions', [])) disabled_extensions = set(settings.get('disabled_extensions', []))
disable_all_extensions = settings.get('disable_all_extensions', 'none')
if disable_all_extensions != 'none':
return []
return [x for x in os.listdir(os.path.join(data_path, dir_extensions)) if x not in disabled_extensions] return [x for x in os.listdir(extensions_dir) if x not in disabled_extensions]
def run_extensions_installers(settings_file): def run_extensions_installers(settings_file):
if not os.path.isdir(dir_extensions): if not os.path.isdir(extensions_dir):
return return
for dirname_extension in list_extensions(settings_file): for dirname_extension in list_extensions(settings_file):
run_extension_installer(os.path.join(dir_extensions, dirname_extension)) run_extension_installer(os.path.join(extensions_dir, dirname_extension))
def prepare_environment(): def prepare_environment():
global skip_install global skip_install
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117") torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==2.0.0 torchvision==0.15.1 --index-url https://download.pytorch.org/whl/cu118")
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.16rc425') xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.17')
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379") gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1") clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b") openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b")
...@@ -252,27 +239,13 @@ def prepare_environment(): ...@@ -252,27 +239,13 @@ def prepare_environment():
codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git') codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git') blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "47b6b607fdd31875c9279cd2f4f16b92e4ea958e") stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6") taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "5b3af030dd83e0297272d861c19477735d0317ec") k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "5b3af030dd83e0297272d861c19477735d0317ec")
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af") codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
sys.argv += shlex.split(commandline_args) if not args.skip_python_version_check:
sys.argv, _ = extract_arg(sys.argv, '-f')
sys.argv, update_all_extensions = extract_arg(sys.argv, '--update-all-extensions')
sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test')
sys.argv, skip_python_version_check = extract_arg(sys.argv, '--skip-python-version-check')
sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers')
sys.argv, reinstall_torch = extract_arg(sys.argv, '--reinstall-torch')
sys.argv, update_check = extract_arg(sys.argv, '--update-check')
sys.argv, run_tests, test_dir = extract_opt(sys.argv, '--tests')
sys.argv, skip_install = extract_arg(sys.argv, '--skip-install')
xformers = '--xformers' in sys.argv
ngrok = '--ngrok' in sys.argv
if not skip_python_version_check:
check_python_version() check_python_version()
commit = commit_hash() commit = commit_hash()
...@@ -280,10 +253,10 @@ def prepare_environment(): ...@@ -280,10 +253,10 @@ def prepare_environment():
print(f"Python {sys.version}") print(f"Python {sys.version}")
print(f"Commit hash: {commit}") print(f"Commit hash: {commit}")
if reinstall_torch or not is_installed("torch") or not is_installed("torchvision"): if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True) run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
if not skip_torch_cuda_test: if not args.skip_torch_cuda_test:
run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'") run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'")
if not is_installed("gfpgan"): if not is_installed("gfpgan"):
...@@ -295,10 +268,10 @@ def prepare_environment(): ...@@ -295,10 +268,10 @@ def prepare_environment():
if not is_installed("open_clip"): if not is_installed("open_clip"):
run_pip(f"install {openclip_package}", "open_clip") run_pip(f"install {openclip_package}", "open_clip")
if (not is_installed("xformers") or reinstall_xformers) and xformers: if (not is_installed("xformers") or args.reinstall_xformers) and args.xformers:
if platform.system() == "Windows": if platform.system() == "Windows":
if platform.python_version().startswith("3.10"): if platform.python_version().startswith("3.10"):
run_pip(f"install -U -I --no-deps {xformers_package}", "xformers") run_pip(f"install -U -I --no-deps {xformers_package}", "xformers", live=True)
else: else:
print("Installation of xformers is not supported in this version of Python.") print("Installation of xformers is not supported in this version of Python.")
print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness") print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness")
...@@ -307,7 +280,7 @@ def prepare_environment(): ...@@ -307,7 +280,7 @@ def prepare_environment():
elif platform.system() == "Linux": elif platform.system() == "Linux":
run_pip(f"install {xformers_package}", "xformers") run_pip(f"install {xformers_package}", "xformers")
if not is_installed("pyngrok") and ngrok: if not is_installed("pyngrok") and args.ngrok:
run_pip("install pyngrok", "ngrok") run_pip("install pyngrok", "ngrok")
os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True) os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
...@@ -323,22 +296,22 @@ def prepare_environment(): ...@@ -323,22 +296,22 @@ def prepare_environment():
if not os.path.isfile(requirements_file): if not os.path.isfile(requirements_file):
requirements_file = os.path.join(script_path, requirements_file) requirements_file = os.path.join(script_path, requirements_file)
run_pip(f"install -r \"{requirements_file}\"", "requirements for Web UI") run_pip(f"install -r \"{requirements_file}\"", "requirements")
run_extensions_installers(settings_file=args.ui_settings_file) run_extensions_installers(settings_file=args.ui_settings_file)
if update_check: if args.update_check:
version_check(commit) version_check(commit)
if update_all_extensions: if args.update_all_extensions:
git_pull_recursive(os.path.join(data_path, dir_extensions)) git_pull_recursive(extensions_dir)
if "--exit" in sys.argv: if "--exit" in sys.argv:
print("Exiting because of --exit argument") print("Exiting because of --exit argument")
exit(0) exit(0)
if run_tests: if args.tests and not args.no_tests:
exitcode = tests(test_dir) exitcode = tests(args.tests)
exit(exitcode) exit(exitcode)
...@@ -352,6 +325,8 @@ def tests(test_dir): ...@@ -352,6 +325,8 @@ def tests(test_dir):
sys.argv.append("--skip-torch-cuda-test") sys.argv.append("--skip-torch-cuda-test")
if "--disable-nan-check" not in sys.argv: if "--disable-nan-check" not in sys.argv:
sys.argv.append("--disable-nan-check") sys.argv.append("--disable-nan-check")
if "--no-tests" not in sys.argv:
sys.argv.append("--no-tests")
print(f"Launching Web UI in another process for testing with arguments: {' '.join(sys.argv[1:])}") print(f"Launching Web UI in another process for testing with arguments: {' '.join(sys.argv[1:])}")
......
This diff is collapsed.
This diff is collapsed.
...@@ -92,14 +92,18 @@ def cond_cast_float(input): ...@@ -92,14 +92,18 @@ def cond_cast_float(input):
def randn(seed, shape): def randn(seed, shape):
from modules.shared import opts
torch.manual_seed(seed) torch.manual_seed(seed)
if device.type == 'mps': if opts.randn_source == "CPU" or device.type == 'mps':
return torch.randn(shape, device=cpu).to(device) return torch.randn(shape, device=cpu).to(device)
return torch.randn(shape, device=device) return torch.randn(shape, device=device)
def randn_without_seed(shape): def randn_without_seed(shape):
if device.type == 'mps': from modules.shared import opts
if opts.randn_source == "CPU" or device.type == 'mps':
return torch.randn(shape, device=cpu).to(device) return torch.randn(shape, device=cpu).to(device)
return torch.randn(shape, device=device) return torch.randn(shape, device=device)
......
...@@ -5,16 +5,21 @@ import traceback ...@@ -5,16 +5,21 @@ import traceback
import time import time
import git import git
from modules import paths, shared from modules import shared
from modules.paths_internal import extensions_dir, extensions_builtin_dir
extensions = [] extensions = []
extensions_dir = os.path.join(paths.data_path, "extensions")
extensions_builtin_dir = os.path.join(paths.script_path, "extensions-builtin")
if not os.path.exists(extensions_dir): if not os.path.exists(extensions_dir):
os.makedirs(extensions_dir) os.makedirs(extensions_dir)
def active(): def active():
if shared.opts.disable_all_extensions == "all":
return []
elif shared.opts.disable_all_extensions == "extra":
return [x for x in extensions if x.enabled and x.is_builtin]
else:
return [x for x in extensions if x.enabled] return [x for x in extensions if x.enabled]
...@@ -27,21 +32,29 @@ class Extension: ...@@ -27,21 +32,29 @@ class Extension:
self.can_update = False self.can_update = False
self.is_builtin = is_builtin self.is_builtin = is_builtin
self.version = '' self.version = ''
self.remote = None
self.have_info_from_repo = False
def read_info_from_repo(self):
if self.have_info_from_repo:
return
self.have_info_from_repo = True
repo = None repo = None
try: try:
if os.path.exists(os.path.join(path, ".git")): if os.path.exists(os.path.join(self.path, ".git")):
repo = git.Repo(path) repo = git.Repo(self.path)
except Exception: except Exception:
print(f"Error reading github repository info from {path}:", file=sys.stderr) print(f"Error reading github repository info from {self.path}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
if repo is None or repo.bare: if repo is None or repo.bare:
self.remote = None self.remote = None
else: else:
try: try:
self.remote = next(repo.remote().urls, None)
self.status = 'unknown' self.status = 'unknown'
self.remote = next(repo.remote().urls, None)
head = repo.head.commit head = repo.head.commit
ts = time.asctime(time.gmtime(repo.head.commit.committed_date)) ts = time.asctime(time.gmtime(repo.head.commit.committed_date))
self.version = f'{head.hexsha[:8]} ({ts})' self.version = f'{head.hexsha[:8]} ({ts})'
...@@ -89,7 +102,12 @@ def list_extensions(): ...@@ -89,7 +102,12 @@ def list_extensions():
if not os.path.isdir(extensions_dir): if not os.path.isdir(extensions_dir):
return return
paths = [] if shared.opts.disable_all_extensions == "all":
print("*** \"Disable all extensions\" option was set, will not load any extensions ***")
elif shared.opts.disable_all_extensions == "extra":
print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***")
extension_paths = []
for dirname in [extensions_dir, extensions_builtin_dir]: for dirname in [extensions_dir, extensions_builtin_dir]:
if not os.path.isdir(dirname): if not os.path.isdir(dirname):
return return
...@@ -99,9 +117,8 @@ def list_extensions(): ...@@ -99,9 +117,8 @@ def list_extensions():
if not os.path.isdir(path): if not os.path.isdir(path):
continue continue
paths.append((extension_dirname, path, dirname == extensions_builtin_dir)) extension_paths.append((extension_dirname, path, dirname == extensions_builtin_dir))
for dirname, path, is_builtin in paths: for dirname, path, is_builtin in extension_paths:
extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin) extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin)
extensions.append(extension) extensions.append(extension)
...@@ -9,7 +9,7 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork): ...@@ -9,7 +9,7 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
def activate(self, p, params_list): def activate(self, p, params_list):
additional = shared.opts.sd_hypernetwork additional = shared.opts.sd_hypernetwork
if additional != "" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0: if additional != "None" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0:
p.all_prompts = [x + f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts] p.all_prompts = [x + f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier])) params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
......
...@@ -284,6 +284,10 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model ...@@ -284,6 +284,10 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
restore_old_hires_fix_params(res) restore_old_hires_fix_params(res)
# Missing RNG means the default was set, which is GPU RNG
if "RNG" not in res:
res["RNG"] = "GPU"
return res return res
...@@ -304,6 +308,7 @@ infotext_to_setting_name_mapping = [ ...@@ -304,6 +308,7 @@ infotext_to_setting_name_mapping = [
('UniPC skip type', 'uni_pc_skip_type'), ('UniPC skip type', 'uni_pc_skip_type'),
('UniPC order', 'uni_pc_order'), ('UniPC order', 'uni_pc_order'),
('UniPC lower order final', 'uni_pc_lower_order_final'), ('UniPC lower order final', 'uni_pc_lower_order_final'),
('RNG', 'randn_source'),
] ]
...@@ -401,9 +406,14 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component, ...@@ -401,9 +406,14 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
button.click( button.click(
fn=paste_func, fn=paste_func,
_js=f"recalculate_prompts_{tabname}",
inputs=[input_comp], inputs=[input_comp],
outputs=[x[0] for x in paste_fields], outputs=[x[0] for x in paste_fields],
) )
button.click(
fn=None,
_js=f"recalculate_prompts_{tabname}",
inputs=[],
outputs=[],
)
...@@ -312,7 +312,7 @@ class Hypernetwork: ...@@ -312,7 +312,7 @@ class Hypernetwork:
def list_hypernetworks(path): def list_hypernetworks(path):
res = {} res = {}
for filename in sorted(glob.iglob(os.path.join(path, '**/*.pt'), recursive=True)): for filename in sorted(glob.iglob(os.path.join(path, '**/*.pt'), recursive=True), key=str.lower):
name = os.path.splitext(os.path.basename(filename))[0] name = os.path.splitext(os.path.basename(filename))[0]
# Prevent a hypothetical "None.pt" from being listed. # Prevent a hypothetical "None.pt" from being listed.
if name != "None": if name != "None":
......
...@@ -261,9 +261,12 @@ def resize_image(resize_mode, im, width, height, upscaler_name=None): ...@@ -261,9 +261,12 @@ def resize_image(resize_mode, im, width, height, upscaler_name=None):
if scale > 1.0: if scale > 1.0:
upscalers = [x for x in shared.sd_upscalers if x.name == upscaler_name] upscalers = [x for x in shared.sd_upscalers if x.name == upscaler_name]
assert len(upscalers) > 0, f"could not find upscaler named {upscaler_name}" if len(upscalers) == 0:
upscaler = shared.sd_upscalers[0]
print(f"could not find upscaler named {upscaler_name or '<empty string>'}, using {upscaler.name} as a fallback")
else:
upscaler = upscalers[0] upscaler = upscalers[0]
im = upscaler.scaler.upscale(im, scale, upscaler.data_path) im = upscaler.scaler.upscale(im, scale, upscaler.data_path)
if im.width != w or im.height != h: if im.width != w or im.height != h:
...@@ -350,6 +353,7 @@ class FilenameGenerator: ...@@ -350,6 +353,7 @@ class FilenameGenerator:
'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False), 'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
'prompt_words': lambda self: self.prompt_words(), 'prompt_words': lambda self: self.prompt_words(),
'hasprompt': lambda self, *args: self.hasprompt(*args), #accept formats:[hasprompt<prompt1|default><prompt2>..] 'hasprompt': lambda self, *args: self.hasprompt(*args), #accept formats:[hasprompt<prompt1|default><prompt2>..]
'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"],
} }
default_time_format = '%Y%m%d%H%M%S' default_time_format = '%Y%m%d%H%M%S'
...@@ -662,6 +666,8 @@ Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]} ...@@ -662,6 +666,8 @@ Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}
def image_data(data): def image_data(data):
import gradio as gr
try: try:
image = Image.open(io.BytesIO(data)) image = Image.open(io.BytesIO(data))
textinfo, _ = read_info_from_image(image) textinfo, _ = read_info_from_image(image)
...@@ -677,7 +683,7 @@ def image_data(data): ...@@ -677,7 +683,7 @@ def image_data(data):
except Exception: except Exception:
pass pass
return '', None return gr.update(), None
def flatten(img, bgcolor): def flatten(img, bgcolor):
......
...@@ -151,12 +151,13 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s ...@@ -151,12 +151,13 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
override_settings=override_settings, override_settings=override_settings,
) )
p.scripts = modules.scripts.scripts_txt2img p.scripts = modules.scripts.scripts_img2img
p.script_args = args p.script_args = args
if shared.cmd_opts.enable_console_prompts: if shared.cmd_opts.enable_console_prompts:
print(f"\nimg2img: {prompt}", file=shared.progress_print_out) print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
if mask:
p.extra_generation_params["Mask blur"] = mask_blur p.extra_generation_params["Mask blur"] = mask_blur
if is_batch: if is_batch:
......
...@@ -32,7 +32,7 @@ def download_default_clip_interrogate_categories(content_dir): ...@@ -32,7 +32,7 @@ def download_default_clip_interrogate_categories(content_dir):
category_types = ["artists", "flavors", "mediums", "movements"] category_types = ["artists", "flavors", "mediums", "movements"]
try: try:
os.makedirs(tmpdir) os.makedirs(tmpdir, exist_ok=True)
for category_type in category_types: for category_type in category_types:
torch.hub.download_url_to_file(f"https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/{category_type}.txt", os.path.join(tmpdir, f"{category_type}.txt")) torch.hub.download_url_to_file(f"https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/{category_type}.txt", os.path.join(tmpdir, f"{category_type}.txt"))
os.rename(tmpdir, content_dir) os.rename(tmpdir, content_dir)
...@@ -41,7 +41,7 @@ def download_default_clip_interrogate_categories(content_dir): ...@@ -41,7 +41,7 @@ def download_default_clip_interrogate_categories(content_dir):
errors.display(e, "downloading default CLIP interrogate categories") errors.display(e, "downloading default CLIP interrogate categories")
finally: finally:
if os.path.exists(tmpdir): if os.path.exists(tmpdir):
os.remove(tmpdir) os.removedirs(tmpdir)
class InterrogateModels: class InterrogateModels:
......
...@@ -55,12 +55,12 @@ def setup_for_low_vram(sd_model, use_medvram): ...@@ -55,12 +55,12 @@ def setup_for_low_vram(sd_model, use_medvram):
if hasattr(sd_model.cond_stage_model, 'model'): if hasattr(sd_model.cond_stage_model, 'model'):
sd_model.cond_stage_model.transformer = sd_model.cond_stage_model.model sd_model.cond_stage_model.transformer = sd_model.cond_stage_model.model
# remove four big modules, cond, first_stage, depth (if applicable), and unet from the model and then # remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model and then
# send the model to GPU. Then put modules back. the modules will be in CPU. # send the model to GPU. Then put modules back. the modules will be in CPU.
stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, getattr(sd_model, 'depth_model', None), sd_model.model stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, getattr(sd_model, 'depth_model', None), getattr(sd_model, 'embedder', None), sd_model.model
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.model = None, None, None, None sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = None, None, None, None, None
sd_model.to(devices.device) sd_model.to(devices.device)
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.model = stored sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = stored
# register hooks for those the first three models # register hooks for those the first three models
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu) sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
...@@ -69,6 +69,8 @@ def setup_for_low_vram(sd_model, use_medvram): ...@@ -69,6 +69,8 @@ def setup_for_low_vram(sd_model, use_medvram):
sd_model.first_stage_model.decode = first_stage_model_decode_wrap sd_model.first_stage_model.decode = first_stage_model_decode_wrap
if sd_model.depth_model: if sd_model.depth_model:
sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu) sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
if sd_model.embedder:
sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
if hasattr(sd_model.cond_stage_model, 'model'): if hasattr(sd_model.cond_stage_model, 'model'):
......
import torch import torch
import platform
from modules import paths from modules import paths
from modules.sd_hijack_utils import CondFunc from modules.sd_hijack_utils import CondFunc
from packaging import version from packaging import version
...@@ -32,6 +33,10 @@ if has_mps: ...@@ -32,6 +33,10 @@ if has_mps:
# MPS fix for randn in torchsde # MPS fix for randn in torchsde
CondFunc('torchsde._brownian.brownian_interval._randn', lambda _, size, dtype, device, seed: torch.randn(size, dtype=dtype, device=torch.device("cpu"), generator=torch.Generator(torch.device("cpu")).manual_seed(int(seed))).to(device), lambda _, size, dtype, device, seed: device.type == 'mps') CondFunc('torchsde._brownian.brownian_interval._randn', lambda _, size, dtype, device, seed: torch.randn(size, dtype=dtype, device=torch.device("cpu"), generator=torch.Generator(torch.device("cpu")).manual_seed(int(seed))).to(device), lambda _, size, dtype, device, seed: device.type == 'mps')
if platform.mac_ver()[0].startswith("13.2."):
# MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124)
CondFunc('torch.nn.functional.linear', lambda _, input, weight, bias: (torch.matmul(input, weight.t()) + bias) if bias is not None else torch.matmul(input, weight.t()), lambda _, input, weight, bias: input.numel() > 10485760)
if version.parse(torch.__version__) < version.parse("1.13"): if version.parse(torch.__version__) < version.parse("1.13"):
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working # PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
...@@ -49,4 +54,6 @@ if has_mps: ...@@ -49,4 +54,6 @@ if has_mps:
CondFunc('torch.cumsum', cumsum_fix_func, None) CondFunc('torch.cumsum', cumsum_fix_func, None)
CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None) CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None) CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
if version.parse(torch.__version__) == version.parse("2.0"):
# MPS workaround for https://github.com/pytorch/pytorch/issues/96113
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda *args, **kwargs: len(args) == 6)
...@@ -4,7 +4,6 @@ import shutil ...@@ -4,7 +4,6 @@ import shutil
import importlib import importlib
from urllib.parse import urlparse from urllib.parse import urlparse
from basicsr.utils.download_util import load_file_from_url
from modules import shared from modules import shared
from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
from modules.paths import script_path, models_path from modules.paths import script_path, models_path
...@@ -59,6 +58,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None ...@@ -59,6 +58,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
if model_url is not None and len(output) == 0: if model_url is not None and len(output) == 0:
if download_name is not None: if download_name is not None:
from basicsr.utils.download_util import load_file_from_url
dl = load_file_from_url(model_url, model_path, True, download_name) dl = load_file_from_url(model_url, model_path, True, download_name)
output.append(dl) output.append(dl)
else: else:
......
import argparse
import os import os
import sys import sys
import modules.safe from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) import modules.safe
# Parse the --data-dir flag first so we can use it as a base for our other argument default values
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",)
cmd_opts_pre = parser.parse_known_args()[0]
data_path = cmd_opts_pre.data_dir
models_path = os.path.join(data_path, "models")
# data_path = cmd_opts_pre.data # data_path = cmd_opts_pre.data
sys.path.insert(0, script_path) sys.path.insert(0, script_path)
......
"""this module defines internal paths used by program and is safe to import before dependencies are installed in launch.py"""
import argparse
import os
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sd_configs_path = os.path.join(script_path, "configs")
sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml")
sd_model_file = os.path.join(script_path, 'model.ckpt')
default_sd_model_file = sd_model_file
# Parse the --data-dir flag first so we can use it as a base for our other argument default values
parser_pre = argparse.ArgumentParser(add_help=False)
parser_pre.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",)
cmd_opts_pre = parser_pre.parse_known_args()[0]
data_path = cmd_opts_pre.data_dir
models_path = os.path.join(data_path, "models")
extensions_dir = os.path.join(data_path, "extensions")
extensions_builtin_dir = os.path.join(script_path, "extensions-builtin")
...@@ -18,9 +18,15 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, ...@@ -18,9 +18,15 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
if extras_mode == 1: if extras_mode == 1:
for img in image_folder: for img in image_folder:
image = Image.open(img) if isinstance(img, Image.Image):
image = img
fn = ''
else:
image = Image.open(os.path.abspath(img.name))
fn = os.path.splitext(img.orig_name)[0]
image_data.append(image) image_data.append(image)
image_names.append(os.path.splitext(img.orig_name)[0]) image_names.append(fn)
elif extras_mode == 2: elif extras_mode == 2:
assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled' assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
assert input_dir, 'input directory not selected' assert input_dir, 'input directory not selected'
......
...@@ -3,6 +3,7 @@ import math ...@@ -3,6 +3,7 @@ import math
import os import os
import sys import sys
import warnings import warnings
import hashlib
import torch import torch
import numpy as np import numpy as np
...@@ -78,11 +79,7 @@ def apply_overlay(image, paste_loc, index, overlays): ...@@ -78,11 +79,7 @@ def apply_overlay(image, paste_loc, index, overlays):
def txt2img_image_conditioning(sd_model, x, width, height): def txt2img_image_conditioning(sd_model, x, width, height):
if sd_model.model.conditioning_key not in {'hybrid', 'concat'}: if sd_model.model.conditioning_key in {'hybrid', 'concat'}: # Inpainting models
# Dummy zero conditioning if we're not using inpainting model.
# Still takes up a bit of memory, but no encoder call.
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
# The "masked-image" in this case will just be all zeros since the entire image is masked. # The "masked-image" in this case will just be all zeros since the entire image is masked.
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device) image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
...@@ -94,6 +91,16 @@ def txt2img_image_conditioning(sd_model, x, width, height): ...@@ -94,6 +91,16 @@ def txt2img_image_conditioning(sd_model, x, width, height):
return image_conditioning return image_conditioning
elif sd_model.model.conditioning_key == "crossattn-adm": # UnCLIP models
return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device)
else:
# Dummy zero conditioning if we're not using inpainting or unclip models.
# Still takes up a bit of memory, but no encoder call.
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
class StableDiffusionProcessing: class StableDiffusionProcessing:
""" """
...@@ -190,6 +197,14 @@ class StableDiffusionProcessing: ...@@ -190,6 +197,14 @@ class StableDiffusionProcessing:
return conditioning_image return conditioning_image
def unclip_image_conditioning(self, source_image):
c_adm = self.sd_model.embedder(source_image)
if self.sd_model.noise_augmentor is not None:
noise_level = 0 # TODO: Allow other noise levels?
c_adm, noise_level_emb = self.sd_model.noise_augmentor(c_adm, noise_level=repeat(torch.tensor([noise_level]).to(c_adm.device), '1 -> b', b=c_adm.shape[0]))
c_adm = torch.cat((c_adm, noise_level_emb), 1)
return c_adm
def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None): def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None):
self.is_using_inpainting_conditioning = True self.is_using_inpainting_conditioning = True
...@@ -241,6 +256,9 @@ class StableDiffusionProcessing: ...@@ -241,6 +256,9 @@ class StableDiffusionProcessing:
if self.sampler.conditioning_key in {'hybrid', 'concat'}: if self.sampler.conditioning_key in {'hybrid', 'concat'}:
return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask) return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
if self.sampler.conditioning_key == "crossattn-adm":
return self.unclip_image_conditioning(source_image)
# Dummy zero conditioning if we're not using inpainting or depth model. # Dummy zero conditioning if we're not using inpainting or depth model.
return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1) return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
...@@ -459,6 +477,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter ...@@ -459,6 +477,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None, "Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
"Clip skip": None if clip_skip <= 1 else clip_skip, "Clip skip": None if clip_skip <= 1 else clip_skip,
"ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta, "ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
"Init image hash": getattr(p, 'init_img_hash', None),
"RNG": (opts.randn_source if opts.randn_source != "GPU" else None)
} }
generation_params.update(p.extra_generation_params) generation_params.update(p.extra_generation_params)
...@@ -622,8 +642,14 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: ...@@ -622,8 +642,14 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
processed = Processed(p, [], p.seed, "") processed = Processed(p, [], p.seed, "")
file.write(processed.infotext(p, 0)) file.write(processed.infotext(p, 0))
uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps, cached_uc) step_multiplier = 1
c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps, cached_c) if not shared.opts.dont_fix_second_order_samplers_schedule:
try:
step_multiplier = 2 if sd_samplers.all_samplers_map.get(p.sampler_name).aliases[0] in ['k_dpmpp_2s_a', 'k_dpmpp_2s_a_ka', 'k_dpmpp_sde', 'k_dpmpp_sde_ka', 'k_dpm_2', 'k_dpm_2_a', 'k_heun'] else 1
except:
pass
uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps * step_multiplier, cached_uc)
c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps * step_multiplier, cached_c)
if len(model_hijack.comments) > 0: if len(model_hijack.comments) > 0:
for comment in model_hijack.comments: for comment in model_hijack.comments:
...@@ -689,6 +715,22 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: ...@@ -689,6 +715,22 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
image.info["parameters"] = text image.info["parameters"] = text
output_images.append(image) output_images.append(image)
if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay:
image_mask = p.mask_for_overlay.convert('RGB')
image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), p.mask_for_overlay.convert('L')).convert('RGBA')
if opts.save_mask:
images.save_image(image_mask, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask")
if opts.save_mask_composite:
images.save_image(image_mask_composite, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask-composite")
if opts.return_mask:
output_images.append(image_mask)
if opts.return_mask_composite:
output_images.append(image_mask_composite)
del x_samples_ddim del x_samples_ddim
devices.torch_gc() devices.torch_gc()
...@@ -974,6 +1016,12 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): ...@@ -974,6 +1016,12 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.color_corrections = [] self.color_corrections = []
imgs = [] imgs = []
for img in self.init_images: for img in self.init_images:
# Save init image
if opts.save_init_img:
self.init_img_hash = hashlib.md5(img.tobytes()).hexdigest()
images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False)
image = images.flatten(img, opts.img2img_background_color) image = images.flatten(img, opts.img2img_background_color)
if crop_region is None and self.resize_mode != 3: if crop_region is None and self.resize_mode != 3:
......
# this code is adapted from the script contributed by anon from /h/ # this code is adapted from the script contributed by anon from /h/
import io
import pickle import pickle
import collections import collections
import sys import sys
...@@ -12,11 +11,9 @@ import _codecs ...@@ -12,11 +11,9 @@ import _codecs
import zipfile import zipfile
import re import re
# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage # PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
def encode(*args): def encode(*args):
out = _codecs.encode(*args) out = _codecs.encode(*args)
return out return out
...@@ -27,7 +24,7 @@ class RestrictedUnpickler(pickle.Unpickler): ...@@ -27,7 +24,7 @@ class RestrictedUnpickler(pickle.Unpickler):
def persistent_load(self, saved_id): def persistent_load(self, saved_id):
assert saved_id[0] == 'storage' assert saved_id[0] == 'storage'
return TypedStorage() return TypedStorage(_internal=True)
def find_class(self, module, name): def find_class(self, module, name):
if self.extra_handler is not None: if self.extra_handler is not None:
......
...@@ -239,7 +239,15 @@ def load_scripts(): ...@@ -239,7 +239,15 @@ def load_scripts():
elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing): elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing):
postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module)) postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
for scriptfile in sorted(scripts_list): def orderby(basedir):
# 1st webui, 2nd extensions-builtin, 3rd extensions
priority = {os.path.join(paths.script_path, "extensions-builtin"):1, paths.script_path:0}
for key in priority:
if basedir.startswith(key):
return priority[key]
return 9999
for scriptfile in sorted(scripts_list, key=lambda x: [orderby(x.basedir), x]):
try: try:
if scriptfile.basedir != paths.script_path: if scriptfile.basedir != paths.script_path:
sys.path = [scriptfile.basedir] + sys.path sys.path = [scriptfile.basedir] + sys.path
...@@ -513,6 +521,18 @@ def reload_scripts(): ...@@ -513,6 +521,18 @@ def reload_scripts():
scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner() scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
def add_classes_to_gradio_component(comp):
"""
this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
"""
comp.elem_classes = ["gradio-" + comp.get_block_name(), *(comp.elem_classes or [])]
if getattr(comp, 'multiselect', False):
comp.elem_classes.append('multiselect')
def IOComponent_init(self, *args, **kwargs): def IOComponent_init(self, *args, **kwargs):
if scripts_current is not None: if scripts_current is not None:
scripts_current.before_component(self, **kwargs) scripts_current.before_component(self, **kwargs)
...@@ -521,6 +541,8 @@ def IOComponent_init(self, *args, **kwargs): ...@@ -521,6 +541,8 @@ def IOComponent_init(self, *args, **kwargs):
res = original_IOComponent_init(self, *args, **kwargs) res = original_IOComponent_init(self, *args, **kwargs)
add_classes_to_gradio_component(self)
script_callbacks.after_component_callback(self, **kwargs) script_callbacks.after_component_callback(self, **kwargs)
if scripts_current is not None: if scripts_current is not None:
...@@ -531,3 +553,15 @@ def IOComponent_init(self, *args, **kwargs): ...@@ -531,3 +553,15 @@ def IOComponent_init(self, *args, **kwargs):
original_IOComponent_init = gr.components.IOComponent.__init__ original_IOComponent_init = gr.components.IOComponent.__init__
gr.components.IOComponent.__init__ = IOComponent_init gr.components.IOComponent.__init__ = IOComponent_init
def BlockContext_init(self, *args, **kwargs):
res = original_BlockContext_init(self, *args, **kwargs)
add_classes_to_gradio_component(self)
return res
original_BlockContext_init = gr.blocks.BlockContext.__init__
gr.blocks.BlockContext.__init__ = BlockContext_init
...@@ -109,7 +109,7 @@ class ScriptPostprocessingRunner: ...@@ -109,7 +109,7 @@ class ScriptPostprocessingRunner:
inputs = [] inputs = []
for script in self.scripts_in_preferred_order(): for script in self.scripts_in_preferred_order():
with gr.Box() as group: with gr.Row() as group:
self.create_script_ui(script, inputs) self.create_script_ui(script, inputs)
script.group = group script.group = group
......
...@@ -337,7 +337,7 @@ def xformers_attention_forward(self, x, context=None, mask=None): ...@@ -337,7 +337,7 @@ def xformers_attention_forward(self, x, context=None, mask=None):
dtype = q.dtype dtype = q.dtype
if shared.opts.upcast_attn: if shared.opts.upcast_attn:
q, k = q.float(), k.float() q, k, v = q.float(), k.float(), v.float()
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v)) out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v))
...@@ -372,7 +372,7 @@ def scaled_dot_product_attention_forward(self, x, context=None, mask=None): ...@@ -372,7 +372,7 @@ def scaled_dot_product_attention_forward(self, x, context=None, mask=None):
dtype = q.dtype dtype = q.dtype
if shared.opts.upcast_attn: if shared.opts.upcast_attn:
q, k = q.float(), k.float() q, k, v = q.float(), k.float(), v.float()
# the output of sdp = (batch, num_heads, seq_len, head_dim) # the output of sdp = (batch, num_heads, seq_len, head_dim)
hidden_states = torch.nn.functional.scaled_dot_product_attention( hidden_states = torch.nn.functional.scaled_dot_product_attention(
......
...@@ -67,7 +67,7 @@ def hijack_ddpm_edit(): ...@@ -67,7 +67,7 @@ def hijack_ddpm_edit():
unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast) CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast) CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
if version.parse(torch.__version__) <= version.parse("1.13.1"): if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available():
CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast) CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast) CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU) CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)
......
...@@ -122,7 +122,7 @@ def list_models(): ...@@ -122,7 +122,7 @@ def list_models():
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file: elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr) print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
for filename in model_list: for filename in sorted(model_list, key=str.lower):
checkpoint_info = CheckpointInfo(filename) checkpoint_info = CheckpointInfo(filename)
checkpoint_info.register() checkpoint_info.register()
...@@ -178,7 +178,7 @@ def select_checkpoint(): ...@@ -178,7 +178,7 @@ def select_checkpoint():
return checkpoint_info return checkpoint_info
chckpoint_dict_replacements = { checkpoint_dict_replacements = {
'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.', 'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.', 'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.', 'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.',
...@@ -186,7 +186,7 @@ chckpoint_dict_replacements = { ...@@ -186,7 +186,7 @@ chckpoint_dict_replacements = {
def transform_checkpoint_dict_key(k): def transform_checkpoint_dict_key(k):
for text, replacement in chckpoint_dict_replacements.items(): for text, replacement in checkpoint_dict_replacements.items():
if k.startswith(text): if k.startswith(text):
k = replacement + k[len(text):] k = replacement + k[len(text):]
...@@ -383,6 +383,14 @@ def repair_config(sd_config): ...@@ -383,6 +383,14 @@ def repair_config(sd_config):
elif shared.cmd_opts.upcast_sampling: elif shared.cmd_opts.upcast_sampling:
sd_config.model.params.unet_config.params.use_fp16 = True sd_config.model.params.unet_config.params.use_fp16 = True
if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla"
# For UnCLIP-L, override the hardcoded karlo directory
if hasattr(sd_config.model.params, "noise_aug_config") and hasattr(sd_config.model.params.noise_aug_config.params, "clip_stats_path"):
karlo_path = os.path.join(paths.models_path, 'karlo')
sd_config.model.params.noise_aug_config.params.clip_stats_path = sd_config.model.params.noise_aug_config.params.clip_stats_path.replace("checkpoints/karlo_models", karlo_path)
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight' sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight' sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
...@@ -494,7 +502,7 @@ def reload_model_weights(sd_model=None, info=None): ...@@ -494,7 +502,7 @@ def reload_model_weights(sd_model=None, info=None):
if sd_model is None or checkpoint_config != sd_model.used_config: if sd_model is None or checkpoint_config != sd_model.used_config:
del sd_model del sd_model
checkpoints_loaded.clear() checkpoints_loaded.clear()
load_model(checkpoint_info, already_loaded_state_dict=state_dict, time_taken_to_load_state_dict=timer.records["load weights from disk"]) load_model(checkpoint_info, already_loaded_state_dict=state_dict)
return shared.sd_model return shared.sd_model
try: try:
...@@ -517,3 +525,23 @@ def reload_model_weights(sd_model=None, info=None): ...@@ -517,3 +525,23 @@ def reload_model_weights(sd_model=None, info=None):
print(f"Weights loaded in {timer.summary()}.") print(f"Weights loaded in {timer.summary()}.")
return sd_model return sd_model
def unload_model_weights(sd_model=None, info=None):
from modules import lowvram, devices, sd_hijack
timer = Timer()
if shared.sd_model:
# shared.sd_model.cond_stage_model.to(devices.cpu)
# shared.sd_model.first_stage_model.to(devices.cpu)
shared.sd_model.to(devices.cpu)
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
shared.sd_model = None
sd_model = None
gc.collect()
devices.torch_gc()
torch.cuda.empty_cache()
print(f"Unloaded weights {timer.summary()}.")
return sd_model
\ No newline at end of file
...@@ -14,6 +14,8 @@ config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml") ...@@ -14,6 +14,8 @@ config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")
config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml") config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml") config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml") config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml")
config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml")
config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml") config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml")
config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml") config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml") config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
...@@ -65,9 +67,14 @@ def is_using_v_parameterization_for_sd2(state_dict): ...@@ -65,9 +67,14 @@ def is_using_v_parameterization_for_sd2(state_dict):
def guess_model_config_from_state_dict(sd, filename): def guess_model_config_from_state_dict(sd, filename):
sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None) sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None)
diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None) diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None)
if sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None: if sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
return config_depth_model return config_depth_model
elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768:
return config_unclip
elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 1024:
return config_unopenclip
if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024: if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024:
if diffusion_model_input.shape[1] == 9: if diffusion_model_input.shape[1] == 9:
......
...@@ -60,3 +60,13 @@ def store_latent(decoded): ...@@ -60,3 +60,13 @@ def store_latent(decoded):
class InterruptedException(BaseException): class InterruptedException(BaseException):
pass pass
if opts.randn_source == "CPU":
import torchsde._brownian.brownian_interval
def torchsde_randn(size, dtype, device, seed):
generator = torch.Generator(devices.cpu).manual_seed(int(seed))
return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
torchsde._brownian.brownian_interval._randn = torchsde_randn
...@@ -70,7 +70,12 @@ class VanillaStableDiffusionSampler: ...@@ -70,7 +70,12 @@ class VanillaStableDiffusionSampler:
# Have to unwrap the inpainting conditioning here to perform pre-processing # Have to unwrap the inpainting conditioning here to perform pre-processing
image_conditioning = None image_conditioning = None
uc_image_conditioning = None
if isinstance(cond, dict): if isinstance(cond, dict):
if self.conditioning_key == "crossattn-adm":
image_conditioning = cond["c_adm"]
uc_image_conditioning = unconditional_conditioning["c_adm"]
else:
image_conditioning = cond["c_concat"][0] image_conditioning = cond["c_concat"][0]
cond = cond["c_crossattn"][0] cond = cond["c_crossattn"][0]
unconditional_conditioning = unconditional_conditioning["c_crossattn"][0] unconditional_conditioning = unconditional_conditioning["c_crossattn"][0]
...@@ -98,6 +103,10 @@ class VanillaStableDiffusionSampler: ...@@ -98,6 +103,10 @@ class VanillaStableDiffusionSampler:
# Wrap the image conditioning back up since the DDIM code can accept the dict directly. # Wrap the image conditioning back up since the DDIM code can accept the dict directly.
# Note that they need to be lists because it just concatenates them later. # Note that they need to be lists because it just concatenates them later.
if image_conditioning is not None: if image_conditioning is not None:
if self.conditioning_key == "crossattn-adm":
cond = {"c_adm": image_conditioning, "c_crossattn": [cond]}
unconditional_conditioning = {"c_adm": uc_image_conditioning, "c_crossattn": [unconditional_conditioning]}
else:
cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]} cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
...@@ -176,6 +185,10 @@ class VanillaStableDiffusionSampler: ...@@ -176,6 +185,10 @@ class VanillaStableDiffusionSampler:
# Wrap the conditioning models with additional image conditioning for inpainting model # Wrap the conditioning models with additional image conditioning for inpainting model
if image_conditioning is not None: if image_conditioning is not None:
if self.conditioning_key == "crossattn-adm":
conditioning = {"c_adm": image_conditioning, "c_crossattn": [conditioning]}
unconditional_conditioning = {"c_adm": torch.zeros_like(image_conditioning), "c_crossattn": [unconditional_conditioning]}
else:
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]} conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
...@@ -195,6 +208,10 @@ class VanillaStableDiffusionSampler: ...@@ -195,6 +208,10 @@ class VanillaStableDiffusionSampler:
# Wrap the conditioning models with additional image conditioning for inpainting model # Wrap the conditioning models with additional image conditioning for inpainting model
# dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape # dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape
if image_conditioning is not None: if image_conditioning is not None:
if self.conditioning_key == "crossattn-adm":
conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_adm": image_conditioning}
unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_adm": torch.zeros_like(image_conditioning)}
else:
conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]} conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]}
unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]} unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]}
......
...@@ -92,14 +92,21 @@ class CFGDenoiser(torch.nn.Module): ...@@ -92,14 +92,21 @@ class CFGDenoiser(torch.nn.Module):
batch_size = len(conds_list) batch_size = len(conds_list)
repeats = [len(conds_list[i]) for i in range(batch_size)] repeats = [len(conds_list[i]) for i in range(batch_size)]
if shared.sd_model.model.conditioning_key == "crossattn-adm":
image_uncond = torch.zeros_like(image_cond)
make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm}
else:
image_uncond = image_cond
make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]}
if not is_edit_model: if not is_edit_model:
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma]) sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond]) image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond])
else: else:
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x]) x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma]) sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond] + [torch.zeros_like(self.init_latent)]) image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)])
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond) denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond)
cfg_denoiser_callback(denoiser_params) cfg_denoiser_callback(denoiser_params)
...@@ -116,13 +123,13 @@ class CFGDenoiser(torch.nn.Module): ...@@ -116,13 +123,13 @@ class CFGDenoiser(torch.nn.Module):
cond_in = torch.cat([tensor, uncond, uncond]) cond_in = torch.cat([tensor, uncond, uncond])
if shared.batch_cond_uncond: if shared.batch_cond_uncond:
x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]}) x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict([cond_in], image_cond_in))
else: else:
x_out = torch.zeros_like(x_in) x_out = torch.zeros_like(x_in)
for batch_offset in range(0, x_out.shape[0], batch_size): for batch_offset in range(0, x_out.shape[0], batch_size):
a = batch_offset a = batch_offset
b = a + batch_size b = a + batch_size
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]}) x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict([cond_in[a:b]], image_cond_in[a:b]))
else: else:
x_out = torch.zeros_like(x_in) x_out = torch.zeros_like(x_in)
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
...@@ -135,9 +142,9 @@ class CFGDenoiser(torch.nn.Module): ...@@ -135,9 +142,9 @@ class CFGDenoiser(torch.nn.Module):
else: else:
c_crossattn = torch.cat([tensor[a:b]], uncond) c_crossattn = torch.cat([tensor[a:b]], uncond)
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": c_crossattn, "c_concat": [image_cond_in[a:b]]}) x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]}) x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict([uncond], image_cond_in[-uncond.shape[0]:]))
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps) denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps)
cfg_denoised_callback(denoised_params) cfg_denoised_callback(denoised_params)
...@@ -183,7 +190,7 @@ class TorchHijack: ...@@ -183,7 +190,7 @@ class TorchHijack:
if noise.shape == x.shape: if noise.shape == x.shape:
return noise return noise
if x.device.type == 'mps': if opts.randn_source == "CPU" or x.device.type == 'mps':
return torch.randn_like(x, device=devices.cpu).to(x.device) return torch.randn_like(x, device=devices.cpu).to(x.device)
else: else:
return torch.randn_like(x) return torch.randn_like(x)
......
This diff is collapsed.
...@@ -152,7 +152,11 @@ class EmbeddingDatabase: ...@@ -152,7 +152,11 @@ class EmbeddingDatabase:
name = data.get('name', name) name = data.get('name', name)
else: else:
data = extract_image_data_embed(embed_image) data = extract_image_data_embed(embed_image)
if data:
name = data.get('name', name) name = data.get('name', name)
else:
# if data is None, means this is not an embeding, just a preview image
return
elif ext in ['.BIN', '.PT']: elif ext in ['.BIN', '.PT']:
data = torch.load(path, map_location="cpu") data = torch.load(path, map_location="cpu")
elif ext in ['.SAFETENSORS']: elif ext in ['.SAFETENSORS']:
......
This diff is collapsed.
...@@ -125,12 +125,12 @@ Requested path was: {f} ...@@ -125,12 +125,12 @@ Requested path was: {f}
with gr.Column(variant='panel', elem_id=f"{tabname}_results"): with gr.Column(variant='panel', elem_id=f"{tabname}_results"):
with gr.Group(elem_id=f"{tabname}_gallery_container"): with gr.Group(elem_id=f"{tabname}_gallery_container"):
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4) result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(columns=4)
generation_info = None generation_info = None
with gr.Column(): with gr.Column():
with gr.Row(elem_id=f"image_buttons_{tabname}"): with gr.Row(elem_id=f"image_buttons_{tabname}", elem_classes="image-buttons"):
open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}') open_folder_button = gr.Button(folder_symbol, visible=not shared.cmd_opts.hide_ui_dir_config)
if tabname != "extras": if tabname != "extras":
save = gr.Button('Save', elem_id=f'save_{tabname}') save = gr.Button('Save', elem_id=f'save_{tabname}')
...@@ -145,11 +145,10 @@ Requested path was: {f} ...@@ -145,11 +145,10 @@ Requested path was: {f}
) )
if tabname != "extras": if tabname != "extras":
with gr.Row():
download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}') download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}')
with gr.Group(): with gr.Group():
html_info = gr.HTML(elem_id=f'html_info_{tabname}') html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext")
html_log = gr.HTML(elem_id=f'html_log_{tabname}') html_log = gr.HTML(elem_id=f'html_log_{tabname}')
generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}') generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}')
...@@ -160,6 +159,7 @@ Requested path was: {f} ...@@ -160,6 +159,7 @@ Requested path was: {f}
_js="function(x, y, z){ return [x, y, selected_gallery_index()] }", _js="function(x, y, z){ return [x, y, selected_gallery_index()] }",
inputs=[generation_info, html_info, html_info], inputs=[generation_info, html_info, html_info],
outputs=[html_info, html_info], outputs=[html_info, html_info],
show_progress=False,
) )
save.click( save.click(
...@@ -195,7 +195,7 @@ Requested path was: {f} ...@@ -195,7 +195,7 @@ Requested path was: {f}
else: else:
html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}') html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}')
html_info = gr.HTML(elem_id=f'html_info_{tabname}') html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext")
html_log = gr.HTML(elem_id=f'html_log_{tabname}') html_log = gr.HTML(elem_id=f'html_log_{tabname}')
paste_field_names = [] paste_field_names = []
......
import gradio as gr import gradio as gr
class ToolButton(gr.Button, gr.components.FormComponent): class FormComponent:
"""Small button with single emoji as text, fits inside gradio forms""" def get_expected_parent(self):
return gr.components.Form
def __init__(self, **kwargs):
super().__init__(variant="tool", **kwargs)
def get_block_name(self): gr.Dropdown.get_expected_parent = FormComponent.get_expected_parent
return "button"
class ToolButtonTop(gr.Button, gr.components.FormComponent): class ToolButton(FormComponent, gr.Button):
"""Small button with single emoji as text, with extra margin at top, fits inside gradio forms""" """Small button with single emoji as text, fits inside gradio forms"""
def __init__(self, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(variant="tool-top", **kwargs) classes = kwargs.pop("elem_classes", [])
super().__init__(*args, elem_classes=["tool", *classes], **kwargs)
def get_block_name(self): def get_block_name(self):
return "button" return "button"
class FormRow(gr.Row, gr.components.FormComponent): class FormRow(FormComponent, gr.Row):
"""Same as gr.Row but fits inside gradio forms""" """Same as gr.Row but fits inside gradio forms"""
def get_block_name(self): def get_block_name(self):
return "row" return "row"
class FormGroup(gr.Group, gr.components.FormComponent): class FormColumn(FormComponent, gr.Column):
"""Same as gr.Column but fits inside gradio forms"""
def get_block_name(self):
return "column"
class FormGroup(FormComponent, gr.Group):
"""Same as gr.Row but fits inside gradio forms""" """Same as gr.Row but fits inside gradio forms"""
def get_block_name(self): def get_block_name(self):
return "group" return "group"
class FormHTML(gr.HTML, gr.components.FormComponent): class FormHTML(FormComponent, gr.HTML):
"""Same as gr.HTML but fits inside gradio forms""" """Same as gr.HTML but fits inside gradio forms"""
def get_block_name(self): def get_block_name(self):
return "html" return "html"
class FormColorPicker(gr.ColorPicker, gr.components.FormComponent): class FormColorPicker(FormComponent, gr.ColorPicker):
"""Same as gr.ColorPicker but fits inside gradio forms""" """Same as gr.ColorPicker but fits inside gradio forms"""
def get_block_name(self): def get_block_name(self):
return "colorpicker" return "colorpicker"
class DropdownMulti(gr.Dropdown): class DropdownMulti(FormComponent, gr.Dropdown):
"""Same as gr.Dropdown but always multiselect""" """Same as gr.Dropdown but always multiselect"""
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(multiselect=True, **kwargs) super().__init__(multiselect=True, **kwargs)
def get_block_name(self): def get_block_name(self):
return "dropdown" return "dropdown"
class DropdownEditable(FormComponent, gr.Dropdown):
"""Same as gr.Dropdown but allows editing value"""
def __init__(self, **kwargs):
super().__init__(allow_custom_value=True, **kwargs)
def get_block_name(self):
return "dropdown"
This diff is collapsed.
...@@ -2,8 +2,10 @@ import glob ...@@ -2,8 +2,10 @@ import glob
import os.path import os.path
import urllib.parse import urllib.parse
from pathlib import Path from pathlib import Path
from PIL import PngImagePlugin
from modules import shared from modules import shared
from modules.images import read_info_from_image
import gradio as gr import gradio as gr
import json import json
import html import html
...@@ -22,8 +24,7 @@ def register_page(page): ...@@ -22,8 +24,7 @@ def register_page(page):
allowed_dirs.update(set(sum([x.allowed_directories_for_previews() for x in extra_pages], []))) allowed_dirs.update(set(sum([x.allowed_directories_for_previews() for x in extra_pages], [])))
def add_pages_to_demo(app): def fetch_file(filename: str = ""):
def fetch_file(filename: str = ""):
from starlette.responses import FileResponse from starlette.responses import FileResponse
if not any([Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs]): if not any([Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs]):
...@@ -36,7 +37,24 @@ def add_pages_to_demo(app): ...@@ -36,7 +37,24 @@ def add_pages_to_demo(app):
# would profit from returning 304 # would profit from returning 304
return FileResponse(filename, headers={"Accept-Ranges": "bytes"}) return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
def get_metadata(page: str = "", item: str = ""):
from starlette.responses import JSONResponse
page = next(iter([x for x in extra_pages if x.name == page]), None)
if page is None:
return JSONResponse({})
metadata = page.metadata.get(item)
if metadata is None:
return JSONResponse({})
return JSONResponse({"metadata": metadata})
def add_pages_to_demo(app):
app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"]) app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"])
app.add_api_route("/sd_extra_networks/metadata", get_metadata, methods=["GET"])
class ExtraNetworksPage: class ExtraNetworksPage:
...@@ -45,6 +63,7 @@ class ExtraNetworksPage: ...@@ -45,6 +63,7 @@ class ExtraNetworksPage:
self.name = title.lower() self.name = title.lower()
self.card_page = shared.html("extra-networks-card.html") self.card_page = shared.html("extra-networks-card.html")
self.allow_negative_prompt = False self.allow_negative_prompt = False
self.metadata = {}
def refresh(self): def refresh(self):
pass pass
...@@ -66,6 +85,8 @@ class ExtraNetworksPage: ...@@ -66,6 +85,8 @@ class ExtraNetworksPage:
view = shared.opts.extra_networks_default_view view = shared.opts.extra_networks_default_view
items_html = '' items_html = ''
self.metadata = {}
subdirs = {} subdirs = {}
for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]: for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]:
for x in glob.glob(os.path.join(parentdir, '**/*'), recursive=True): for x in glob.glob(os.path.join(parentdir, '**/*'), recursive=True):
...@@ -86,12 +107,16 @@ class ExtraNetworksPage: ...@@ -86,12 +107,16 @@ class ExtraNetworksPage:
subdirs = {"": 1, **subdirs} subdirs = {"": 1, **subdirs}
subdirs_html = "".join([f""" subdirs_html = "".join([f"""
<button class='gr-button gr-button-lg gr-button-secondary{" search-all" if subdir=="" else ""}' onclick='extraNetworksSearchButton("{tabname}_extra_tabs", event)'> <button class='lg secondary gradio-button custom-button{" search-all" if subdir=="" else ""}' onclick='extraNetworksSearchButton("{tabname}_extra_tabs", event)'>
{html.escape(subdir if subdir!="" else "all")} {html.escape(subdir if subdir!="" else "all")}
</button> </button>
""" for subdir in subdirs]) """ for subdir in subdirs])
for item in self.list_items(): for item in self.list_items():
metadata = item.get("metadata")
if metadata:
self.metadata[item["name"]] = metadata
items_html += self.create_html_for_item(item, tabname) items_html += self.create_html_for_item(item, tabname)
if items_html == '': if items_html == '':
...@@ -124,14 +149,16 @@ class ExtraNetworksPage: ...@@ -124,14 +149,16 @@ class ExtraNetworksPage:
if onclick is None: if onclick is None:
onclick = '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"' onclick = '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
height = f"height: {shared.opts.extra_networks_card_height}px;" if shared.opts.extra_networks_card_height else ''
width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else ''
background_image = f"background-image: url(\"{html.escape(preview)}\");" if preview else ''
metadata_button = "" metadata_button = ""
metadata = item.get("metadata") metadata = item.get("metadata")
if metadata: if metadata:
metadata_onclick = '"' + html.escape(f"""extraNetworksShowMetadata({json.dumps(metadata)}); return false;""") + '"' metadata_button = f"<div class='metadata-button' title='Show metadata' onclick='extraNetworksRequestMetadata(event, {json.dumps(self.name)}, {json.dumps(item['name'])})'></div>"
metadata_button = f"<div class='metadata-button' title='Show metadata' onclick={metadata_onclick}></div>"
args = { args = {
"preview_html": "style='background-image: url(\"" + html.escape(preview) + "\")'" if preview else '', "style": f"'{height}{width}{background_image}'",
"prompt": item.get("prompt", None), "prompt": item.get("prompt", None),
"tabname": json.dumps(tabname), "tabname": json.dumps(tabname),
"local_preview": json.dumps(item["local_preview"]), "local_preview": json.dumps(item["local_preview"]),
...@@ -215,6 +242,7 @@ def create_ui(container, button, tabname): ...@@ -215,6 +242,7 @@ def create_ui(container, button, tabname):
with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs: with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs:
for page in ui.stored_extra_pages: for page in ui.stored_extra_pages:
with gr.Tab(page.title): with gr.Tab(page.title):
page_elem = gr.HTML(page.create_html(ui.tabname)) page_elem = gr.HTML(page.create_html(ui.tabname))
ui.pages.append(page_elem) ui.pages.append(page_elem)
...@@ -226,10 +254,10 @@ def create_ui(container, button, tabname): ...@@ -226,10 +254,10 @@ def create_ui(container, button, tabname):
def toggle_visibility(is_visible): def toggle_visibility(is_visible):
is_visible = not is_visible is_visible = not is_visible
return is_visible, gr.update(visible=is_visible) return is_visible, gr.update(visible=is_visible), gr.update(variant=("secondary-down" if is_visible else "secondary"))
state_visible = gr.State(value=False) state_visible = gr.State(value=False)
button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container]) button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container, button])
def refresh(): def refresh():
res = [] res = []
...@@ -264,6 +292,7 @@ def setup_ui(ui, gallery): ...@@ -264,6 +292,7 @@ def setup_ui(ui, gallery):
img_info = images[index if index >= 0 else 0] img_info = images[index if index >= 0 else 0]
image = image_from_url_text(img_info) image = image_from_url_text(img_info)
geninfo, items = read_info_from_image(image)
is_allowed = False is_allowed = False
for extra_page in ui.stored_extra_pages: for extra_page in ui.stored_extra_pages:
...@@ -273,6 +302,11 @@ def setup_ui(ui, gallery): ...@@ -273,6 +302,11 @@ def setup_ui(ui, gallery):
assert is_allowed, f'writing to {filename} is not allowed' assert is_allowed, f'writing to {filename} is not allowed'
if geninfo:
pnginfo_data = PngImagePlugin.PngInfo()
pnginfo_data.add_text('parameters', geninfo)
image.save(filename, pnginfo=pnginfo_data)
else:
image.save(filename) image.save(filename)
return [page.create_html(ui.tabname) for page in ui.stored_extra_pages] return [page.create_html(ui.tabname) for page in ui.stored_extra_pages]
......
...@@ -13,7 +13,7 @@ def create_ui(): ...@@ -13,7 +13,7 @@ def create_ui():
extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image") extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image")
with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab") as tab_batch: with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab") as tab_batch:
image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch") image_batch = gr.Files(label="Batch Process", interactive=True, elem_id="extras_image_batch")
with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab") as tab_batch_dir: with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab") as tab_batch_dir:
extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir") extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir")
......
blendmodes==2022 blendmodes==2022
transformers==4.25.1 transformers==4.25.1
accelerate==0.12.0 accelerate==0.18.0
basicsr==1.4.2 basicsr==1.4.2
gfpgan==1.3.8 gfpgan==1.3.8
gradio==3.16.2 gradio==3.27
numpy==1.23.3 numpy==1.23.5
Pillow==9.4.0 Pillow==9.4.0
realesrgan==0.3.0 realesrgan==0.3.0
torch torch
omegaconf==2.2.3 omegaconf==2.2.3
pytorch_lightning==1.7.6 pytorch_lightning==1.9.4
scikit-image==0.19.2 scikit-image==0.19.2
fonts fonts
font-roboto font-roboto
...@@ -25,6 +25,6 @@ lark==1.1.2 ...@@ -25,6 +25,6 @@ lark==1.1.2
inflection==0.5.1 inflection==0.5.1
GitPython==3.1.30 GitPython==3.1.30
torchsde==0.2.5 torchsde==0.2.5
safetensors==0.2.7 safetensors==0.3.1
httpcore<=0.15 httpcore<=0.15
fastapi==0.94.0 fastapi==0.94.0
function gradioApp() { function gradioApp() {
const elems = document.getElementsByTagName('gradio-app') const elems = document.getElementsByTagName('gradio-app')
const gradioShadowRoot = elems.length == 0 ? null : elems[0].shadowRoot const elem = elems.length == 0 ? document : elems[0]
return !!gradioShadowRoot ? gradioShadowRoot : document;
if (elem !== document) elem.getElementById = function(id){ return document.getElementById(id) }
return elem.shadowRoot ? elem.shadowRoot : elem
} }
function get_uiCurrentTab() { function get_uiCurrentTab() {
......
import modules.scripts as scripts import modules.scripts as scripts
import gradio as gr import gradio as gr
import ast
import copy
from modules.processing import Processed from modules.processing import Processed
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
def convertExpr2Expression(expr):
expr.lineno = 0
expr.col_offset = 0
result = ast.Expression(expr.value, lineno=0, col_offset = 0)
return result
def exec_with_return(code, module):
"""
like exec() but can return values
https://stackoverflow.com/a/52361938/5862977
"""
code_ast = ast.parse(code)
init_ast = copy.deepcopy(code_ast)
init_ast.body = code_ast.body[:-1]
last_ast = copy.deepcopy(code_ast)
last_ast.body = code_ast.body[-1:]
exec(compile(init_ast, "<ast>", "exec"), module.__dict__)
if type(last_ast.body[0]) == ast.Expr:
return eval(compile(convertExpr2Expression(last_ast.body[0]), "<ast>", "eval"), module.__dict__)
else:
exec(compile(last_ast, "<ast>", "exec"), module.__dict__)
class Script(scripts.Script): class Script(scripts.Script):
def title(self): def title(self):
...@@ -13,12 +44,23 @@ class Script(scripts.Script): ...@@ -13,12 +44,23 @@ class Script(scripts.Script):
return cmd_opts.allow_code return cmd_opts.allow_code
def ui(self, is_img2img): def ui(self, is_img2img):
code = gr.Textbox(label="Python code", lines=1, elem_id=self.elem_id("code")) example = """from modules.processing import process_images
p.width = 768
p.height = 768
p.batch_size = 2
p.steps = 10
return process_images(p)
"""
return [code] code = gr.Code(value=example, language="python", label="Python code", elem_id=self.elem_id("code"))
indent_level = gr.Number(label='Indent level', value=2, precision=0, elem_id=self.elem_id("indent_level"))
return [code, indent_level]
def run(self, p, code): def run(self, p, code, indent_level):
assert cmd_opts.allow_code, '--allow-code option must be enabled' assert cmd_opts.allow_code, '--allow-code option must be enabled'
display_result_data = [[], -1, ""] display_result_data = [[], -1, ""]
...@@ -29,13 +71,20 @@ class Script(scripts.Script): ...@@ -29,13 +71,20 @@ class Script(scripts.Script):
display_result_data[2] = i display_result_data[2] = i
from types import ModuleType from types import ModuleType
compiled = compile(code, '', 'exec')
module = ModuleType("testmodule") module = ModuleType("testmodule")
module.__dict__.update(globals()) module.__dict__.update(globals())
module.p = p module.p = p
module.display = display module.display = display
exec(compiled, module.__dict__)
return Processed(p, *display_result_data) indent = " " * indent_level
indented = code.replace('\n', '\n' + indent)
body = f"""def __webuitemp__():
{indent}{indented}
__webuitemp__()"""
result = exec_with_return(body, module)
if isinstance(result, Processed):
return result
return Processed(p, *display_result_data)
...@@ -6,23 +6,21 @@ from tqdm import trange ...@@ -6,23 +6,21 @@ from tqdm import trange
import modules.scripts as scripts import modules.scripts as scripts
import gradio as gr import gradio as gr
from modules import processing, shared, sd_samplers, prompt_parser, sd_samplers_common from modules import processing, shared, sd_samplers, sd_samplers_common
from modules.processing import Processed
from modules.shared import opts, cmd_opts, state
import torch import torch
import k_diffusion as K import k_diffusion as K
from PIL import Image
from torch import autocast
from einops import rearrange, repeat
def find_noise_for_image(p, cond, uncond, cfg_scale, steps): def find_noise_for_image(p, cond, uncond, cfg_scale, steps):
x = p.init_latent x = p.init_latent
s_in = x.new_ones([x.shape[0]]) s_in = x.new_ones([x.shape[0]])
if shared.sd_model.parameterization == "v":
dnw = K.external.CompVisVDenoiser(shared.sd_model)
skip = 1
else:
dnw = K.external.CompVisDenoiser(shared.sd_model) dnw = K.external.CompVisDenoiser(shared.sd_model)
skip = 0
sigmas = dnw.get_sigmas(steps).flip(0) sigmas = dnw.get_sigmas(steps).flip(0)
shared.state.sampling_steps = steps shared.state.sampling_steps = steps
...@@ -37,7 +35,7 @@ def find_noise_for_image(p, cond, uncond, cfg_scale, steps): ...@@ -37,7 +35,7 @@ def find_noise_for_image(p, cond, uncond, cfg_scale, steps):
image_conditioning = torch.cat([p.image_conditioning] * 2) image_conditioning = torch.cat([p.image_conditioning] * 2)
cond_in = {"c_concat": [image_conditioning], "c_crossattn": [cond_in]} cond_in = {"c_concat": [image_conditioning], "c_crossattn": [cond_in]}
c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)] c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)[skip:]]
t = dnw.sigma_to_t(sigma_in) t = dnw.sigma_to_t(sigma_in)
eps = shared.sd_model.apply_model(x_in * c_in, t, cond=cond_in) eps = shared.sd_model.apply_model(x_in * c_in, t, cond=cond_in)
...@@ -69,7 +67,12 @@ def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps): ...@@ -69,7 +67,12 @@ def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps):
x = p.init_latent x = p.init_latent
s_in = x.new_ones([x.shape[0]]) s_in = x.new_ones([x.shape[0]])
if shared.sd_model.parameterization == "v":
dnw = K.external.CompVisVDenoiser(shared.sd_model)
skip = 1
else:
dnw = K.external.CompVisDenoiser(shared.sd_model) dnw = K.external.CompVisDenoiser(shared.sd_model)
skip = 0
sigmas = dnw.get_sigmas(steps).flip(0) sigmas = dnw.get_sigmas(steps).flip(0)
shared.state.sampling_steps = steps shared.state.sampling_steps = steps
...@@ -84,7 +87,7 @@ def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps): ...@@ -84,7 +87,7 @@ def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps):
image_conditioning = torch.cat([p.image_conditioning] * 2) image_conditioning = torch.cat([p.image_conditioning] * 2)
cond_in = {"c_concat": [image_conditioning], "c_crossattn": [cond_in]} cond_in = {"c_concat": [image_conditioning], "c_crossattn": [cond_in]}
c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)] c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)[skip:]]
if i == 1: if i == 1:
t = dnw.sigma_to_t(torch.cat([sigmas[i] * s_in] * 2)) t = dnw.sigma_to_t(torch.cat([sigmas[i] * s_in] * 2))
...@@ -213,4 +216,3 @@ class Script(scripts.Script): ...@@ -213,4 +216,3 @@ class Script(scripts.Script):
processed = processing.process_images(p) processed = processing.process_images(p)
return processed return processed
import numpy as np import math
from tqdm import trange
import modules.scripts as scripts
import gradio as gr import gradio as gr
import modules.scripts as scripts
from modules import processing, shared, sd_samplers, images from modules import deepbooru, images, processing, shared
from modules.processing import Processed from modules.processing import Processed
from modules.sd_samplers import samplers from modules.shared import opts, state
from modules.shared import opts, cmd_opts, state
from modules import deepbooru
class Script(scripts.Script): class Script(scripts.Script):
...@@ -20,39 +16,65 @@ class Script(scripts.Script): ...@@ -20,39 +16,65 @@ class Script(scripts.Script):
def ui(self, is_img2img): def ui(self, is_img2img):
loops = gr.Slider(minimum=1, maximum=32, step=1, label='Loops', value=4, elem_id=self.elem_id("loops")) loops = gr.Slider(minimum=1, maximum=32, step=1, label='Loops', value=4, elem_id=self.elem_id("loops"))
denoising_strength_change_factor = gr.Slider(minimum=0.9, maximum=1.1, step=0.01, label='Denoising strength change factor', value=1, elem_id=self.elem_id("denoising_strength_change_factor")) final_denoising_strength = gr.Slider(minimum=0, maximum=1, step=0.01, label='Final denoising strength', value=0.5, elem_id=self.elem_id("final_denoising_strength"))
denoising_curve = gr.Dropdown(label="Denoising strength curve", choices=["Aggressive", "Linear", "Lazy"], value="Linear")
append_interrogation = gr.Dropdown(label="Append interrogated prompt at each iteration", choices=["None", "CLIP", "DeepBooru"], value="None") append_interrogation = gr.Dropdown(label="Append interrogated prompt at each iteration", choices=["None", "CLIP", "DeepBooru"], value="None")
return [loops, denoising_strength_change_factor, append_interrogation] return [loops, final_denoising_strength, denoising_curve, append_interrogation]
def run(self, p, loops, denoising_strength_change_factor, append_interrogation): def run(self, p, loops, final_denoising_strength, denoising_curve, append_interrogation):
processing.fix_seed(p) processing.fix_seed(p)
batch_count = p.n_iter batch_count = p.n_iter
p.extra_generation_params = { p.extra_generation_params = {
"Denoising strength change factor": denoising_strength_change_factor, "Final denoising strength": final_denoising_strength,
"Denoising curve": denoising_curve
} }
p.batch_size = 1 p.batch_size = 1
p.n_iter = 1 p.n_iter = 1
output_images, info = None, None info = None
initial_seed = None initial_seed = None
initial_info = None initial_info = None
initial_denoising_strength = p.denoising_strength
grids = [] grids = []
all_images = [] all_images = []
original_init_image = p.init_images original_init_image = p.init_images
original_prompt = p.prompt original_prompt = p.prompt
original_inpainting_fill = p.inpainting_fill
state.job_count = loops * batch_count state.job_count = loops * batch_count
initial_color_corrections = [processing.setup_color_correction(p.init_images[0])] initial_color_corrections = [processing.setup_color_correction(p.init_images[0])]
for n in range(batch_count): def calculate_denoising_strength(loop):
strength = initial_denoising_strength
if loops == 1:
return strength
progress = loop / (loops - 1)
if denoising_curve == "Aggressive":
strength = math.sin((progress) * math.pi * 0.5)
elif denoising_curve == "Lazy":
strength = 1 - math.cos((progress) * math.pi * 0.5)
else:
strength = progress
change = (final_denoising_strength - initial_denoising_strength) * strength
return initial_denoising_strength + change
history = [] history = []
for n in range(batch_count):
# Reset to original init image at the start of each batch # Reset to original init image at the start of each batch
p.init_images = original_init_image p.init_images = original_init_image
# Reset to original denoising strength
p.denoising_strength = initial_denoising_strength
last_image = None
for i in range(loops): for i in range(loops):
p.n_iter = 1 p.n_iter = 1
p.batch_size = 1 p.batch_size = 1
...@@ -72,25 +94,45 @@ class Script(scripts.Script): ...@@ -72,25 +94,45 @@ class Script(scripts.Script):
processed = processing.process_images(p) processed = processing.process_images(p)
# Generation cancelled.
if state.interrupted:
break
if initial_seed is None: if initial_seed is None:
initial_seed = processed.seed initial_seed = processed.seed
initial_info = processed.info initial_info = processed.info
init_img = processed.images[0]
p.init_images = [init_img]
p.seed = processed.seed + 1 p.seed = processed.seed + 1
p.denoising_strength = min(max(p.denoising_strength * denoising_strength_change_factor, 0.1), 1) p.denoising_strength = calculate_denoising_strength(i + 1)
history.append(processed.images[0])
if state.skipped:
break
last_image = processed.images[0]
p.init_images = [last_image]
p.inpainting_fill = 1 # Set "masked content" to "original" for next loop.
if batch_count == 1:
history.append(last_image)
all_images.append(last_image)
if batch_count > 1 and not state.skipped and not state.interrupted:
history.append(last_image)
all_images.append(last_image)
p.inpainting_fill = original_inpainting_fill
if state.interrupted:
break
if len(history) > 1:
grid = images.image_grid(history, rows=1) grid = images.image_grid(history, rows=1)
if opts.grid_save: if opts.grid_save:
images.save_image(grid, p.outpath_grids, "grid", initial_seed, p.prompt, opts.grid_format, info=info, short_filename=not opts.grid_extended_filename, grid=True, p=p) images.save_image(grid, p.outpath_grids, "grid", initial_seed, p.prompt, opts.grid_format, info=info, short_filename=not opts.grid_extended_filename, grid=True, p=p)
if opts.return_grid:
grids.append(grid) grids.append(grid)
all_images += history
if opts.return_grid:
all_images = grids + all_images all_images = grids + all_images
processed = Processed(p, all_images, initial_seed, initial_info) processed = Processed(p, all_images, initial_seed, initial_info)
......
...@@ -4,8 +4,8 @@ import numpy as np ...@@ -4,8 +4,8 @@ import numpy as np
from modules import scripts_postprocessing, shared from modules import scripts_postprocessing, shared
import gradio as gr import gradio as gr
from modules.ui_components import FormRow from modules.ui_components import FormRow, ToolButton
from modules.ui import switch_values_symbol
upscale_cache = {} upscale_cache = {}
...@@ -17,14 +17,19 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing): ...@@ -17,14 +17,19 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
def ui(self): def ui(self):
selected_tab = gr.State(value=0) selected_tab = gr.State(value=0)
with gr.Column():
with FormRow():
with gr.Tabs(elem_id="extras_resize_mode"): with gr.Tabs(elem_id="extras_resize_mode"):
with gr.TabItem('Scale by', elem_id="extras_scale_by_tab") as tab_scale_by: with gr.TabItem('Scale by', elem_id="extras_scale_by_tab") as tab_scale_by:
upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize") upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize")
with gr.TabItem('Scale to', elem_id="extras_scale_to_tab") as tab_scale_to: with gr.TabItem('Scale to', elem_id="extras_scale_to_tab") as tab_scale_to:
with FormRow(): with FormRow():
upscaling_resize_w = gr.Number(label="Width", value=512, precision=0, elem_id="extras_upscaling_resize_w") with gr.Column(elem_id="upscaling_column_size", scale=4):
upscaling_resize_h = gr.Number(label="Height", value=512, precision=0, elem_id="extras_upscaling_resize_h") upscaling_resize_w = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="extras_upscaling_resize_w")
upscaling_resize_h = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="extras_upscaling_resize_h")
with gr.Column(elem_id="upscaling_dimensions_row", scale=1, elem_classes="dimensions-tools"):
upscaling_res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="upscaling_res_switch_btn")
upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop") upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop")
with FormRow(): with FormRow():
...@@ -34,6 +39,7 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing): ...@@ -34,6 +39,7 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
extras_upscaler_2 = gr.Dropdown(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name) extras_upscaler_2 = gr.Dropdown(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)
extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=0.0, elem_id="extras_upscaler_2_visibility") extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=0.0, elem_id="extras_upscaler_2_visibility")
upscaling_res_switch_btn.click(lambda w, h: (h, w), inputs=[upscaling_resize_w, upscaling_resize_h], outputs=[upscaling_resize_w, upscaling_resize_h], show_progress=False)
tab_scale_by.select(fn=lambda: 0, inputs=[], outputs=[selected_tab]) tab_scale_by.select(fn=lambda: 0, inputs=[], outputs=[selected_tab])
tab_scale_to.select(fn=lambda: 1, inputs=[], outputs=[selected_tab]) tab_scale_to.select(fn=lambda: 1, inputs=[], outputs=[selected_tab])
......
This diff is collapsed.
This diff is collapsed.
...@@ -11,7 +11,7 @@ fi ...@@ -11,7 +11,7 @@ fi
export install_dir="$HOME" export install_dir="$HOME"
export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --no-half-vae --use-cpu interrogate" export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --no-half-vae --use-cpu interrogate"
export TORCH_COMMAND="pip install torch==1.12.1 torchvision==0.13.1" export TORCH_COMMAND="pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu118"
export K_DIFFUSION_REPO="https://github.com/brkirch/k-diffusion.git" export K_DIFFUSION_REPO="https://github.com/brkirch/k-diffusion.git"
export K_DIFFUSION_COMMIT_HASH="51c9778f269cedb55a4d88c79c0246d35bdadb71" export K_DIFFUSION_COMMIT_HASH="51c9778f269cedb55a4d88c79c0246d35bdadb71"
export PYTORCH_ENABLE_MPS_FALLBACK=1 export PYTORCH_ENABLE_MPS_FALLBACK=1
......
...@@ -43,4 +43,7 @@ ...@@ -43,4 +43,7 @@
# Uncomment to enable accelerated launch # Uncomment to enable accelerated launch
#export ACCELERATE="True" #export ACCELERATE="True"
# Uncomment to disable TCMalloc
#export NO_TCMALLOC="True"
########################################### ###########################################
...@@ -4,6 +4,7 @@ import time ...@@ -4,6 +4,7 @@ import time
import importlib import importlib
import signal import signal
import re import re
import warnings
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware from fastapi.middleware.gzip import GZipMiddleware
...@@ -17,6 +18,11 @@ from modules import paths, timer, import_hook, errors ...@@ -17,6 +18,11 @@ from modules import paths, timer, import_hook, errors
startup_timer = timer.Timer() startup_timer = timer.Timer()
import torch import torch
import pytorch_lightning # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them
warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning")
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
startup_timer.record("import torch") startup_timer.record("import torch")
import gradio import gradio
...@@ -64,11 +70,51 @@ else: ...@@ -64,11 +70,51 @@ else:
server_name = "0.0.0.0" if cmd_opts.listen else None server_name = "0.0.0.0" if cmd_opts.listen else None
def fix_asyncio_event_loop_policy():
"""
The default `asyncio` event loop policy only automatically creates
event loops in the main threads. Other threads must create event
loops explicitly or `asyncio.get_event_loop` (and therefore
`.IOLoop.current`) will fail. Installing this policy allows event
loops to be created automatically on any thread, matching the
behavior of Tornado versions prior to 5.0 (or 5.0 on Python 2).
"""
import asyncio
if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy"):
# "Any thread" and "selector" should be orthogonal, but there's not a clean
# interface for composing policies so pick the right base.
_BasePolicy = asyncio.WindowsSelectorEventLoopPolicy # type: ignore
else:
_BasePolicy = asyncio.DefaultEventLoopPolicy
class AnyThreadEventLoopPolicy(_BasePolicy): # type: ignore
"""Event loop policy that allows loop creation on any thread.
Usage::
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
"""
def get_event_loop(self) -> asyncio.AbstractEventLoop:
try:
return super().get_event_loop()
except (RuntimeError, AssertionError):
# This was an AssertionError in python 3.4.2 (which ships with debian jessie)
# and changed to a RuntimeError in 3.4.3.
# "There is no current event loop in thread %r"
loop = self.new_event_loop()
self.set_event_loop(loop)
return loop
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
def check_versions(): def check_versions():
if shared.cmd_opts.skip_version_check: if shared.cmd_opts.skip_version_check:
return return
expected_torch_version = "1.13.1" expected_torch_version = "2.0.0"
if version.parse(torch.__version__) < version.parse(expected_torch_version): if version.parse(torch.__version__) < version.parse(expected_torch_version):
errors.print_error_explanation(f""" errors.print_error_explanation(f"""
...@@ -81,7 +127,7 @@ there are reports of issues with training tab on the latest version. ...@@ -81,7 +127,7 @@ there are reports of issues with training tab on the latest version.
Use --skip-version-check commandline argument to disable this check. Use --skip-version-check commandline argument to disable this check.
""".strip()) """.strip())
expected_xformers_version = "0.0.16rc425" expected_xformers_version = "0.0.17"
if shared.xformers_available: if shared.xformers_available:
import xformers import xformers
...@@ -96,6 +142,8 @@ Use --skip-version-check commandline argument to disable this check. ...@@ -96,6 +142,8 @@ Use --skip-version-check commandline argument to disable this check.
def initialize(): def initialize():
fix_asyncio_event_loop_policy()
check_versions() check_versions()
extensions.list_extensions() extensions.list_extensions()
...@@ -123,9 +171,6 @@ def initialize(): ...@@ -123,9 +171,6 @@ def initialize():
modules.scripts.load_scripts() modules.scripts.load_scripts()
startup_timer.record("load scripts") startup_timer.record("load scripts")
modelloader.load_upscalers()
startup_timer.record("load upscalers")
modules.sd_vae.refresh_vae_list() modules.sd_vae.refresh_vae_list()
startup_timer.record("refresh VAE") startup_timer.record("refresh VAE")
...@@ -147,6 +192,7 @@ def initialize(): ...@@ -147,6 +192,7 @@ def initialize():
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed) shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
startup_timer.record("opts onchange") startup_timer.record("opts onchange")
shared.reload_hypernetworks() shared.reload_hypernetworks()
...@@ -240,7 +286,7 @@ def webui(): ...@@ -240,7 +286,7 @@ def webui():
shared.demo = modules.ui.create_ui() shared.demo = modules.ui.create_ui()
startup_timer.record("create ui") startup_timer.record("create ui")
if cmd_opts.gradio_queue: if not cmd_opts.no_gradio_queue:
shared.demo.queue(64) shared.demo.queue(64)
gradio_auth_creds = [] gradio_auth_creds = []
......
...@@ -23,7 +23,7 @@ fi ...@@ -23,7 +23,7 @@ fi
# Install directory without trailing slash # Install directory without trailing slash
if [[ -z "${install_dir}" ]] if [[ -z "${install_dir}" ]]
then then
install_dir="/home/$(whoami)" install_dir="${HOME}"
fi fi
# Name of the subdirectory (defaults to stable-diffusion-webui) # Name of the subdirectory (defaults to stable-diffusion-webui)
...@@ -172,15 +172,30 @@ else ...@@ -172,15 +172,30 @@ else
exit 1 exit 1
fi fi
# Try using TCMalloc on Linux
prepare_tcmalloc() {
if [[ "${OSTYPE}" == "linux"* ]] && [[ -z "${NO_TCMALLOC}" ]] && [[ -z "${LD_PRELOAD}" ]]; then
TCMALLOC="$(ldconfig -p | grep -Po "libtcmalloc.so.\d" | head -n 1)"
if [[ ! -z "${TCMALLOC}" ]]; then
echo "Using TCMalloc: ${TCMALLOC}"
export LD_PRELOAD="${TCMALLOC}"
else
printf "\e[1m\e[31mCannot locate TCMalloc (improves CPU memory usage)\e[0m\n"
fi
fi
}
if [[ ! -z "${ACCELERATE}" ]] && [ ${ACCELERATE}="True" ] && [ -x "$(command -v accelerate)" ] if [[ ! -z "${ACCELERATE}" ]] && [ ${ACCELERATE}="True" ] && [ -x "$(command -v accelerate)" ]
then then
printf "\n%s\n" "${delimiter}" printf "\n%s\n" "${delimiter}"
printf "Accelerating launch.py..." printf "Accelerating launch.py..."
printf "\n%s\n" "${delimiter}" printf "\n%s\n" "${delimiter}"
prepare_tcmalloc
exec accelerate launch --num_cpu_threads_per_process=6 "${LAUNCH_SCRIPT}" "$@" exec accelerate launch --num_cpu_threads_per_process=6 "${LAUNCH_SCRIPT}" "$@"
else else
printf "\n%s\n" "${delimiter}" printf "\n%s\n" "${delimiter}"
printf "Launching launch.py..." printf "Launching launch.py..."
printf "\n%s\n" "${delimiter}" printf "\n%s\n" "${delimiter}"
prepare_tcmalloc
exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@" exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@"
fi fi
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment