Commit a9d7eb72 authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub

Merge branch 'master' into saving

parents f28ce3e3 4e72a1aa
__pycache__ __pycache__
/ESRGAN *.ckpt
*.pth
/ESRGAN/*
/SwinIR/*
/repositories /repositories
/venv /venv
/tmp /tmp
/model.ckpt /model.ckpt
/models/**/*.ckpt /models/**/*
/GFPGANv1.3.pth /GFPGANv1.3.pth
/gfpgan/weights/*.pth /gfpgan/weights/*.pth
/ui-config.json /ui-config.json
...@@ -22,3 +25,4 @@ __pycache__ ...@@ -22,3 +25,4 @@ __pycache__
/.idea /.idea
notification.mp3 notification.mp3
/SwinIR /SwinIR
/textual_inversion
...@@ -3,50 +3,64 @@ A browser interface based on Gradio library for Stable Diffusion. ...@@ -3,50 +3,64 @@ A browser interface based on Gradio library for Stable Diffusion.
![](txt2img_Screenshot.png) ![](txt2img_Screenshot.png)
Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Scripts) wiki page for extra scripts developed by users.
## Features ## Features
[Detailed feature showcase with images](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features): [Detailed feature showcase with images](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features):
- Original txt2img and img2img modes - Original txt2img and img2img modes
- One click install and run script (but you still must install python and git) - One click install and run script (but you still must install python and git)
- Outpainting - Outpainting
- Inpainting - Inpainting
- Prompt matrix - Prompt Matrix
- Stable Diffusion upscale - Stable Diffusion Upscale
- Attention - Attention, specify parts of text that the model should pay more attention to
- Loopback - a man in a ((tuxedo)) - will pay more attention to tuxedo
- X/Y plot - a man in a (tuxedo:1.21) - alternative syntax
- Loopback, run img2img processing multiple times
- X/Y plot, a way to draw a 2 dimensional plot of images with different parameters
- Textual Inversion - Textual Inversion
- have as many embeddings as you want and use any names you like for them
- use multiple embeddings with different numbers of vectors per token
- works with half precision floating point numbers
- Extras tab with: - Extras tab with:
- GFPGAN, neural network that fixes faces - GFPGAN, neural network that fixes faces
- CodeFormer, face restoration tool as an alternative to GFPGAN - CodeFormer, face restoration tool as an alternative to GFPGAN
- RealESRGAN, neural network upscaler - RealESRGAN, neural network upscaler
- ESRGAN, neural network with a lot of third party models - ESRGAN, neural network upscaler with a lot of third party models
- SwinIR, neural network upscaler - SwinIR, neural network upscaler
- LDSR, Latent diffusion super resolution upscaling - LDSR, Latent diffusion super resolution upscaling
- Resizing aspect ratio options - Resizing aspect ratio options
- Sampling method selection - Sampling method selection
- Interrupt processing at any time - Interrupt processing at any time
- 4GB video card support - 4GB video card support (also reports of 2GB working)
- Correct seeds for batches - Correct seeds for batches
- Prompt length validation - Prompt length validation
- Generation parameters added as text to PNG - get length of prompt in tokens as you type
- Tab to view an existing picture's generation parameters - get a warning after generation if some text was truncated
- Generation parameters
- parameters you used to generate images are saved with that image
- in PNG chunks for PNG, in EXIF for JPEG
- can drag the image to PNG info tab to restore generation parameters and automatically copy them into UI
- can be disabled in settings
- Settings page - Settings page
- Running custom code from UI - Running arbitrary python code from UI (must run with --allow-code to enable)
- Mouseover hints for most UI elements - Mouseover hints for most UI elements
- Possible to change defaults/mix/max/step values for UI elements via text config - Possible to change defaults/mix/max/step values for UI elements via text config
- Random artist button - Random artist button
- Tiling support: UI checkbox to create images that can be tiled like textures - Tiling support, a checkbox to create images that can be tiled like textures
- Progress bar and live image generation preview - Progress bar and live image generation preview
- Negative prompt - Negative prompt, an extra text field that allows you to list what you don't want to see in generated image
- Styles - Styles, a way to save part of prompt and easily apply them via dropdown later
- Variations - Variations, a way to generate same image but with tiny differences
- Seed resizing - Seed resizing, a way to generate same image but at slightly different resolution
- CLIP interrogator - CLIP interrogator, a button that tries to guess prompt from an image
- Prompt Editing - Prompt Editing, a way to change prompt mid-generation, say to start making a watermelon and switch to anime girl midway
- Batch Processing - Batch Processing, process a group of files using img2img
- Img2img Alternative - Img2img Alternative
- Highres Fix - Highres Fix, a convenience option to produce high resolution pictures in one click without usual distortions
- LDSR Upscaling - Reloading checkpoints on the fly
- Checkpoint Merger, a tab that allows you to merge two checkpoints into one
- [Custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Scripts) with many extensions from community
## Installation and Running ## Installation and Running
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs. Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
...@@ -83,6 +97,9 @@ bash <(wget -qO- https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusio ...@@ -83,6 +97,9 @@ bash <(wget -qO- https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusio
Find the instructions [here](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Installation-on-Apple-Silicon). Find the instructions [here](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Installation-on-Apple-Silicon).
## Contributing
Here's how to add code to this repo: [Contributing](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing)
## Documentation ## Documentation
The documentation was moved from this README over to the project's [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki). The documentation was moved from this README over to the project's [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki).
......
...@@ -359,7 +359,6 @@ Antanas Sutkus,0.7369492,black-white ...@@ -359,7 +359,6 @@ Antanas Sutkus,0.7369492,black-white
Leonora Carrington,0.73726475,scribbles Leonora Carrington,0.73726475,scribbles
Hieronymus Bosch,0.7369955,scribbles Hieronymus Bosch,0.7369955,scribbles
A. J. Casson,0.73666203,scribbles A. J. Casson,0.73666203,scribbles
A.J.Casson,0.73666203,scribbles
Chaim Soutine,0.73662066,scribbles Chaim Soutine,0.73662066,scribbles
Artur Bordalo,0.7364549,weird Artur Bordalo,0.7364549,weird
Thomas Allom,0.68792284,fineart Thomas Allom,0.68792284,fineart
...@@ -1907,7 +1906,6 @@ Alex Schomburg,0.46614102,digipa-low-impact ...@@ -1907,7 +1906,6 @@ Alex Schomburg,0.46614102,digipa-low-impact
Bastien L. Deharme,0.583349,special Bastien L. Deharme,0.583349,special
František Jakub Prokyš,0.58782333,fineart František Jakub Prokyš,0.58782333,fineart
Jesper Ejsing,0.58782053,fineart Jesper Ejsing,0.58782053,fineart
Jesper Ejsing,0.58782053,fineart
Odd Nerdrum,0.53551745,digipa-high-impact Odd Nerdrum,0.53551745,digipa-high-impact
Tom Lovell,0.5877577,fineart Tom Lovell,0.5877577,fineart
Ayami Kojima,0.5877416,fineart Ayami Kojima,0.5877416,fineart
......
...@@ -58,8 +58,8 @@ titles = { ...@@ -58,8 +58,8 @@ titles = {
"Interrogate": "Reconstruct prompt from existing image and put it into the prompt field.", "Interrogate": "Reconstruct prompt from existing image and put it into the prompt field.",
"Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [prompt], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [prompt_words], [date], [datetime], [job_timestamp]; leave empty for default.", "Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [prompt_words], [date], [datetime], [job_timestamp]; leave empty for default.",
"Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg], [prompt], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [prompt_words], [date], [datetime], [job_timestamp]; leave empty for default.", "Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [prompt_words], [date], [datetime], [job_timestamp]; leave empty for default.",
"Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle", "Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle",
"Loopback": "Process an image, use it as an input, repeat.", "Loopback": "Process an image, use it as an input, repeat.",
......
...@@ -30,6 +30,7 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_inte ...@@ -30,6 +30,7 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_inte
onUiUpdate(function(){ onUiUpdate(function(){
check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery') check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery')
check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery') check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery')
check_progressbar('ti', 'ti_progressbar', 'ti_progress_span', 'ti_interrupt', 'ti_preview', 'ti_gallery')
}) })
function requestMoreProgress(id_part, id_progressbar_span, id_interrupt){ function requestMoreProgress(id_part, id_progressbar_span, id_interrupt){
......
function start_training_textual_inversion(){
requestProgress('ti')
gradioApp().querySelector('#ti_error').innerHTML=''
return args_to_array(arguments)
}
...@@ -186,10 +186,12 @@ onUiUpdate(function(){ ...@@ -186,10 +186,12 @@ onUiUpdate(function(){
if (!txt2img_textarea) { if (!txt2img_textarea) {
txt2img_textarea = gradioApp().querySelector("#txt2img_prompt > label > textarea"); txt2img_textarea = gradioApp().querySelector("#txt2img_prompt > label > textarea");
txt2img_textarea?.addEventListener("input", () => update_token_counter("txt2img_token_button")); txt2img_textarea?.addEventListener("input", () => update_token_counter("txt2img_token_button"));
txt2img_textarea?.addEventListener("keyup", (event) => submit_prompt(event, "txt2img_generate"));
} }
if (!img2img_textarea) { if (!img2img_textarea) {
img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea"); img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea");
img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button")); img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button"));
img2img_textarea?.addEventListener("keyup", (event) => submit_prompt(event, "img2img_generate"));
} }
}) })
...@@ -197,6 +199,14 @@ let txt2img_textarea, img2img_textarea = undefined; ...@@ -197,6 +199,14 @@ let txt2img_textarea, img2img_textarea = undefined;
let wait_time = 800 let wait_time = 800
let token_timeout; let token_timeout;
function submit_prompt(event, generate_button_id) {
if (event.altKey && event.keyCode === 13) {
event.preventDefault();
gradioApp().getElementById(generate_button_id).click();
return;
}
}
function update_token_counter(button_id) { function update_token_counter(button_id) {
if (token_timeout) if (token_timeout)
clearTimeout(token_timeout); clearTimeout(token_timeout);
......
# this scripts installs necessary requirements and launches main program in webui.py # this scripts installs necessary requirements and launches main program in webui.py
import subprocess import subprocess
import os import os
import sys import sys
...@@ -19,10 +18,9 @@ gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/Tencen ...@@ -19,10 +18,9 @@ gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/Tencen
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc") stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc")
taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6") taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "9e3002b7cd64df7870e08527b7664eb2f2f5f3f5") k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "a7ec1974d4ccb394c2dca275f42cd97490618924")
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af") codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
ldsr_commit_hash = os.environ.get('LDSR_COMMIT_HASH', "abf33e7002d59d9085081bce93ec798dcabd49af")
args = shlex.split(commandline_args) args = shlex.split(commandline_args)
...@@ -120,8 +118,6 @@ git_clone("https://github.com/CompVis/taming-transformers.git", repo_dir('taming ...@@ -120,8 +118,6 @@ git_clone("https://github.com/CompVis/taming-transformers.git", repo_dir('taming
git_clone("https://github.com/crowsonkb/k-diffusion.git", repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash) git_clone("https://github.com/crowsonkb/k-diffusion.git", repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
git_clone("https://github.com/sczhou/CodeFormer.git", repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash) git_clone("https://github.com/sczhou/CodeFormer.git", repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
git_clone("https://github.com/salesforce/BLIP.git", repo_dir('BLIP'), "BLIP", blip_commit_hash) git_clone("https://github.com/salesforce/BLIP.git", repo_dir('BLIP'), "BLIP", blip_commit_hash)
# Using my repo until my changes are merged, as this makes interfacing with our version of SD-web a lot easier
git_clone("https://github.com/Hafiidz/latent-diffusion", repo_dir('latent-diffusion'), "LDSR", ldsr_commit_hash)
if not is_installed("lpips"): if not is_installed("lpips"):
run_pip(f"install -r {os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}", "requirements for CodeFormer") run_pip(f"install -r {os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}", "requirements for CodeFormer")
...@@ -130,6 +126,9 @@ run_pip(f"install -r {requirements_file}", "requirements for Web UI") ...@@ -130,6 +126,9 @@ run_pip(f"install -r {requirements_file}", "requirements for Web UI")
sys.argv += args sys.argv += args
if "--exit" in args:
print("Exiting because of --exit argument")
exit(0)
def start_webui(): def start_webui():
print(f"Launching Web UI with arguments: {' '.join(sys.argv[1:])}") print(f"Launching Web UI with arguments: {' '.join(sys.argv[1:])}")
......
import os.path
import sys
import traceback
import PIL.Image
import numpy as np
import torch
from basicsr.utils.download_util import load_file_from_url
import modules.upscaler
from modules import shared, modelloader
from modules.bsrgan_model_arch import RRDBNet
from modules.paths import models_path
class UpscalerBSRGAN(modules.upscaler.Upscaler):
def __init__(self, dirname):
self.name = "BSRGAN"
self.model_path = os.path.join(models_path, self.name)
self.model_name = "BSRGAN 4x"
self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth"
self.user_path = dirname
super().__init__()
model_paths = self.find_models(ext_filter=[".pt", ".pth"])
scalers = []
if len(model_paths) == 0:
scaler_data = modules.upscaler.UpscalerData(self.model_name, self.model_url, self, 4)
scalers.append(scaler_data)
for file in model_paths:
if "http" in file:
name = self.model_name
else:
name = modelloader.friendly_name(file)
try:
scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
scalers.append(scaler_data)
except Exception:
print(f"Error loading BSRGAN model: {file}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
self.scalers = scalers
def do_upscale(self, img: PIL.Image, selected_file):
torch.cuda.empty_cache()
model = self.load_model(selected_file)
if model is None:
return img
model.to(shared.device)
torch.cuda.empty_cache()
img = np.array(img)
img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float()
img = img.unsqueeze(0).to(shared.device)
with torch.no_grad():
output = model(img)
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
output = 255. * np.moveaxis(output, 0, 2)
output = output.astype(np.uint8)
output = output[:, :, ::-1]
torch.cuda.empty_cache()
return PIL.Image.fromarray(output, 'RGB')
def load_model(self, path: str):
if "http" in path:
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
progress=True)
else:
filename = path
if not os.path.exists(filename) or filename is None:
print(f"BSRGAN: Unable to load model from {filename}", file=sys.stderr)
return None
model = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4) # define network
model.load_state_dict(torch.load(filename), strict=True)
model.eval()
for k, v in model.named_parameters():
v.requires_grad = False
return model
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
def initialize_weights(net_l, scale=1):
if not isinstance(net_l, list):
net_l = [net_l]
for net in net_l:
for m in net.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
m.weight.data *= scale # for residual block
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
m.weight.data *= scale
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias.data, 0.0)
def make_layer(block, n_layers):
layers = []
for _ in range(n_layers):
layers.append(block())
return nn.Sequential(*layers)
class ResidualDenseBlock_5C(nn.Module):
def __init__(self, nf=64, gc=32, bias=True):
super(ResidualDenseBlock_5C, self).__init__()
# gc: growth channel, i.e. intermediate channels
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
# initialization
initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
def forward(self, x):
x1 = self.lrelu(self.conv1(x))
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
return x5 * 0.2 + x
class RRDB(nn.Module):
'''Residual in Residual Dense Block'''
def __init__(self, nf, gc=32):
super(RRDB, self).__init__()
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
def forward(self, x):
out = self.RDB1(x)
out = self.RDB2(out)
out = self.RDB3(out)
return out * 0.2 + x
class RRDBNet(nn.Module):
def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4):
super(RRDBNet, self).__init__()
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
self.sf = sf
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
#### upsampling
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
if self.sf==4:
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
fea = self.conv_first(x)
trunk = self.trunk_conv(self.RRDB_trunk(fea))
fea = fea + trunk
fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
if self.sf==4:
fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
out = self.conv_last(self.lrelu(self.HRconv(fea)))
return out
\ No newline at end of file
...@@ -5,31 +5,31 @@ import traceback ...@@ -5,31 +5,31 @@ import traceback
import cv2 import cv2
import torch import torch
from modules import shared, devices
from modules.paths import script_path
import modules.shared
import modules.face_restoration import modules.face_restoration
from importlib import reload import modules.shared
from modules import shared, devices, modelloader
from modules.paths import script_path, models_path
# codeformer people made a choice to include modified basicsr librry to their projectwhich makes # codeformer people made a choice to include modified basicsr library to their project which makes
# it utterly impossiblr to use it alongside with other libraries that also use basicsr, like GFPGAN. # it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN.
# I am making a choice to include some files from codeformer to work around this issue. # I am making a choice to include some files from codeformer to work around this issue.
model_dir = "Codeformer"
pretrain_model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' model_path = os.path.join(models_path, model_dir)
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
have_codeformer = False have_codeformer = False
codeformer = None codeformer = None
def setup_codeformer():
def setup_model(dirname):
global model_path
if not os.path.exists(model_path):
os.makedirs(model_path)
path = modules.paths.paths.get("CodeFormer", None) path = modules.paths.paths.get("CodeFormer", None)
if path is None: if path is None:
return return
# both GFPGAN and CodeFormer use bascisr, one has it installed from pip the other uses its own
#stored_sys_path = sys.path
#sys.path = [path] + sys.path
try: try:
from torchvision.transforms.functional import normalize from torchvision.transforms.functional import normalize
from modules.codeformer.codeformer_arch import CodeFormer from modules.codeformer.codeformer_arch import CodeFormer
...@@ -44,18 +44,23 @@ def setup_codeformer(): ...@@ -44,18 +44,23 @@ def setup_codeformer():
def name(self): def name(self):
return "CodeFormer" return "CodeFormer"
def __init__(self): def __init__(self, dirname):
self.net = None self.net = None
self.face_helper = None self.face_helper = None
self.cmd_dir = dirname
def create_models(self): def create_models(self):
if self.net is not None and self.face_helper is not None: if self.net is not None and self.face_helper is not None:
self.net.to(devices.device_codeformer) self.net.to(devices.device_codeformer)
return self.net, self.face_helper return self.net, self.face_helper
model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth')
if len(model_paths) != 0:
ckpt_path = model_paths[0]
else:
print("Unable to load codeformer model.")
return None, None
net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer) net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer)
ckpt_path = load_file_from_url(url=pretrain_model_url, model_dir=os.path.join(path, 'weights/CodeFormer'), progress=True)
checkpoint = torch.load(ckpt_path)['params_ema'] checkpoint = torch.load(ckpt_path)['params_ema']
net.load_state_dict(checkpoint) net.load_state_dict(checkpoint)
net.eval() net.eval()
...@@ -74,6 +79,9 @@ def setup_codeformer(): ...@@ -74,6 +79,9 @@ def setup_codeformer():
original_resolution = np_image.shape[0:2] original_resolution = np_image.shape[0:2]
self.create_models() self.create_models()
if self.net is None or self.face_helper is None:
return np_image
self.face_helper.clean_all() self.face_helper.clean_all()
self.face_helper.read_image(np_image) self.face_helper.read_image(np_image)
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5) self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
...@@ -114,7 +122,7 @@ def setup_codeformer(): ...@@ -114,7 +122,7 @@ def setup_codeformer():
have_codeformer = True have_codeformer = True
global codeformer global codeformer
codeformer = FaceRestorerCodeFormer() codeformer = FaceRestorerCodeFormer(dirname)
shared.face_restorers.append(codeformer) shared.face_restorers.append(codeformer)
except Exception: except Exception:
......
...@@ -32,10 +32,9 @@ def enable_tf32(): ...@@ -32,10 +32,9 @@ def enable_tf32():
errors.run(enable_tf32, "Enabling TF32") errors.run(enable_tf32, "Enabling TF32")
device = get_optimal_device() device = get_optimal_device()
device_codeformer = cpu if has_mps else device device_codeformer = cpu if has_mps else device
dtype = torch.float16
def randn(seed, shape): def randn(seed, shape):
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used. # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
......
import os import os
import sys
import traceback
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from basicsr.utils.download_util import load_file_from_url
import modules.esrgam_model_arch as arch import modules.esrgam_model_arch as arch
from modules import shared from modules import shared, modelloader, images
from modules.shared import opts
from modules.devices import has_mps from modules.devices import has_mps
import modules.images from modules.paths import models_path
from modules.upscaler import Upscaler, UpscalerData
from modules.shared import opts
def load_model(filename): def fix_model_layers(crt_model, pretrained_net):
# this code is adapted from https://github.com/xinntao/ESRGAN # this code is adapted from https://github.com/xinntao/ESRGAN
pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None)
crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
if 'conv_first.weight' in pretrained_net: if 'conv_first.weight' in pretrained_net:
crt_model.load_state_dict(pretrained_net) return pretrained_net
return crt_model
if 'model.0.weight' not in pretrained_net: if 'model.0.weight' not in pretrained_net:
is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net["params_ema"] is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net["params_ema"]
...@@ -72,9 +68,59 @@ def load_model(filename): ...@@ -72,9 +68,59 @@ def load_model(filename):
crt_net['conv_last.weight'] = pretrained_net['model.10.weight'] crt_net['conv_last.weight'] = pretrained_net['model.10.weight']
crt_net['conv_last.bias'] = pretrained_net['model.10.bias'] crt_net['conv_last.bias'] = pretrained_net['model.10.bias']
crt_model.load_state_dict(crt_net) return crt_net
crt_model.eval()
return crt_model class UpscalerESRGAN(Upscaler):
def __init__(self, dirname):
self.name = "ESRGAN"
self.model_url = "https://drive.google.com/u/0/uc?id=1TPrz5QKd8DHHt1k8SRtm6tMiPjz_Qene&export=download"
self.model_name = "ESRGAN 4x"
self.scalers = []
self.user_path = dirname
self.model_path = os.path.join(models_path, self.name)
super().__init__()
model_paths = self.find_models(ext_filter=[".pt", ".pth"])
scalers = []
if len(model_paths) == 0:
scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
scalers.append(scaler_data)
for file in model_paths:
if "http" in file:
name = self.model_name
else:
name = modelloader.friendly_name(file)
scaler_data = UpscalerData(name, file, self, 4)
self.scalers.append(scaler_data)
def do_upscale(self, img, selected_model):
model = self.load_model(selected_model)
if model is None:
return img
model.to(shared.device)
img = esrgan_upscale(model, img)
return img
def load_model(self, path: str):
if "http" in path:
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path,
file_name="%s.pth" % self.model_name,
progress=True)
else:
filename = path
if not os.path.exists(filename) or filename is None:
print("Unable to load %s from %s" % (self.model_path, filename))
return None
pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None)
crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
pretrained_net = fix_model_layers(crt_model, pretrained_net)
crt_model.load_state_dict(pretrained_net)
crt_model.eval()
return crt_model
def upscale_without_tiling(model, img): def upscale_without_tiling(model, img):
img = np.array(img) img = np.array(img)
...@@ -95,7 +141,7 @@ def esrgan_upscale(model, img): ...@@ -95,7 +141,7 @@ def esrgan_upscale(model, img):
if opts.ESRGAN_tile == 0: if opts.ESRGAN_tile == 0:
return upscale_without_tiling(model, img) return upscale_without_tiling(model, img)
grid = modules.images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap) grid = images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap)
newtiles = [] newtiles = []
scale_factor = 1 scale_factor = 1
...@@ -110,32 +156,6 @@ def esrgan_upscale(model, img): ...@@ -110,32 +156,6 @@ def esrgan_upscale(model, img):
newrow.append([x * scale_factor, w * scale_factor, output]) newrow.append([x * scale_factor, w * scale_factor, output])
newtiles.append([y * scale_factor, h * scale_factor, newrow]) newtiles.append([y * scale_factor, h * scale_factor, newrow])
newgrid = modules.images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor) newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
output = modules.images.combine_grid(newgrid) output = images.combine_grid(newgrid)
return output return output
class UpscalerESRGAN(modules.images.Upscaler):
def __init__(self, filename, title):
self.name = title
self.model = load_model(filename)
def do_upscale(self, img):
model = self.model.to(shared.device)
img = esrgan_upscale(model, img)
return img
def load_models(dirname):
for file in os.listdir(dirname):
path = os.path.join(dirname, file)
model_name, extension = os.path.splitext(file)
if extension != '.pt' and extension != '.pth':
continue
try:
modules.shared.sd_upscalers.append(UpscalerESRGAN(path, model_name))
except Exception:
print(f"Error loading ESRGAN model: {path}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
...@@ -40,6 +40,8 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v ...@@ -40,6 +40,8 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v
outputs = [] outputs = []
for image, image_name in zip(imageArr, imageNameArr): for image, image_name in zip(imageArr, imageNameArr):
if image is None:
return outputs, "Please select an input image.", ''
existing_pnginfo = image.info or {} existing_pnginfo = image.info or {}
image = image.convert("RGB") image = image.convert("RGB")
...@@ -74,7 +76,7 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v ...@@ -74,7 +76,7 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v
c = cached_images.get(key) c = cached_images.get(key)
if c is None: if c is None:
upscaler = shared.sd_upscalers[scaler_index] upscaler = shared.sd_upscalers[scaler_index]
c = upscaler.upscale(image, image.width * resize, image.height * resize) c = upscaler.scaler.upscale(image, resize, upscaler.data_path)
cached_images[key] = c cached_images[key] = c
return c return c
...@@ -189,9 +191,11 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int ...@@ -189,9 +191,11 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int
if save_as_half: if save_as_half:
theta_0[key] = theta_0[key].half() theta_0[key] = theta_0[key].half()
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
filename = primary_model_info.model_name + '_' + str(round(interp_amount, 2)) + '-' + secondary_model_info.model_name + '_' + str(round((float(1.0) - interp_amount), 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt' filename = primary_model_info.model_name + '_' + str(round(interp_amount, 2)) + '-' + secondary_model_info.model_name + '_' + str(round((float(1.0) - interp_amount), 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt'
filename = filename if custom_name == '' else (custom_name + '.ckpt') filename = filename if custom_name == '' else (custom_name + '.ckpt')
output_modelname = os.path.join(shared.cmd_opts.ckpt_dir, filename) output_modelname = os.path.join(ckpt_dir, filename)
print(f"Saving to {output_modelname}...") print(f"Saving to {output_modelname}...")
torch.save(primary_model, output_modelname) torch.save(primary_model, output_modelname)
......
import os import os
import sys import sys
import traceback import traceback
from glob import glob
from modules import shared, devices import facexlib
from modules.shared import cmd_opts import gfpgan
from modules.paths import script_path
import modules.face_restoration
def gfpgan_model_path():
from modules.shared import cmd_opts
filemask = 'GFPGAN*.pth'
if cmd_opts.gfpgan_model is not None:
return cmd_opts.gfpgan_model
places = [script_path, '.', os.path.join(cmd_opts.gfpgan_dir, 'experiments/pretrained_models')]
filename = None
for place in places:
filename = next(iter(glob(os.path.join(place, filemask))), None)
if filename is not None:
break
return filename
import modules.face_restoration
from modules import shared, devices, modelloader
from modules.paths import models_path
model_dir = "GFPGAN"
user_path = None
model_path = os.path.join(models_path, model_dir)
model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
have_gfpgan = False
loaded_gfpgan_model = None loaded_gfpgan_model = None
def gfpgan(): def gfpgann():
global loaded_gfpgan_model global loaded_gfpgan_model
global model_path
if loaded_gfpgan_model is not None: if loaded_gfpgan_model is not None:
loaded_gfpgan_model.gfpgan.to(shared.device) loaded_gfpgan_model.gfpgan.to(shared.device)
return loaded_gfpgan_model return loaded_gfpgan_model
...@@ -41,7 +27,16 @@ def gfpgan(): ...@@ -41,7 +27,16 @@ def gfpgan():
if gfpgan_constructor is None: if gfpgan_constructor is None:
return None return None
model = gfpgan_constructor(model_path=gfpgan_model_path() or 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth', upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None) models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
if len(models) == 1 and "http" in models[0]:
model_file = models[0]
elif len(models) != 0:
latest_file = max(models, key=os.path.getctime)
model_file = latest_file
else:
print("Unable to load gfpgan model!")
return None
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
model.gfpgan.to(shared.device) model.gfpgan.to(shared.device)
loaded_gfpgan_model = model loaded_gfpgan_model = model
...@@ -49,8 +44,9 @@ def gfpgan(): ...@@ -49,8 +44,9 @@ def gfpgan():
def gfpgan_fix_faces(np_image): def gfpgan_fix_faces(np_image):
model = gfpgan() model = gfpgann()
if model is None:
return np_image
np_image_bgr = np_image[:, :, ::-1] np_image_bgr = np_image[:, :, ::-1]
cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True) cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
np_image = gfpgan_output_bgr[:, :, ::-1] np_image = gfpgan_output_bgr[:, :, ::-1]
...@@ -61,21 +57,39 @@ def gfpgan_fix_faces(np_image): ...@@ -61,21 +57,39 @@ def gfpgan_fix_faces(np_image):
return np_image return np_image
have_gfpgan = False
gfpgan_constructor = None gfpgan_constructor = None
def setup_gfpgan():
try:
gfpgan_model_path()
if os.path.exists(cmd_opts.gfpgan_dir): def setup_model(dirname):
sys.path.append(os.path.abspath(cmd_opts.gfpgan_dir)) global model_path
from gfpgan import GFPGANer if not os.path.exists(model_path):
os.makedirs(model_path)
try:
from gfpgan import GFPGANer
from facexlib import detection, parsing
global user_path
global have_gfpgan global have_gfpgan
have_gfpgan = True
global gfpgan_constructor global gfpgan_constructor
load_file_from_url_orig = gfpgan.utils.load_file_from_url
facex_load_file_from_url_orig = facexlib.detection.load_file_from_url
facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url
def my_load_file_from_url(**kwargs):
return load_file_from_url_orig(**dict(kwargs, model_dir=model_path))
def facex_load_file_from_url(**kwargs):
return facex_load_file_from_url_orig(**dict(kwargs, save_dir=model_path, model_dir=None))
def facex_load_file_from_url2(**kwargs):
return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=model_path, model_dir=None))
gfpgan.utils.load_file_from_url = my_load_file_from_url
facexlib.detection.load_file_from_url = facex_load_file_from_url
facexlib.parsing.load_file_from_url = facex_load_file_from_url2
user_path = dirname
have_gfpgan = True
gfpgan_constructor = GFPGANer gfpgan_constructor = GFPGANer
class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration): class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration):
...@@ -84,7 +98,7 @@ def setup_gfpgan(): ...@@ -84,7 +98,7 @@ def setup_gfpgan():
def restore(self, np_image): def restore(self, np_image):
np_image_bgr = np_image[:, :, ::-1] np_image_bgr = np_image[:, :, ::-1]
cropped_faces, restored_faces, gfpgan_output_bgr = gfpgan().enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True) cropped_faces, restored_faces, gfpgan_output_bgr = gfpgann().enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
np_image = gfpgan_output_bgr[:, :, ::-1] np_image = gfpgan_output_bgr[:, :, ::-1]
return np_image return np_image
......
...@@ -11,7 +11,6 @@ from PIL import Image, ImageFont, ImageDraw, PngImagePlugin ...@@ -11,7 +11,6 @@ from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
from fonts.ttf import Roboto from fonts.ttf import Roboto
import string import string
import modules.shared
from modules import sd_samplers, shared from modules import sd_samplers, shared
from modules.shared import opts, cmd_opts from modules.shared import opts, cmd_opts
...@@ -52,8 +51,8 @@ def split_grid(image, tile_w=512, tile_h=512, overlap=64): ...@@ -52,8 +51,8 @@ def split_grid(image, tile_w=512, tile_h=512, overlap=64):
cols = math.ceil((w - overlap) / non_overlap_width) cols = math.ceil((w - overlap) / non_overlap_width)
rows = math.ceil((h - overlap) / non_overlap_height) rows = math.ceil((h - overlap) / non_overlap_height)
dx = (w - tile_w) / (cols-1) if cols > 1 else 0 dx = (w - tile_w) / (cols - 1) if cols > 1 else 0
dy = (h - tile_h) / (rows-1) if rows > 1 else 0 dy = (h - tile_h) / (rows - 1) if rows > 1 else 0
grid = Grid([], tile_w, tile_h, w, h, overlap) grid = Grid([], tile_w, tile_h, w, h, overlap)
for row in range(rows): for row in range(rows):
...@@ -67,7 +66,7 @@ def split_grid(image, tile_w=512, tile_h=512, overlap=64): ...@@ -67,7 +66,7 @@ def split_grid(image, tile_w=512, tile_h=512, overlap=64):
for col in range(cols): for col in range(cols):
x = int(col * dx) x = int(col * dx)
if x+tile_w >= w: if x + tile_w >= w:
x = w - tile_w x = w - tile_w
tile = image.crop((x, y, x + tile_w, y + tile_h)) tile = image.crop((x, y, x + tile_w, y + tile_h))
...@@ -132,7 +131,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts): ...@@ -132,7 +131,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor="mm", align="center") drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor="mm", align="center")
if not line.is_active: if not line.is_active:
drawing.line((draw_x - line.size[0]//2, draw_y + line.size[1]//2, draw_x + line.size[0]//2, draw_y + line.size[1]//2), fill=color_inactive, width=4) drawing.line((draw_x - line.size[0] // 2, draw_y + line.size[1] // 2, draw_x + line.size[0] // 2, draw_y + line.size[1] // 2), fill=color_inactive, width=4)
draw_y += line.size[1] + line_spacing draw_y += line.size[1] + line_spacing
...@@ -171,7 +170,8 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts): ...@@ -171,7 +170,8 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
line.size = (bbox[2] - bbox[0], bbox[3] - bbox[1]) line.size = (bbox[2] - bbox[0], bbox[3] - bbox[1])
hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts] hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts]
ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in ver_texts] ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in
ver_texts]
pad_top = max(hor_text_heights) + line_spacing * 2 pad_top = max(hor_text_heights) + line_spacing * 2
...@@ -213,8 +213,19 @@ def resize_image(resize_mode, im, width, height): ...@@ -213,8 +213,19 @@ def resize_image(resize_mode, im, width, height):
if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None" or im.mode == 'L': if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None" or im.mode == 'L':
return im.resize((w, h), resample=LANCZOS) return im.resize((w, h), resample=LANCZOS)
upscaler = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_img2img][0] scale = max(w / im.width, h / im.height)
return upscaler.upscale(im, w, h)
if scale > 1.0:
upscalers = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_img2img]
assert len(upscalers) > 0, f"could not find upscaler named {opts.upscaler_for_img2img}"
upscaler = upscalers[0]
im = upscaler.scaler.upscale(im, scale, upscaler.data_path)
if im.width != w or im.height != h:
im = im.resize((w, h), resample=LANCZOS)
return im
if resize_mode == 0: if resize_mode == 0:
res = resize(im, width, height) res = resize(im, width, height)
...@@ -256,7 +267,7 @@ def resize_image(resize_mode, im, width, height): ...@@ -256,7 +267,7 @@ def resize_image(resize_mode, im, width, height):
invalid_filename_chars = '<>:"/\\|?*\n' invalid_filename_chars = '<>:"/\\|?*\n'
invalid_filename_prefix = ' ' invalid_filename_prefix = ' '
invalid_filename_postfix = ' .' invalid_filename_postfix = ' .'
re_nonletters = re.compile(r'[\s'+string.punctuation+']+') re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
max_filename_part_length = 128 max_filename_part_length = 128
...@@ -278,6 +289,16 @@ def apply_filename_pattern(x, p, seed, prompt): ...@@ -278,6 +289,16 @@ def apply_filename_pattern(x, p, seed, prompt):
if prompt is not None: if prompt is not None:
x = x.replace("[prompt]", sanitize_filename_part(prompt)) x = x.replace("[prompt]", sanitize_filename_part(prompt))
if "[prompt_no_styles]" in x:
prompt_no_style = prompt
for style in shared.prompt_styles.get_style_prompts(p.styles):
if len(style) > 0:
style_parts = [y for y in style.split("{prompt}")]
for part in style_parts:
prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',')
prompt_no_style = prompt_no_style.replace(style, "").strip().strip(',').strip()
x = x.replace("[prompt_no_styles]", sanitize_filename_part(prompt_no_style, replace_spaces=False))
x = x.replace("[prompt_spaces]", sanitize_filename_part(prompt, replace_spaces=False)) x = x.replace("[prompt_spaces]", sanitize_filename_part(prompt, replace_spaces=False))
if "[prompt_words]" in x: if "[prompt_words]" in x:
words = [x for x in re_nonletters.split(prompt or "") if len(x) > 0] words = [x for x in re_nonletters.split(prompt or "") if len(x) > 0]
...@@ -290,10 +311,12 @@ def apply_filename_pattern(x, p, seed, prompt): ...@@ -290,10 +311,12 @@ def apply_filename_pattern(x, p, seed, prompt):
x = x.replace("[cfg]", str(p.cfg_scale)) x = x.replace("[cfg]", str(p.cfg_scale))
x = x.replace("[width]", str(p.width)) x = x.replace("[width]", str(p.width))
x = x.replace("[height]", str(p.height)) x = x.replace("[height]", str(p.height))
#currently disabled if using the save button, will work otherwise #currently disabled if using the save button, will work otherwise
# if enabled it will cause a bug because styles is not included in the save_files data dictionary # if enabled it will cause a bug because styles is not included in the save_files data dictionary
if hasattr(p, "styles"): if hasattr(p, "styles"):
x = x.replace("[styles]", sanitize_filename_part(", ".join(p.styles), replace_spaces=False)) x = x.replace("[styles]", sanitize_filename_part(", ".join([x for x in p.styles if not x == "None"]), replace_spaces=False))
x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False)) x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False))
x = x.replace("[model_hash]", shared.sd_model.sd_model_hash) x = x.replace("[model_hash]", shared.sd_model.sd_model_hash)
...@@ -306,6 +329,7 @@ def apply_filename_pattern(x, p, seed, prompt): ...@@ -306,6 +329,7 @@ def apply_filename_pattern(x, p, seed, prompt):
return x return x
def get_next_sequence_number(path, basename): def get_next_sequence_number(path, basename):
""" """
Determines and returns the next sequence number to use when saving an image in the specified directory. Determines and returns the next sequence number to use when saving an image in the specified directory.
...@@ -319,7 +343,7 @@ def get_next_sequence_number(path, basename): ...@@ -319,7 +343,7 @@ def get_next_sequence_number(path, basename):
prefix_length = len(basename) prefix_length = len(basename)
for p in os.listdir(path): for p in os.listdir(path):
if p.startswith(basename): if p.startswith(basename):
l = os.path.splitext(p[prefix_length:])[0].split('-') #splits the filename (removing the basename first if one is defined, so the sequence number is always the first element) l = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
try: try:
result = max(int(l[0]), result) result = max(int(l[0]), result)
except ValueError: except ValueError:
...@@ -327,6 +351,7 @@ def get_next_sequence_number(path, basename): ...@@ -327,6 +351,7 @@ def get_next_sequence_number(path, basename):
return result + 1 return result + 1
def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix=""): def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix=""):
if short_filename or prompt is None or seed is None: if short_filename or prompt is None or seed is None:
file_decoration = "" file_decoration = ""
...@@ -364,7 +389,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i ...@@ -364,7 +389,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
fullfn = "a.png" fullfn = "a.png"
fullfn_without_extension = "a" fullfn_without_extension = "a"
for i in range(500): for i in range(500):
fn = f"{basecount+i:05}" if basename == '' else f"{basename}-{basecount+i:04}" fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}"
fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}") fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}")
fullfn_without_extension = os.path.join(path, f"{fn}{file_decoration}") fullfn_without_extension = os.path.join(path, f"{fn}{file_decoration}")
if not os.path.exists(fullfn): if not os.path.exists(fullfn):
...@@ -406,31 +431,3 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i ...@@ -406,31 +431,3 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
file.write(info + "\n") file.write(info + "\n")
class Upscaler:
name = "Lanczos"
def do_upscale(self, img):
return img
def upscale(self, img, w, h):
for i in range(3):
if img.width >= w and img.height >= h:
break
img = self.do_upscale(img)
if img.width != w or img.height != h:
img = img.resize((int(w), int(h)), resample=LANCZOS)
return img
class UpscalerNone(Upscaler):
name = "None"
def upscale(self, img, w, h):
return img
modules.shared.sd_upscalers.append(UpscalerNone())
modules.shared.sd_upscalers.append(Upscaler())
import os import os
import sys import sys
import traceback import traceback
from collections import namedtuple
from basicsr.utils.download_util import load_file_from_url from basicsr.utils.download_util import load_file_from_url
import modules.images from modules.upscaler import Upscaler, UpscalerData
from modules.ldsr_model_arch import LDSR
from modules import shared from modules import shared
from modules.paths import script_path from modules.paths import models_path
LDSRModelInfo = namedtuple("LDSRModelInfo", ["name", "location", "model", "netscale"])
ldsr_models = [] class UpscalerLDSR(Upscaler):
have_ldsr = False def __init__(self, user_path):
LDSR_obj = None
class UpscalerLDSR(modules.images.Upscaler):
def __init__(self, steps):
self.steps = steps
self.name = "LDSR" self.name = "LDSR"
self.model_path = os.path.join(models_path, self.name)
def do_upscale(self, img): self.user_path = user_path
return upscale_with_ldsr(img) self.model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1"
self.yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1"
super().__init__()
def add_lsdr(): scaler_data = UpscalerData("LDSR", None, self)
modules.shared.sd_upscalers.append(UpscalerLDSR(100)) self.scalers = [scaler_data]
def load_model(self, path: str):
def setup_ldsr(): # Remove incorrect project.yaml file if too big
path = modules.paths.paths.get("LDSR", None) yaml_path = os.path.join(self.model_path, "project.yaml")
if path is None: old_model_path = os.path.join(self.model_path, "model.pth")
return new_model_path = os.path.join(self.model_path, "model.ckpt")
global have_ldsr if os.path.exists(yaml_path):
global LDSR_obj statinfo = os.stat(yaml_path)
try: if statinfo.st_size >= 10485760:
from LDSR import LDSR print("Removing invalid LDSR YAML file.")
model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1" os.remove(yaml_path)
yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1" if os.path.exists(old_model_path):
repo_path = 'latent-diffusion/experiments/pretrained_models/' print("Renaming model from model.pth to model.ckpt")
model_path = load_file_from_url(url=model_url, model_dir=os.path.join("repositories", repo_path), os.rename(old_model_path, new_model_path)
progress=True, file_name="model.chkpt") model = load_file_from_url(url=self.model_url, model_dir=self.model_path,
yaml_path = load_file_from_url(url=yaml_url, model_dir=os.path.join("repositories", repo_path), file_name="model.ckpt", progress=True)
progress=True, file_name="project.yaml") yaml = load_file_from_url(url=self.yaml_url, model_dir=self.model_path,
have_ldsr = True file_name="project.yaml", progress=True)
LDSR_obj = LDSR(model_path, yaml_path)
try:
return LDSR(model, yaml)
except Exception:
print("Error importing LDSR:", file=sys.stderr) except Exception:
print(traceback.format_exc(), file=sys.stderr) print("Error importing LDSR:", file=sys.stderr)
have_ldsr = False print(traceback.format_exc(), file=sys.stderr)
return None
def upscale_with_ldsr(image): def do_upscale(self, img, path):
setup_ldsr() ldsr = self.load_model(path)
if not have_ldsr or LDSR_obj is None: if ldsr is None:
return image print("NO LDSR!")
return img
ddim_steps = shared.opts.ldsr_steps ddim_steps = shared.opts.ldsr_steps
pre_scale = shared.opts.ldsr_pre_down return ldsr.super_resolution(img, ddim_steps, self.scale)
post_scale = shared.opts.ldsr_post_down
image = LDSR_obj.super_resolution(image, ddim_steps, pre_scale, post_scale)
return image
import gc
import time
import warnings
import numpy as np
import torch
import torchvision
from PIL import Image
from einops import rearrange, repeat
from omegaconf import OmegaConf
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import instantiate_from_config, ismap
warnings.filterwarnings("ignore", category=UserWarning)
# Create LDSR Class
class LDSR:
def load_model_from_config(self, half_attention):
print(f"Loading model from {self.modelPath}")
pl_sd = torch.load(self.modelPath, map_location="cpu")
sd = pl_sd["state_dict"]
config = OmegaConf.load(self.yamlPath)
model = instantiate_from_config(config.model)
model.load_state_dict(sd, strict=False)
model.cuda()
if half_attention:
model = model.half()
model.eval()
return {"model": model}
def __init__(self, model_path, yaml_path):
self.modelPath = model_path
self.yamlPath = yaml_path
@staticmethod
def run(model, selected_path, custom_steps, eta):
example = get_cond(selected_path)
n_runs = 1
guider = None
ckwargs = None
ddim_use_x0_pred = False
temperature = 1.
eta = eta
custom_shape = None
height, width = example["image"].shape[1:3]
split_input = height >= 128 and width >= 128
if split_input:
ks = 128
stride = 64
vqf = 4 #
model.split_input_params = {"ks": (ks, ks), "stride": (stride, stride),
"vqf": vqf,
"patch_distributed_vq": True,
"tie_braker": False,
"clip_max_weight": 0.5,
"clip_min_weight": 0.01,
"clip_max_tie_weight": 0.5,
"clip_min_tie_weight": 0.01}
else:
if hasattr(model, "split_input_params"):
delattr(model, "split_input_params")
x_t = None
logs = None
for n in range(n_runs):
if custom_shape is not None:
x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0])
logs = make_convolutional_sample(example, model,
custom_steps=custom_steps,
eta=eta, quantize_x0=False,
custom_shape=custom_shape,
temperature=temperature, noise_dropout=0.,
corrector=guider, corrector_kwargs=ckwargs, x_T=x_t,
ddim_use_x0_pred=ddim_use_x0_pred
)
return logs
def super_resolution(self, image, steps=100, target_scale=2, half_attention=False):
model = self.load_model_from_config(half_attention)
# Run settings
diffusion_steps = int(steps)
eta = 1.0
down_sample_method = 'Lanczos'
gc.collect()
torch.cuda.empty_cache()
im_og = image
width_og, height_og = im_og.size
# If we can adjust the max upscale size, then the 4 below should be our variable
down_sample_rate = target_scale / 4
wd = width_og * down_sample_rate
hd = height_og * down_sample_rate
width_downsampled_pre = int(wd)
height_downsampled_pre = int(hd)
if down_sample_rate != 1:
print(
f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]')
im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
else:
print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)")
logs = self.run(model["model"], im_og, diffusion_steps, eta)
sample = logs["sample"]
sample = sample.detach().cpu()
sample = torch.clamp(sample, -1., 1.)
sample = (sample + 1.) / 2. * 255
sample = sample.numpy().astype(np.uint8)
sample = np.transpose(sample, (0, 2, 3, 1))
a = Image.fromarray(sample[0])
del model
gc.collect()
torch.cuda.empty_cache()
return a
def get_cond(selected_path):
example = dict()
up_f = 4
c = selected_path.convert('RGB')
c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]],
antialias=True)
c_up = rearrange(c_up, '1 c h w -> 1 h w c')
c = rearrange(c, '1 c h w -> 1 h w c')
c = 2. * c - 1.
c = c.to(torch.device("cuda"))
example["LR_image"] = c
example["image"] = c_up
return example
@torch.no_grad()
def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None,
mask=None, x0=None, quantize_x0=False, temperature=1., score_corrector=None,
corrector_kwargs=None, x_t=None
):
ddim = DDIMSampler(model)
bs = shape[0]
shape = shape[1:]
print(f"Sampling with eta = {eta}; steps: {steps}")
samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback,
normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta,
mask=mask, x0=x0, temperature=temperature, verbose=False,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs, x_t=x_t)
return samples, intermediates
@torch.no_grad()
def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,
corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False):
log = dict()
z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,
return_first_stage_outputs=True,
force_c_encode=not (hasattr(model, 'split_input_params')
and model.cond_stage_key == 'coordinates_bbox'),
return_original_cond=True)
if custom_shape is not None:
z = torch.randn(custom_shape)
print(f"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}")
z0 = None
log["input"] = x
log["reconstruction"] = xrec
if ismap(xc):
log["original_conditioning"] = model.to_rgb(xc)
if hasattr(model, 'cond_stage_key'):
log[model.cond_stage_key] = model.to_rgb(xc)
else:
log["original_conditioning"] = xc if xc is not None else torch.zeros_like(x)
if model.cond_stage_model:
log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x)
if model.cond_stage_key == 'class_label':
log[model.cond_stage_key] = xc[model.cond_stage_key]
with model.ema_scope("Plotting"):
t0 = time.time()
sample, intermediates = convsample_ddim(model, c, steps=custom_steps, shape=z.shape,
eta=eta,
quantize_x0=quantize_x0, mask=None, x0=z0,
temperature=temperature, score_corrector=corrector, corrector_kwargs=corrector_kwargs,
x_t=x_T)
t1 = time.time()
if ddim_use_x0_pred:
sample = intermediates['pred_x0'][-1]
x_sample = model.decode_first_stage(sample)
try:
x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
log["sample_noquant"] = x_sample_noquant
log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
except:
pass
log["sample"] = x_sample
log["time"] = t1 - t0
return log
import glob
import os
import shutil
import importlib
from urllib.parse import urlparse
from basicsr.utils.download_util import load_file_from_url
from modules import shared
from modules.upscaler import Upscaler
from modules.paths import script_path, models_path
def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None) -> list:
"""
A one-and done loader to try finding the desired models in specified directories.
@param download_name: Specify to download from model_url immediately.
@param model_url: If no other models are found, this will be downloaded on upscale.
@param model_path: The location to store/find models in.
@param command_path: A command-line argument to search for models in first.
@param ext_filter: An optional list of filename extensions to filter by
@return: A list of paths containing the desired model(s)
"""
output = []
if ext_filter is None:
ext_filter = []
try:
places = []
if command_path is not None and command_path != model_path:
pretrained_path = os.path.join(command_path, 'experiments/pretrained_models')
if os.path.exists(pretrained_path):
print(f"Appending path: {pretrained_path}")
places.append(pretrained_path)
elif os.path.exists(command_path):
places.append(command_path)
places.append(model_path)
for place in places:
if os.path.exists(place):
for file in glob.iglob(place + '**/**', recursive=True):
full_path = file
if os.path.isdir(full_path):
continue
if len(ext_filter) != 0:
model_name, extension = os.path.splitext(file)
if extension not in ext_filter:
continue
if file not in output:
output.append(full_path)
if model_url is not None and len(output) == 0:
if download_name is not None:
dl = load_file_from_url(model_url, model_path, True, download_name)
output.append(dl)
else:
output.append(model_url)
except Exception:
pass
return output
def friendly_name(file: str):
if "http" in file:
file = urlparse(file).path
file = os.path.basename(file)
model_name, extension = os.path.splitext(file)
return model_name
def cleanup_models():
# This code could probably be more efficient if we used a tuple list or something to store the src/destinations
# and then enumerate that, but this works for now. In the future, it'd be nice to just have every "model" scaler
# somehow auto-register and just do these things...
root_path = script_path
src_path = models_path
dest_path = os.path.join(models_path, "Stable-diffusion")
move_files(src_path, dest_path, ".ckpt")
src_path = os.path.join(root_path, "ESRGAN")
dest_path = os.path.join(models_path, "ESRGAN")
move_files(src_path, dest_path)
src_path = os.path.join(root_path, "gfpgan")
dest_path = os.path.join(models_path, "GFPGAN")
move_files(src_path, dest_path)
src_path = os.path.join(root_path, "SwinIR")
dest_path = os.path.join(models_path, "SwinIR")
move_files(src_path, dest_path)
src_path = os.path.join(root_path, "repositories/latent-diffusion/experiments/pretrained_models/")
dest_path = os.path.join(models_path, "LDSR")
move_files(src_path, dest_path)
def move_files(src_path: str, dest_path: str, ext_filter: str = None):
try:
if not os.path.exists(dest_path):
os.makedirs(dest_path)
if os.path.exists(src_path):
for file in os.listdir(src_path):
fullpath = os.path.join(src_path, file)
if os.path.isfile(fullpath):
if ext_filter is not None:
if ext_filter not in file:
continue
print(f"Moving {file} from {src_path} to {dest_path}.")
try:
shutil.move(fullpath, dest_path)
except:
pass
if len(os.listdir(src_path)) == 0:
print(f"Removing empty folder: {src_path}")
shutil.rmtree(src_path, True)
except:
pass
def load_upscalers():
datas = []
for cls in Upscaler.__subclasses__():
name = cls.__name__
module_name = cls.__module__
module = importlib.import_module(module_name)
class_ = getattr(module, name)
cmd_name = f"{name.lower().replace('upscaler', '')}-models-path"
opt_string = None
try:
opt_string = shared.opts.__getattr__(cmd_name)
except:
pass
scaler = class_(opt_string)
for child in scaler.scalers:
datas.append(child)
shared.sd_upscalers = datas
...@@ -3,9 +3,10 @@ import os ...@@ -3,9 +3,10 @@ import os
import sys import sys
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
models_path = os.path.join(script_path, "models")
sys.path.insert(0, script_path) sys.path.insert(0, script_path)
# search for directory of stable diffsuion in following palces # search for directory of stable diffusion in following places
sd_path = None sd_path = None
possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion'), '.', os.path.dirname(script_path)] possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion'), '.', os.path.dirname(script_path)]
for possible_sd_path in possible_sd_paths: for possible_sd_path in possible_sd_paths:
...@@ -15,21 +16,24 @@ for possible_sd_path in possible_sd_paths: ...@@ -15,21 +16,24 @@ for possible_sd_path in possible_sd_paths:
assert sd_path is not None, "Couldn't find Stable Diffusion in any of: " + str(possible_sd_paths) assert sd_path is not None, "Couldn't find Stable Diffusion in any of: " + str(possible_sd_paths)
path_dirs = [ path_dirs = [
(sd_path, 'ldm', 'Stable Diffusion'), (sd_path, 'ldm', 'Stable Diffusion', []),
(os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers'), (os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers', []),
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer'), (os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP'), (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
(os.path.join(sd_path, '../latent-diffusion'), 'LDSR.py', 'LDSR'), (os.path.join(sd_path, '../latent-diffusion'), 'LDSR.py', 'LDSR', []),
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion'), (os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
] ]
paths = {} paths = {}
for d, must_exist, what in path_dirs: for d, must_exist, what, options in path_dirs:
must_exist_path = os.path.abspath(os.path.join(script_path, d, must_exist)) must_exist_path = os.path.abspath(os.path.join(script_path, d, must_exist))
if not os.path.exists(must_exist_path): if not os.path.exists(must_exist_path):
print(f"Warning: {what} not found at path {must_exist_path}", file=sys.stderr) print(f"Warning: {what} not found at path {must_exist_path}", file=sys.stderr)
else: else:
d = os.path.abspath(d) d = os.path.abspath(d)
sys.path.append(d) if "atstart" in options:
sys.path.insert(0, d)
else:
sys.path.append(d)
paths[what] = d paths[what] = d
...@@ -56,7 +56,7 @@ class StableDiffusionProcessing: ...@@ -56,7 +56,7 @@ class StableDiffusionProcessing:
self.prompt: str = prompt self.prompt: str = prompt
self.prompt_for_display: str = None self.prompt_for_display: str = None
self.negative_prompt: str = (negative_prompt or "") self.negative_prompt: str = (negative_prompt or "")
self.styles: str = styles self.styles: list = styles or []
self.seed: int = seed self.seed: int = seed
self.subseed: int = subseed self.subseed: int = subseed
self.subseed_strength: float = subseed_strength self.subseed_strength: float = subseed_strength
...@@ -79,7 +79,7 @@ class StableDiffusionProcessing: ...@@ -79,7 +79,7 @@ class StableDiffusionProcessing:
self.paste_to = None self.paste_to = None
self.color_corrections = None self.color_corrections = None
self.denoising_strength: float = 0 self.denoising_strength: float = 0
self.sampler_noise_scheduler_override = None
self.ddim_discretize = opts.ddim_discretize self.ddim_discretize = opts.ddim_discretize
self.s_churn = opts.s_churn self.s_churn = opts.s_churn
self.s_tmin = opts.s_tmin self.s_tmin = opts.s_tmin
...@@ -130,7 +130,7 @@ class Processed: ...@@ -130,7 +130,7 @@ class Processed:
self.s_tmin = p.s_tmin self.s_tmin = p.s_tmin
self.s_tmax = p.s_tmax self.s_tmax = p.s_tmax
self.s_noise = p.s_noise self.s_noise = p.s_noise
self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0] self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0] self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) self.seed = int(self.seed if type(self.seed) != list else self.seed[0])
...@@ -271,7 +271,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration ...@@ -271,7 +271,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
"Denoising strength": getattr(p, 'denoising_strength', None), "Denoising strength": getattr(p, 'denoising_strength', None),
"Eta": (None if p.sampler.eta == p.sampler.default_eta else p.sampler.eta), "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
} }
generation_params.update(p.extra_generation_params) generation_params.update(p.extra_generation_params)
...@@ -295,8 +295,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed: ...@@ -295,8 +295,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
fix_seed(p) fix_seed(p)
os.makedirs(p.outpath_samples, exist_ok=True) if p.outpath_samples is not None:
os.makedirs(p.outpath_grids, exist_ok=True) os.makedirs(p.outpath_samples, exist_ok=True)
if p.outpath_grids is not None:
os.makedirs(p.outpath_grids, exist_ok=True)
modules.sd_hijack.model_hijack.apply_circular(p.tiling) modules.sd_hijack.model_hijack.apply_circular(p.tiling)
...@@ -323,7 +326,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: ...@@ -323,7 +326,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch) return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch)
if os.path.exists(cmd_opts.embeddings_dir): if os.path.exists(cmd_opts.embeddings_dir):
model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, p.sd_model) model_hijack.embedding_db.load_textual_inversion_embeddings()
infotexts = [] infotexts = []
output_images = [] output_images = []
...@@ -492,8 +495,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): ...@@ -492,8 +495,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8) x_sample = x_sample.astype(np.uint8)
image = Image.fromarray(x_sample) image = Image.fromarray(x_sample)
upscaler = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_img2img][0] image = images.resize_image(0, image, self.width, self.height)
image = upscaler.upscale(image, self.width, self.height)
image = np.array(image).astype(np.float32) / 255.0 image = np.array(image).astype(np.float32) / 255.0
image = np.moveaxis(image, 2, 0) image = np.moveaxis(image, 2, 0)
batch_images.append(image) batch_images.append(image)
......
import os
import sys import sys
import traceback import traceback
from collections import namedtuple
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer from realesrgan import RealESRGANer
import modules.images from modules.upscaler import Upscaler, UpscalerData
from modules.paths import models_path
from modules.shared import cmd_opts, opts from modules.shared import cmd_opts, opts
RealesrganModelInfo = namedtuple("RealesrganModelInfo", ["name", "location", "model", "netscale"])
realesrgan_models = []
have_realesrgan = False
class UpscalerRealESRGAN(Upscaler):
def __init__(self, path):
self.name = "RealESRGAN"
self.model_path = os.path.join(models_path, self.name)
self.user_path = path
super().__init__()
try:
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
self.enable = True
self.scalers = []
scalers = self.load_models(path)
for scaler in scalers:
if scaler.name in opts.realesrgan_enabled_models:
self.scalers.append(scaler)
except Exception:
print("Error importing Real-ESRGAN:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
self.enable = False
self.scalers = []
def do_upscale(self, img, path):
if not self.enable:
return img
info = self.load_model(path)
if not os.path.exists(info.data_path):
print("Unable to load RealESRGAN model: %s" % info.name)
return img
upsampler = RealESRGANer(
scale=info.scale,
model_path=info.data_path,
model=info.model(),
half=not cmd_opts.no_half,
tile=opts.ESRGAN_tile,
tile_pad=opts.ESRGAN_tile_overlap,
)
upsampled = upsampler.enhance(np.array(img), outscale=info.scale)[0]
image = Image.fromarray(upsampled)
return image
def load_model(self, path):
try:
info = None
for scaler in self.scalers:
if scaler.data_path == path:
info = scaler
if info is None:
print(f"Unable to find model info: {path}")
return None
model_file = load_file_from_url(url=info.data_path, model_dir=self.model_path, progress=True)
info.data_path = model_file
return info
except Exception as e:
print(f"Error making Real-ESRGAN models list: {e}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
return None
def get_realesrgan_models(): def load_models(self, _):
return get_realesrgan_models(self)
def get_realesrgan_models(scaler):
try: try:
from basicsr.archs.rrdbnet_arch import RRDBNet from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact from realesrgan.archs.srvgg_arch import SRVGGNetCompact
models = [ models = [
RealesrganModelInfo( UpscalerData(
name="Real-ESRGAN General x4x3", name="R-ESRGAN General 4xV3",
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth", path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
netscale=4, scale=4,
upscaler=scaler,
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
), ),
RealesrganModelInfo( UpscalerData(
name="Real-ESRGAN General WDN x4x3", name="R-ESRGAN General WDN 4xV3",
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth", path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
netscale=4, scale=4,
upscaler=scaler,
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
), ),
RealesrganModelInfo( UpscalerData(
name="Real-ESRGAN AnimeVideo", name="R-ESRGAN AnimeVideo",
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth", path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
netscale=4, scale=4,
upscaler=scaler,
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu') model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
), ),
RealesrganModelInfo( UpscalerData(
name="Real-ESRGAN 4x plus", name="R-ESRGAN 4x+",
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
netscale=4, scale=4,
upscaler=scaler,
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
), ),
RealesrganModelInfo( UpscalerData(
name="Real-ESRGAN 4x plus anime 6B", name="R-ESRGAN 4x+ Anime6B",
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth", path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
netscale=4, scale=4,
upscaler=scaler,
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
), ),
RealesrganModelInfo( UpscalerData(
name="Real-ESRGAN 2x plus", name="R-ESRGAN 2x+",
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth", path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
netscale=2, scale=2,
upscaler=scaler,
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
), ),
] ]
return models return models
except Exception as e: except Exception as e:
print("Error makeing Real-ESRGAN midels list:", file=sys.stderr) print("Error making Real-ESRGAN models list:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
class UpscalerRealESRGAN(modules.images.Upscaler):
def __init__(self, upscaling, model_index):
self.upscaling = upscaling
self.model_index = model_index
self.name = realesrgan_models[model_index].name
def do_upscale(self, img):
return upscale_with_realesrgan(img, self.upscaling, self.model_index)
def setup_realesrgan():
global realesrgan_models
global have_realesrgan
try:
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
realesrgan_models = get_realesrgan_models()
have_realesrgan = True
for i, model in enumerate(realesrgan_models):
if model.name in opts.realesrgan_enabled_models:
modules.shared.sd_upscalers.append(UpscalerRealESRGAN(model.netscale, i))
except Exception:
print("Error importing Real-ESRGAN:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
realesrgan_models = [RealesrganModelInfo('None', '', 0, None)]
have_realesrgan = False
def upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index):
if not have_realesrgan:
return image
info = realesrgan_models[RealESRGAN_model_index]
model = info.model()
upsampler = RealESRGANer(
scale=info.netscale,
model_path=info.location,
model=model,
half=not cmd_opts.no_half,
tile=opts.ESRGAN_tile,
tile_pad=opts.ESRGAN_tile_overlap,
)
upsampled = upsampler.enhance(np.array(image), outscale=RealESRGAN_upscaling)[0]
image = Image.fromarray(upsampled)
return image
This diff is collapsed.
import math
import torch
from torch import einsum
from ldm.util import default
from einops import rearrange
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
def split_cross_attention_forward_v1(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
for i in range(0, q.shape[0], 2):
end = i + 2
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
s1 *= self.scale
s2 = s1.softmax(dim=-1)
del s1
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
del s2
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1
return self.to_out(r2)
# taken from https://github.com/Doggettx/stable-diffusion
def split_cross_attention_forward(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
context = default(context, x)
k_in = self.to_k(context) * self.scale
v_in = self.to_v(context)
del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
modifier = 3 if q.element_size() == 2 else 2.5
mem_required = tensor_size * modifier
steps = 1
if mem_required > mem_free_total:
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
if steps > 64:
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
s2 = s1.softmax(dim=-1, dtype=q.dtype)
del s1
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
del q, k, v
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1
return self.to_out(r2)
def nonlinearity_hijack(x):
# swish
t = torch.sigmoid(x)
x *= t
del t
return x
def cross_attention_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)
q1 = self.q(h_)
k1 = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q1.shape
q2 = q1.reshape(b, c, h*w)
del q1
q = q2.permute(0, 2, 1) # b,hw,c
del q2
k = k1.reshape(b, c, h*w) # b,c,hw
del k1
h_ = torch.zeros_like(k, device=q.device)
stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
mem_required = tensor_size * 2.5
steps = 1
if mem_required > mem_free_total:
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w2 = w1 * (int(c)**(-0.5))
del w1
w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
del w2
# attend to values
v1 = v.reshape(b, c, h*w)
w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
del w3
h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
del v1, w4
h2 = h_.reshape(b, c, h, w)
del h_
h3 = self.proj_out(h2)
del h2
h3 += x
return h3
...@@ -8,7 +8,14 @@ from omegaconf import OmegaConf ...@@ -8,7 +8,14 @@ from omegaconf import OmegaConf
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
from modules import shared from modules import shared, modelloader, devices
from modules.paths import models_path
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(models_path, model_dir))
model_name = "sd-v1-4.ckpt"
model_url = "https://drive.yerf.org/wl/?id=EBfTrmcCCUAGaQBXVIj5lJmEhjoP1tgl&mode=grid&download=1"
user_dir = None
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name']) CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name'])
checkpoints_list = {} checkpoints_list = {}
...@@ -23,20 +30,30 @@ except Exception: ...@@ -23,20 +30,30 @@ except Exception:
pass pass
def setup_model(dirname):
global user_dir
user_dir = dirname
if not os.path.exists(model_path):
os.makedirs(model_path)
checkpoints_list.clear()
list_models()
def checkpoint_tiles(): def checkpoint_tiles():
return sorted([x.title for x in checkpoints_list.values()]) return sorted([x.title for x in checkpoints_list.values()])
def list_models(): def list_models():
checkpoints_list.clear() checkpoints_list.clear()
model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=user_dir, ext_filter=[".ckpt"], download_name=model_name)
model_dir = os.path.abspath(shared.cmd_opts.ckpt_dir) def modeltitle(path, shorthash):
def modeltitle(path, h):
abspath = os.path.abspath(path) abspath = os.path.abspath(path)
if abspath.startswith(model_dir): if user_dir is not None and abspath.startswith(user_dir):
name = abspath.replace(model_dir, '') name = abspath.replace(user_dir, '')
elif abspath.startswith(model_path):
name = abspath.replace(model_path, '')
else: else:
name = os.path.basename(path) name = os.path.basename(path)
...@@ -45,21 +62,27 @@ def list_models(): ...@@ -45,21 +62,27 @@ def list_models():
shortname = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0] shortname = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
return f'{name} [{h}]', shortname return f'{name} [{shorthash}]', shortname
cmd_ckpt = shared.cmd_opts.ckpt cmd_ckpt = shared.cmd_opts.ckpt
if os.path.exists(cmd_ckpt): if os.path.exists(cmd_ckpt):
h = model_hash(cmd_ckpt) h = model_hash(cmd_ckpt)
title, model_name = modeltitle(cmd_ckpt, h) title, short_model_name = modeltitle(cmd_ckpt, h)
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, model_name) checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name)
shared.opts.sd_model_checkpoint = title
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file: elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
print(f"Checkpoint in --ckpt argument not found: {cmd_ckpt}", file=sys.stderr) print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
for filename in model_list:
h = model_hash(filename)
title, short_model_name = modeltitle(filename, h)
checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name)
if os.path.exists(model_dir):
for filename in glob.glob(model_dir + '/**/*.ckpt', recursive=True): def get_closet_checkpoint_match(searchString):
h = model_hash(filename) applicable = sorted([info for info in checkpoints_list.values() if searchString in info.title], key = lambda x:len(x.title))
title, model_name = modeltitle(filename, h) if len(applicable) > 0:
checkpoints_list[title] = CheckpointInfo(filename, title, h, model_name) return applicable[0]
return None
def model_hash(filename): def model_hash(filename):
...@@ -111,6 +134,8 @@ def load_model_weights(model, checkpoint_file, sd_model_hash): ...@@ -111,6 +134,8 @@ def load_model_weights(model, checkpoint_file, sd_model_hash):
if not shared.cmd_opts.no_half: if not shared.cmd_opts.no_half:
model.half() model.half()
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
model.sd_model_hash = sd_model_hash model.sd_model_hash = sd_model_hash
model.sd_model_checkpint = checkpoint_file model.sd_model_checkpint = checkpoint_file
...@@ -137,7 +162,7 @@ def load_model(): ...@@ -137,7 +162,7 @@ def load_model():
def reload_model_weights(sd_model, info=None): def reload_model_weights(sd_model, info=None):
from modules import lowvram, devices from modules import lowvram, devices, sd_hijack
checkpoint_info = info or select_checkpoint() checkpoint_info = info or select_checkpoint()
if sd_model.sd_model_checkpint == checkpoint_info.filename: if sd_model.sd_model_checkpint == checkpoint_info.filename:
...@@ -148,8 +173,12 @@ def reload_model_weights(sd_model, info=None): ...@@ -148,8 +173,12 @@ def reload_model_weights(sd_model, info=None):
else: else:
sd_model.to(devices.cpu) sd_model.to(devices.cpu)
sd_hijack.model_hijack.undo_hijack(sd_model)
load_model_weights(sd_model, checkpoint_info.filename, checkpoint_info.hash) load_model_weights(sd_model, checkpoint_info.filename, checkpoint_info.hash)
sd_hijack.model_hijack.hijack(sd_model)
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
sd_model.to(devices.device) sd_model.to(devices.device)
......
...@@ -4,7 +4,6 @@ import torch ...@@ -4,7 +4,6 @@ import torch
import tqdm import tqdm
from PIL import Image from PIL import Image
import inspect import inspect
import k_diffusion.sampling import k_diffusion.sampling
import ldm.models.diffusion.ddim import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms import ldm.models.diffusion.plms
...@@ -23,6 +22,8 @@ samplers_k_diffusion = [ ...@@ -23,6 +22,8 @@ samplers_k_diffusion = [
('Heun', 'sample_heun', ['k_heun']), ('Heun', 'sample_heun', ['k_heun']),
('DPM2', 'sample_dpm_2', ['k_dpm_2']), ('DPM2', 'sample_dpm_2', ['k_dpm_2']),
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a']), ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a']),
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast']),
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad']),
] ]
samplers_data_k_diffusion = [ samplers_data_k_diffusion = [
...@@ -36,7 +37,7 @@ samplers = [ ...@@ -36,7 +37,7 @@ samplers = [
SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), []), SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), []),
SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), []), SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), []),
] ]
samplers_for_img2img = [x for x in samplers if x.name != 'PLMS'] samplers_for_img2img = [x for x in samplers if x.name not in ['PLMS', 'DPM fast', 'DPM adaptive']]
sampler_extra_params = { sampler_extra_params = {
'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'], 'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
...@@ -289,7 +290,10 @@ class KDiffusionSampler: ...@@ -289,7 +290,10 @@ class KDiffusionSampler:
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None): def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None):
steps, t_enc = setup_img2img_steps(p, steps) steps, t_enc = setup_img2img_steps(p, steps)
sigmas = self.model_wrap.get_sigmas(steps) if p.sampler_noise_scheduler_override:
sigmas = p.sampler_noise_scheduler_override(steps)
else:
sigmas = self.model_wrap.get_sigmas(steps)
noise = noise * sigmas[steps - t_enc - 1] noise = noise * sigmas[steps - t_enc - 1]
xi = x + noise xi = x + noise
...@@ -305,12 +309,20 @@ class KDiffusionSampler: ...@@ -305,12 +309,20 @@ class KDiffusionSampler:
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None): def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
steps = steps or p.steps steps = steps or p.steps
sigmas = self.model_wrap.get_sigmas(steps) if p.sampler_noise_scheduler_override:
sigmas = p.sampler_noise_scheduler_override(steps)
else:
sigmas = self.model_wrap.get_sigmas(steps)
x = x * sigmas[0] x = x * sigmas[0]
extra_params_kwargs = self.initialize(p) extra_params_kwargs = self.initialize(p)
if 'sigma_min' in inspect.signature(self.func).parameters:
samples = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs) extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()
extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item()
if 'n' in inspect.signature(self.func).parameters:
extra_params_kwargs['n'] = steps
else:
extra_params_kwargs['sigmas'] = sigmas
samples = self.func(self.model_wrap_cfg, x, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
return samples return samples
import sys
import argparse import argparse
import datetime
import json import json
import os import os
import sys
import gradio as gr import gradio as gr
import tqdm import tqdm
import datetime
import modules.artists import modules.artists
from modules.paths import script_path, sd_path
from modules.devices import get_optimal_device
import modules.styles
import modules.interrogate import modules.interrogate
import modules.memmon import modules.memmon
import modules.sd_models import modules.sd_models
import modules.styles
from modules.devices import get_optimal_device
from modules.paths import script_path, sd_path
sd_model_file = os.path.join(script_path, 'model.ckpt') sd_model_file = os.path.join(script_path, 'model.ckpt')
default_sd_model_file = sd_model_file default_sd_model_file = sd_model_file
model_path = os.path.join(script_path, 'models')
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",) parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",)
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; this checkpoint will be added to the list of checkpoints and loaded by default if you don't have a checkpoint selected in settings",) parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
parser.add_argument("--ckpt-dir", type=str, default=os.path.join(script_path, 'models'), help="path to directory with stable diffusion checkpoints",) parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None) parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats") parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
...@@ -34,8 +35,13 @@ parser.add_argument("--always-batch-cond-uncond", action='store_true', help="dis ...@@ -34,8 +35,13 @@ parser.add_argument("--always-batch-cond-uncond", action='store_true', help="dis
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.") parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast") parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)") parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)")
parser.add_argument("--esrgan-models-path", type=str, help="path to directory with ESRGAN models", default=os.path.join(script_path, 'ESRGAN')) parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(model_path, 'Codeformer'))
parser.add_argument("--swinir-models-path", type=str, help="path to directory with SwinIR models", default=os.path.join(script_path, 'SwinIR')) parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(model_path, 'GFPGAN'))
parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(model_path, 'ESRGAN'))
parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(model_path, 'BSRGAN'))
parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(model_path, 'RealESRGAN'))
parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(model_path, 'SwinIR'))
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(model_path, 'LDSR'))
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.") parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
...@@ -53,7 +59,6 @@ parser.add_argument("--autolaunch", action='store_true', help="open the webui UR ...@@ -53,7 +59,6 @@ parser.add_argument("--autolaunch", action='store_true', help="open the webui UR
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False) parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
cmd_opts = parser.parse_args() cmd_opts = parser.parse_args()
device = get_optimal_device() device = get_optimal_device()
batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram) batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
...@@ -61,6 +66,7 @@ parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram ...@@ -61,6 +66,7 @@ parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
config_filename = cmd_opts.ui_settings_file config_filename = cmd_opts.ui_settings_file
class State: class State:
interrupted = False interrupted = False
job = "" job = ""
...@@ -72,6 +78,7 @@ class State: ...@@ -72,6 +78,7 @@ class State:
current_latent = None current_latent = None
current_image = None current_image = None
current_image_sampling_step = 0 current_image_sampling_step = 0
textinfo = None
def interrupt(self): def interrupt(self):
self.interrupted = True self.interrupted = True
...@@ -82,7 +89,7 @@ class State: ...@@ -82,7 +89,7 @@ class State:
self.current_image_sampling_step = 0 self.current_image_sampling_step = 0
def get_job_timestamp(self): def get_job_timestamp(self):
return datetime.datetime.now().strftime("%Y%m%d%H%M%S") return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp?
state = State() state = State()
...@@ -95,13 +102,13 @@ prompt_styles = modules.styles.StyleDatabase(styles_filename) ...@@ -95,13 +102,13 @@ prompt_styles = modules.styles.StyleDatabase(styles_filename)
interrogator = modules.interrogate.InterrogateModels("interrogate") interrogator = modules.interrogate.InterrogateModels("interrogate")
face_restorers = [] face_restorers = []
# This was moved to webui.py with the other model "setup" calls.
modules.sd_models.list_models() # modules.sd_models.list_models()
def realesrgan_models_names(): def realesrgan_models_names():
import modules.realesrgan_model import modules.realesrgan_model
return [x.name for x in modules.realesrgan_model.get_realesrgan_models()] return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)]
class OptionInfo: class OptionInfo:
...@@ -167,13 +174,10 @@ options_templates.update(options_section(('saving-to-dirs', "Saving to a directo ...@@ -167,13 +174,10 @@ options_templates.update(options_section(('saving-to-dirs', "Saving to a directo
options_templates.update(options_section(('upscaling', "Upscaling"), { options_templates.update(options_section(('upscaling', "Upscaling"), {
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}), "ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}), "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
"realesrgan_enabled_models": OptionInfo(["Real-ESRGAN 4x plus", "Real-ESRGAN 4x plus anime 6B"], "Select which RealESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}), "realesrgan_enabled_models": OptionInfo(["R-ESRGAN x4+", "R-ESRGAN x4+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}),
"SWIN_tile": OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}), "SWIN_tile": OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}),
"SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}), "SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
"ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}), "ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}),
"ldsr_pre_down": OptionInfo(1, "LDSR Pre-process down-sample scale. 1 = no down-sampling, 4 = 1/4 scale.", gr.Slider, {"minimum": 1, "maximum": 4, "step": 1}),
"ldsr_post_down": OptionInfo(1, "LDSR Post-process down-sample scale. 1 = no down-sampling, 4 = 1/4 scale.", gr.Slider, {"minimum": 1, "maximum": 4, "step": 1}),
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Radio, lambda: {"choices": [x.name for x in sd_upscalers]}), "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Radio, lambda: {"choices": [x.name for x in sd_upscalers]}),
})) }))
...@@ -190,9 +194,9 @@ options_templates.update(options_section(('system', "System"), { ...@@ -190,9 +194,9 @@ options_templates.update(options_section(('system', "System"), {
})) }))
options_templates.update(options_section(('sd', "Stable Diffusion"), { options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Radio, lambda: {"choices": modules.sd_models.checkpoint_tiles()}), "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."), "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."), "enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
"enable_emphasis": OptionInfo(True, "Eemphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"), "enable_emphasis": OptionInfo(True, "Eemphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
......
...@@ -53,6 +53,12 @@ class StyleDatabase: ...@@ -53,6 +53,12 @@ class StyleDatabase:
negative_prompt = row.get("negative_prompt", "") negative_prompt = row.get("negative_prompt", "")
self.styles[row["name"]] = PromptStyle(row["name"], prompt, negative_prompt) self.styles[row["name"]] = PromptStyle(row["name"], prompt, negative_prompt)
def get_style_prompts(self, styles):
return [self.styles.get(x, self.no_style).prompt for x in styles]
def get_negative_style_prompts(self, styles):
return [self.styles.get(x, self.no_style).negative_prompt for x in styles]
def apply_styles_to_prompt(self, prompt, styles): def apply_styles_to_prompt(self, prompt, styles):
return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).prompt for x in styles]) return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).prompt for x in styles])
......
import sys import contextlib
import traceback import os
import cv2
import os import numpy as np
import contextlib import torch
import numpy as np from PIL import Image
from PIL import Image from basicsr.utils.download_util import load_file_from_url
import torch from tqdm import tqdm
import modules.images
from modules.shared import cmd_opts, opts, device from modules import modelloader
from modules.swinir_arch import SwinIR as net from modules.paths import models_path
from modules.shared import cmd_opts, opts, device
precision_scope = ( from modules.swinir_model_arch import SwinIR as net
torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext from modules.upscaler import Upscaler, UpscalerData
)
precision_scope = (
torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
def load_model(filename, scale=4): )
model = net(
upscale=scale,
in_chans=3, class UpscalerSwinIR(Upscaler):
img_size=64, def __init__(self, dirname):
window_size=8, self.name = "SwinIR"
img_range=1.0, self.model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0" \
depths=[6, 6, 6, 6, 6, 6, 6, 6, 6], "/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR" \
embed_dim=240, "-L_x4_GAN.pth "
num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8], self.model_name = "SwinIR 4x"
mlp_ratio=2, self.model_path = os.path.join(models_path, self.name)
upsampler="nearest+conv", self.user_path = dirname
resi_connection="3conv", super().__init__()
) scalers = []
model_files = self.find_models(ext_filter=[".pt", ".pth"])
pretrained_model = torch.load(filename) for model in model_files:
model.load_state_dict(pretrained_model["params_ema"], strict=True) if "http" in model:
if not cmd_opts.no_half: name = self.model_name
model = model.half() else:
return model name = modelloader.friendly_name(model)
model_data = UpscalerData(name, model, self)
scalers.append(model_data)
def load_models(dirname): self.scalers = scalers
for file in os.listdir(dirname):
path = os.path.join(dirname, file) def do_upscale(self, img, model_file):
model_name, extension = os.path.splitext(file) model = self.load_model(model_file)
if model is None:
if extension != ".pt" and extension != ".pth": return img
continue model = model.to(device)
img = upscale(img, model)
try: try:
modules.shared.sd_upscalers.append(UpscalerSwin(path, model_name)) torch.cuda.empty_cache()
except Exception: except:
print(f"Error loading SwinIR model: {path}", file=sys.stderr) pass
print(traceback.format_exc(), file=sys.stderr) return img
def load_model(self, path, scale=4):
def upscale( if "http" in path:
img, dl_name = "%s%s" % (self.model_name.replace(" ", "_"), ".pth")
model, filename = load_file_from_url(url=path, model_dir=self.model_path, file_name=dl_name, progress=True)
tile=opts.SWIN_tile, else:
tile_overlap=opts.SWIN_tile_overlap, filename = path
window_size=8, if filename is None or not os.path.exists(filename):
scale=4, return None
): model = net(
img = np.array(img) upscale=scale,
img = img[:, :, ::-1] in_chans=3,
img = np.moveaxis(img, 2, 0) / 255 img_size=64,
img = torch.from_numpy(img).float() window_size=8,
img = img.unsqueeze(0).to(device) img_range=1.0,
with torch.no_grad(), precision_scope("cuda"): depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
_, _, h_old, w_old = img.size() embed_dim=240,
h_pad = (h_old // window_size + 1) * window_size - h_old num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
w_pad = (w_old // window_size + 1) * window_size - w_old mlp_ratio=2,
img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :] upsampler="nearest+conv",
img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad] resi_connection="3conv",
output = inference(img, model, tile, tile_overlap, window_size, scale) )
output = output[..., : h_old * scale, : w_old * scale]
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() pretrained_model = torch.load(filename)
if output.ndim == 3: model.load_state_dict(pretrained_model["params_ema"], strict=True)
output = np.transpose( if not cmd_opts.no_half:
output[[2, 1, 0], :, :], (1, 2, 0) model = model.half()
) # CHW-RGB to HCW-BGR return model
output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
return Image.fromarray(output, "RGB")
def upscale(
img,
def inference(img, model, tile, tile_overlap, window_size, scale): model,
# test the image tile by tile tile=opts.SWIN_tile,
b, c, h, w = img.size() tile_overlap=opts.SWIN_tile_overlap,
tile = min(tile, h, w) window_size=8,
assert tile % window_size == 0, "tile size should be a multiple of window_size" scale=4,
sf = scale ):
img = np.array(img)
stride = tile - tile_overlap img = img[:, :, ::-1]
h_idx_list = list(range(0, h - tile, stride)) + [h - tile] img = np.moveaxis(img, 2, 0) / 255
w_idx_list = list(range(0, w - tile, stride)) + [w - tile] img = torch.from_numpy(img).float()
E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=device).type_as(img) img = img.unsqueeze(0).to(device)
W = torch.zeros_like(E, dtype=torch.half, device=device) with torch.no_grad(), precision_scope("cuda"):
_, _, h_old, w_old = img.size()
for h_idx in h_idx_list: h_pad = (h_old // window_size + 1) * window_size - h_old
for w_idx in w_idx_list: w_pad = (w_old // window_size + 1) * window_size - w_old
in_patch = img[..., h_idx : h_idx + tile, w_idx : w_idx + tile] img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :]
out_patch = model(in_patch) img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad]
out_patch_mask = torch.ones_like(out_patch) output = inference(img, model, tile, tile_overlap, window_size, scale)
output = output[..., : h_old * scale, : w_old * scale]
E[ output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
..., h_idx * sf : (h_idx + tile) * sf, w_idx * sf : (w_idx + tile) * sf if output.ndim == 3:
].add_(out_patch) output = np.transpose(
W[ output[[2, 1, 0], :, :], (1, 2, 0)
..., h_idx * sf : (h_idx + tile) * sf, w_idx * sf : (w_idx + tile) * sf ) # CHW-RGB to HCW-BGR
].add_(out_patch_mask) output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
output = E.div_(W) return Image.fromarray(output, "RGB")
return output
def inference(img, model, tile, tile_overlap, window_size, scale):
# test the image tile by tile
class UpscalerSwin(modules.images.Upscaler): b, c, h, w = img.size()
def __init__(self, filename, title): tile = min(tile, h, w)
self.name = title assert tile % window_size == 0, "tile size should be a multiple of window_size"
self.model = load_model(filename) sf = scale
def do_upscale(self, img): stride = tile - tile_overlap
model = self.model.to(device) h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
img = upscale(img, model) w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
return img E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=device).type_as(img)
W = torch.zeros_like(E, dtype=torch.half, device=device)
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
for h_idx in h_idx_list:
for w_idx in w_idx_list:
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
out_patch = model(in_patch)
out_patch_mask = torch.ones_like(out_patch)
E[
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
].add_(out_patch)
W[
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
].add_(out_patch_mask)
pbar.update(1)
output = E.div_(W)
return output
import os
import numpy as np
import PIL
import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
import random
import tqdm
class PersonalizedBase(Dataset):
def __init__(self, data_root, size=None, repeats=100, flip_p=0.5, placeholder_token="*", width=512, height=512, model=None, device=None, template_file=None):
self.placeholder_token = placeholder_token
self.size = size
self.width = width
self.height = height
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
self.dataset = []
with open(template_file, "r") as file:
lines = [x.strip() for x in file.readlines()]
self.lines = lines
assert data_root, 'dataset directory not specified'
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
print("Preparing dataset...")
for path in tqdm.tqdm(self.image_paths):
image = Image.open(path)
image = image.convert('RGB')
image = image.resize((self.width, self.height), PIL.Image.BICUBIC)
filename = os.path.basename(path)
filename_tokens = os.path.splitext(filename)[0].replace('_', '-').replace(' ', '-').split('-')
filename_tokens = [token for token in filename_tokens if token.isalpha()]
npimage = np.array(image).astype(np.uint8)
npimage = (npimage / 127.5 - 1.0).astype(np.float32)
torchdata = torch.from_numpy(npimage).to(device=device, dtype=torch.float32)
torchdata = torch.moveaxis(torchdata, 2, 0)
init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
self.dataset.append((init_latent, filename_tokens))
self.length = len(self.dataset) * repeats
self.initial_indexes = np.arange(self.length) % len(self.dataset)
self.indexes = None
self.shuffle()
def shuffle(self):
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
def __len__(self):
return self.length
def __getitem__(self, i):
if i % len(self.dataset) == 0:
self.shuffle()
index = self.indexes[i % len(self.indexes)]
x, filename_tokens = self.dataset[index]
text = random.choice(self.lines)
text = text.replace("[name]", self.placeholder_token)
text = text.replace("[filewords]", ' '.join(filename_tokens))
return x, text
import os
import sys
import traceback
import torch
import tqdm
import html
import datetime
from modules import shared, devices, sd_hijack, processing
import modules.textual_inversion.dataset
class Embedding:
def __init__(self, vec, name, step=None):
self.vec = vec
self.name = name
self.step = step
self.cached_checksum = None
def save(self, filename):
embedding_data = {
"string_to_token": {"*": 265},
"string_to_param": {"*": self.vec},
"name": self.name,
"step": self.step,
}
torch.save(embedding_data, filename)
def checksum(self):
if self.cached_checksum is not None:
return self.cached_checksum
def const_hash(a):
r = 0
for v in a:
r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
return r
self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}'
return self.cached_checksum
class EmbeddingDatabase:
def __init__(self, embeddings_dir):
self.ids_lookup = {}
self.word_embeddings = {}
self.dir_mtime = None
self.embeddings_dir = embeddings_dir
def register_embedding(self, embedding, model):
self.word_embeddings[embedding.name] = embedding
ids = model.cond_stage_model.tokenizer([embedding.name], add_special_tokens=False)['input_ids'][0]
first_id = ids[0]
if first_id not in self.ids_lookup:
self.ids_lookup[first_id] = []
self.ids_lookup[first_id].append((ids, embedding))
return embedding
def load_textual_inversion_embeddings(self):
mt = os.path.getmtime(self.embeddings_dir)
if self.dir_mtime is not None and mt <= self.dir_mtime:
return
self.dir_mtime = mt
self.ids_lookup.clear()
self.word_embeddings.clear()
def process_file(path, filename):
name = os.path.splitext(filename)[0]
data = torch.load(path, map_location="cpu")
# textual inversion embeddings
if 'string_to_param' in data:
param_dict = data['string_to_param']
if hasattr(param_dict, '_parameters'):
param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
emb = next(iter(param_dict.items()))[1]
# diffuser concepts
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
emb = next(iter(data.values()))
if len(emb.shape) == 1:
emb = emb.unsqueeze(0)
else:
raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
vec = emb.detach().to(devices.device, dtype=torch.float32)
embedding = Embedding(vec, name)
embedding.step = data.get('step', None)
self.register_embedding(embedding, shared.sd_model)
for fn in os.listdir(self.embeddings_dir):
try:
fullfn = os.path.join(self.embeddings_dir, fn)
if os.stat(fullfn).st_size == 0:
continue
process_file(fullfn, fn)
except Exception:
print(f"Error loading emedding {fn}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
continue
print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
def find_embedding_at_position(self, tokens, offset):
token = tokens[offset]
possible_matches = self.ids_lookup.get(token, None)
if possible_matches is None:
return None
for ids, embedding in possible_matches:
if tokens[offset:offset + len(ids)] == ids:
return embedding
return None
def create_embedding(name, num_vectors_per_token):
init_text = '*'
cond_model = shared.sd_model.cond_stage_model
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"]
embedded = embedding_layer(ids.to(devices.device)).squeeze(0)
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
for i in range(num_vectors_per_token):
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
assert not os.path.exists(fn), f"file {fn} already exists"
embedding = Embedding(vec, name)
embedding.step = 0
embedding.save(fn)
return fn
def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file):
assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..."
shared.state.job_count = steps
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%d-%m"), embedding_name)
if save_embedding_every > 0:
embedding_dir = os.path.join(log_directory, "embeddings")
os.makedirs(embedding_dir, exist_ok=True)
else:
embedding_dir = None
if create_image_every > 0:
images_dir = os.path.join(log_directory, "images")
os.makedirs(images_dir, exist_ok=True)
else:
images_dir = None
cond_model = shared.sd_model.cond_stage_model
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=512, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
hijack = sd_hijack.model_hijack
embedding = hijack.embedding_db.word_embeddings[embedding_name]
embedding.vec.requires_grad = True
optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
losses = torch.zeros((32,))
last_saved_file = "<none>"
last_saved_image = "<none>"
ititial_step = embedding.step or 0
if ititial_step > steps:
return embedding, filename
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
for i, (x, text) in pbar:
embedding.step = i + ititial_step
if embedding.step > steps:
break
if shared.state.interrupted:
break
with torch.autocast("cuda"):
c = cond_model([text])
loss = shared.sd_model(x.unsqueeze(0), c)[0]
losses[embedding.step % losses.shape[0]] = loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
pbar.set_description(f"loss: {losses.mean():.7f}")
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
embedding.save(last_saved_file)
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
prompt=text,
steps=20,
do_not_save_grid=True,
do_not_save_samples=True,
)
processed = processing.process_images(p)
image = processed.images[0]
shared.state.current_image = image
image.save(last_saved_image)
last_saved_image += f", prompt: {text}"
shared.state.job_no = embedding.step
shared.state.textinfo = f"""
<p>
Loss: {losses.mean():.7f}<br/>
Step: {embedding.step}<br/>
Last prompt: {html.escape(text)}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/>
</p>
"""
embedding.cached_checksum = None
embedding.save(filename)
return embedding, filename
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -34,7 +34,7 @@ class Script(scripts.Script): ...@@ -34,7 +34,7 @@ class Script(scripts.Script):
seed = p.seed seed = p.seed
init_img = p.init_images[0] init_img = p.init_images[0]
img = upscaler.upscale(init_img, init_img.width * 2, init_img.height * 2) img = upscaler.scaler.upscale(init_img, 2, upscaler.data_path)
devices.torch_gc() devices.torch_gc()
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment