Commit f5e44364 authored by InvincibleDude's avatar InvincibleDude Committed by GitHub

Merge branch 'master' into improved-hr-conflict-test

parents f6e27378 a9fed7c3
...@@ -157,5 +157,6 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al ...@@ -157,5 +157,6 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
- Sampling in float32 precision from a float16 UNet - marunine for the idea, Birch-san for the example Diffusers implementation (https://github.com/Birch-san/diffusers-play/tree/92feee6) - Sampling in float32 precision from a float16 UNet - marunine for the idea, Birch-san for the example Diffusers implementation (https://github.com/Birch-san/diffusers-play/tree/92feee6)
- Instruct pix2pix - Tim Brooks (star), Aleksander Holynski (star), Alexei A. Efros (no star) - https://github.com/timothybrooks/instruct-pix2pix - Instruct pix2pix - Tim Brooks (star), Aleksander Holynski (star), Alexei A. Efros (no star) - https://github.com/timothybrooks/instruct-pix2pix
- Security advice - RyotaK - Security advice - RyotaK
- UniPC sampler - Wenliang Zhao - https://github.com/wl-zhao/UniPC
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user. - Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
- (You) - (You)
...@@ -3,7 +3,9 @@ import os ...@@ -3,7 +3,9 @@ import os
import re import re
import torch import torch
from modules import shared, devices, sd_models from modules import shared, devices, sd_models, errors
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
re_digits = re.compile(r"\d+") re_digits = re.compile(r"\d+")
re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)") re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)")
...@@ -43,6 +45,23 @@ class LoraOnDisk: ...@@ -43,6 +45,23 @@ class LoraOnDisk:
def __init__(self, name, filename): def __init__(self, name, filename):
self.name = name self.name = name
self.filename = filename self.filename = filename
self.metadata = {}
_, ext = os.path.splitext(filename)
if ext.lower() == ".safetensors":
try:
self.metadata = sd_models.read_metadata_from_safetensors(filename)
except Exception as e:
errors.display(e, f"reading lora {filename}")
if self.metadata:
m = {}
for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)):
m[k] = v
self.metadata = m
self.ssmd_cover_images = self.metadata.pop('ssmd_cover_images', None) # those are cover images and they are too big to display in UI as text
class LoraModule: class LoraModule:
......
...@@ -15,21 +15,15 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): ...@@ -15,21 +15,15 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
def list_items(self): def list_items(self):
for name, lora_on_disk in lora.available_loras.items(): for name, lora_on_disk in lora.available_loras.items():
path, ext = os.path.splitext(lora_on_disk.filename) path, ext = os.path.splitext(lora_on_disk.filename)
previews = [path + ".png", path + ".preview.png"]
preview = None
for file in previews:
if os.path.isfile(file):
preview = self.link_preview(file)
break
yield { yield {
"name": name, "name": name,
"filename": path, "filename": path,
"preview": preview, "preview": self.find_preview(path),
"description": self.find_description(path),
"search_term": self.search_terms_from_path(lora_on_disk.filename), "search_term": self.search_terms_from_path(lora_on_disk.filename),
"prompt": json.dumps(f"<lora:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"), "prompt": json.dumps(f"<lora:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
"local_preview": path + ".png", "local_preview": f"{path}.{shared.opts.samples_format}",
"metadata": json.dumps(lora_on_disk.metadata, indent=4) if lora_on_disk.metadata else None,
} }
def allowed_directories_for_previews(self): def allowed_directories_for_previews(self):
......
<div class='card' {preview_html} onclick={card_clicked}> <div class='card' {preview_html} onclick={card_clicked}>
{metadata_button}
<div class='actions'> <div class='actions'>
<div class='additional'> <div class='additional'>
<ul> <ul>
...@@ -7,6 +9,7 @@ ...@@ -7,6 +9,7 @@
<span style="display:none" class='search_term'>{search_term}</span> <span style="display:none" class='search_term'>{search_term}</span>
</div> </div>
<span class='name'>{name}</span> <span class='name'>{name}</span>
<span class='description'>{description}</span>
</div> </div>
</div> </div>
This diff is collapsed.
...@@ -5,12 +5,10 @@ function setupExtraNetworksForTab(tabname){ ...@@ -5,12 +5,10 @@ function setupExtraNetworksForTab(tabname){
var tabs = gradioApp().querySelector('#'+tabname+'_extra_tabs > div') var tabs = gradioApp().querySelector('#'+tabname+'_extra_tabs > div')
var search = gradioApp().querySelector('#'+tabname+'_extra_search textarea') var search = gradioApp().querySelector('#'+tabname+'_extra_search textarea')
var refresh = gradioApp().getElementById(tabname+'_extra_refresh') var refresh = gradioApp().getElementById(tabname+'_extra_refresh')
var close = gradioApp().getElementById(tabname+'_extra_close')
search.classList.add('search') search.classList.add('search')
tabs.appendChild(search) tabs.appendChild(search)
tabs.appendChild(refresh) tabs.appendChild(refresh)
tabs.appendChild(close)
search.addEventListener("input", function(evt){ search.addEventListener("input", function(evt){
searchTerm = search.value.toLowerCase() searchTerm = search.value.toLowerCase()
...@@ -78,7 +76,7 @@ function cardClicked(tabname, textToAdd, allowNegativePrompt){ ...@@ -78,7 +76,7 @@ function cardClicked(tabname, textToAdd, allowNegativePrompt){
var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea") var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea")
if(! tryToRemoveExtraNetworkFromPrompt(textarea, textToAdd)){ if(! tryToRemoveExtraNetworkFromPrompt(textarea, textToAdd)){
textarea.value = textarea.value + " " + textToAdd textarea.value = textarea.value + opts.extra_networks_add_text_separator + textToAdd
} }
updateInput(textarea) updateInput(textarea)
...@@ -104,4 +102,40 @@ function extraNetworksSearchButton(tabs_id, event){ ...@@ -104,4 +102,40 @@ function extraNetworksSearchButton(tabs_id, event){
searchTextarea.value = text searchTextarea.value = text
updateInput(searchTextarea) updateInput(searchTextarea)
} }
\ No newline at end of file
var globalPopup = null;
var globalPopupInner = null;
function popup(contents){
if(! globalPopup){
globalPopup = document.createElement('div')
globalPopup.onclick = function(){ globalPopup.style.display = "none"; };
globalPopup.classList.add('global-popup');
var close = document.createElement('div')
close.classList.add('global-popup-close');
close.onclick = function(){ globalPopup.style.display = "none"; };
close.title = "Close";
globalPopup.appendChild(close)
globalPopupInner = document.createElement('div')
globalPopupInner.onclick = function(event){ event.stopPropagation(); return false; };
globalPopupInner.classList.add('global-popup-inner');
globalPopup.appendChild(globalPopupInner)
gradioApp().appendChild(globalPopup);
}
globalPopupInner.innerHTML = '';
globalPopupInner.appendChild(contents);
globalPopup.style.display = "flex";
}
function extraNetworksShowMetadata(text){
elem = document.createElement('pre')
elem.classList.add('popup-metadata');
elem.textContent = text;
popup(elem);
}
...@@ -6,6 +6,7 @@ titles = { ...@@ -6,6 +6,7 @@ titles = {
"GFPGAN": "Restore low quality faces using GFPGAN neural network", "GFPGAN": "Restore low quality faces using GFPGAN neural network",
"Euler a": "Euler Ancestral - very creative, each can get a completely different picture depending on step count, setting steps higher than 30-40 does not help", "Euler a": "Euler Ancestral - very creative, each can get a completely different picture depending on step count, setting steps higher than 30-40 does not help",
"DDIM": "Denoising Diffusion Implicit Models - best at inpainting", "DDIM": "Denoising Diffusion Implicit Models - best at inpainting",
"UniPC": "Unified Predictor-Corrector Framework for Fast Sampling of Diffusion Models",
"DPM adaptive": "Ignores step count - uses a number of steps determined by the CFG and resolution", "DPM adaptive": "Ignores step count - uses a number of steps determined by the CFG and resolution",
"Batch count": "How many batches of images to create (has no impact on generation performance or VRAM usage)", "Batch count": "How many batches of images to create (has no impact on generation performance or VRAM usage)",
......
...@@ -11,7 +11,7 @@ function showModal(event) { ...@@ -11,7 +11,7 @@ function showModal(event) {
if (modalImage.style.display === 'none') { if (modalImage.style.display === 'none') {
lb.style.setProperty('background-image', 'url(' + source.src + ')'); lb.style.setProperty('background-image', 'url(' + source.src + ')');
} }
lb.style.display = "block"; lb.style.display = "flex";
lb.focus() lb.focus()
const tabTxt2Img = gradioApp().getElementById("tab_txt2img") const tabTxt2Img = gradioApp().getElementById("tab_txt2img")
......
...@@ -15,7 +15,7 @@ onUiUpdate(function(){ ...@@ -15,7 +15,7 @@ onUiUpdate(function(){
} }
} }
const galleryPreviews = gradioApp().querySelectorAll('div[id^="tab_"][style*="display: block"] img.h-full.w-full.overflow-hidden'); const galleryPreviews = gradioApp().querySelectorAll('div[id^="tab_"][style*="display: block"] div[id$="_results"] img.h-full.w-full.overflow-hidden');
if (galleryPreviews == null) return; if (galleryPreviews == null) return;
......
...@@ -139,7 +139,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre ...@@ -139,7 +139,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
var divProgress = document.createElement('div') var divProgress = document.createElement('div')
divProgress.className='progressDiv' divProgress.className='progressDiv'
divProgress.style.display = opts.show_progressbar ? "" : "none" divProgress.style.display = opts.show_progressbar ? "block" : "none"
var divInner = document.createElement('div') var divInner = document.createElement('div')
divInner.className='progress' divInner.className='progress'
......
...@@ -8,6 +8,14 @@ import platform ...@@ -8,6 +8,14 @@ import platform
import argparse import argparse
import json import json
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument("--ui-settings-file", type=str, default='config.json')
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.realpath(__file__)))
args, _ = parser.parse_known_args(sys.argv)
script_path = os.path.dirname(__file__)
data_path = os.getcwd()
dir_repos = "repositories" dir_repos = "repositories"
dir_extensions = "extensions" dir_extensions = "extensions"
python = sys.executable python = sys.executable
...@@ -122,7 +130,7 @@ def is_installed(package): ...@@ -122,7 +130,7 @@ def is_installed(package):
def repo_dir(name): def repo_dir(name):
return os.path.join(dir_repos, name) return os.path.join(script_path, dir_repos, name)
def run_python(code, desc=None, errdesc=None): def run_python(code, desc=None, errdesc=None):
...@@ -161,7 +169,17 @@ def git_clone(url, dir, name, commithash=None): ...@@ -161,7 +169,17 @@ def git_clone(url, dir, name, commithash=None):
if commithash is not None: if commithash is not None:
run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}") run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
def git_pull_recursive(dir):
for subdir, _, _ in os.walk(dir):
if os.path.exists(os.path.join(subdir, '.git')):
try:
output = subprocess.check_output([git, '-C', subdir, 'pull', '--autostash'])
print(f"Pulled changes for repository in '{subdir}':\n{output.decode('utf-8').strip()}\n")
except subprocess.CalledProcessError as e:
print(f"Couldn't perform 'git pull' on repository in '{subdir}':\n{e.output.decode('utf-8').strip()}\n")
def version_check(commit): def version_check(commit):
try: try:
import requests import requests
...@@ -205,7 +223,7 @@ def list_extensions(settings_file): ...@@ -205,7 +223,7 @@ def list_extensions(settings_file):
disabled_extensions = set(settings.get('disabled_extensions', [])) disabled_extensions = set(settings.get('disabled_extensions', []))
return [x for x in os.listdir(dir_extensions) if x not in disabled_extensions] return [x for x in os.listdir(os.path.join(data_path, dir_extensions)) if x not in disabled_extensions]
def run_extensions_installers(settings_file): def run_extensions_installers(settings_file):
...@@ -242,11 +260,8 @@ def prepare_environment(): ...@@ -242,11 +260,8 @@ def prepare_environment():
sys.argv += shlex.split(commandline_args) sys.argv += shlex.split(commandline_args)
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default='config.json')
args, _ = parser.parse_known_args(sys.argv)
sys.argv, _ = extract_arg(sys.argv, '-f') sys.argv, _ = extract_arg(sys.argv, '-f')
sys.argv, update_all_extensions = extract_arg(sys.argv, '--update-all-extensions')
sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test') sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test')
sys.argv, skip_python_version_check = extract_arg(sys.argv, '--skip-python-version-check') sys.argv, skip_python_version_check = extract_arg(sys.argv, '--skip-python-version-check')
sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers') sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers')
...@@ -295,7 +310,7 @@ def prepare_environment(): ...@@ -295,7 +310,7 @@ def prepare_environment():
if not is_installed("pyngrok") and ngrok: if not is_installed("pyngrok") and ngrok:
run_pip("install pyngrok", "ngrok") run_pip("install pyngrok", "ngrok")
os.makedirs(dir_repos, exist_ok=True) os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash) git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
git_clone(taming_transformers_repo, repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash) git_clone(taming_transformers_repo, repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash)
...@@ -304,14 +319,19 @@ def prepare_environment(): ...@@ -304,14 +319,19 @@ def prepare_environment():
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash) git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
if not is_installed("lpips"): if not is_installed("lpips"):
run_pip(f"install -r {os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}", "requirements for CodeFormer") run_pip(f"install -r \"{os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}\"", "requirements for CodeFormer")
run_pip(f"install -r {requirements_file}", "requirements for Web UI") if not os.path.isfile(requirements_file):
requirements_file = os.path.join(script_path, requirements_file)
run_pip(f"install -r \"{requirements_file}\"", "requirements for Web UI")
run_extensions_installers(settings_file=args.ui_settings_file) run_extensions_installers(settings_file=args.ui_settings_file)
if update_check: if update_check:
version_check(commit) version_check(commit)
if update_all_extensions:
git_pull_recursive(os.path.join(data_path, dir_extensions))
if "--exit" in sys.argv: if "--exit" in sys.argv:
print("Exiting because of --exit argument") print("Exiting because of --exit argument")
...@@ -327,7 +347,7 @@ def tests(test_dir): ...@@ -327,7 +347,7 @@ def tests(test_dir):
sys.argv.append("--api") sys.argv.append("--api")
if "--ckpt" not in sys.argv: if "--ckpt" not in sys.argv:
sys.argv.append("--ckpt") sys.argv.append("--ckpt")
sys.argv.append("./test/test_files/empty.pt") sys.argv.append(os.path.join(script_path, "test/test_files/empty.pt"))
if "--skip-torch-cuda-test" not in sys.argv: if "--skip-torch-cuda-test" not in sys.argv:
sys.argv.append("--skip-torch-cuda-test") sys.argv.append("--skip-torch-cuda-test")
if "--disable-nan-check" not in sys.argv: if "--disable-nan-check" not in sys.argv:
...@@ -336,7 +356,7 @@ def tests(test_dir): ...@@ -336,7 +356,7 @@ def tests(test_dir):
print(f"Launching Web UI in another process for testing with arguments: {' '.join(sys.argv[1:])}") print(f"Launching Web UI in another process for testing with arguments: {' '.join(sys.argv[1:])}")
os.environ['COMMANDLINE_ARGS'] = "" os.environ['COMMANDLINE_ARGS'] = ""
with open('test/stdout.txt', "w", encoding="utf8") as stdout, open('test/stderr.txt', "w", encoding="utf8") as stderr: with open(os.path.join(script_path, 'test/stdout.txt'), "w", encoding="utf8") as stdout, open(os.path.join(script_path, 'test/stderr.txt'), "w", encoding="utf8") as stderr:
proc = subprocess.Popen([sys.executable, *sys.argv], stdout=stdout, stderr=stderr) proc = subprocess.Popen([sys.executable, *sys.argv], stdout=stdout, stderr=stderr)
import test.server_poll import test.server_poll
......
...@@ -150,6 +150,7 @@ class Api: ...@@ -150,6 +150,7 @@ class Api:
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse) self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse)
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse) self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse)
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=MemoryResponse) self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=MemoryResponse)
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=ScriptsList)
def add_api_route(self, path: str, endpoint, **kwargs): def add_api_route(self, path: str, endpoint, **kwargs):
if shared.cmd_opts.api_auth: if shared.cmd_opts.api_auth:
...@@ -163,47 +164,98 @@ class Api: ...@@ -163,47 +164,98 @@ class Api:
raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"}) raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"})
def get_script(self, script_name, script_runner): def get_selectable_script(self, script_name, script_runner):
if script_name is None: if script_name is None or script_name == "":
return None, None return None, None
if not script_runner.scripts:
script_runner.initialize_scripts(False)
ui.create_ui()
script_idx = script_name_to_index(script_name, script_runner.selectable_scripts) script_idx = script_name_to_index(script_name, script_runner.selectable_scripts)
script = script_runner.selectable_scripts[script_idx] script = script_runner.selectable_scripts[script_idx]
return script, script_idx return script, script_idx
def get_scripts_list(self):
t2ilist = [str(title.lower()) for title in scripts.scripts_txt2img.titles]
i2ilist = [str(title.lower()) for title in scripts.scripts_img2img.titles]
return ScriptsList(txt2img = t2ilist, img2img = i2ilist)
def get_script(self, script_name, script_runner):
if script_name is None or script_name == "":
return None, None
script_idx = script_name_to_index(script_name, script_runner.scripts)
return script_runner.scripts[script_idx]
def init_script_args(self, request, selectable_scripts, selectable_idx, script_runner):
#find max idx from the scripts in runner and generate a none array to init script_args
last_arg_index = 1
for script in script_runner.scripts:
if last_arg_index < script.args_to:
last_arg_index = script.args_to
# None everywhere except position 0 to initialize script args
script_args = [None]*last_arg_index
# position 0 in script_arg is the idx+1 of the selectable script that is going to be run when using scripts.scripts_*2img.run()
if selectable_scripts:
script_args[selectable_scripts.args_from:selectable_scripts.args_to] = request.script_args
script_args[0] = selectable_idx + 1
else:
# when [0] = 0 no selectable script to run
script_args[0] = 0
# Now check for always on scripts
if request.alwayson_scripts and (len(request.alwayson_scripts) > 0):
for alwayson_script_name in request.alwayson_scripts.keys():
alwayson_script = self.get_script(alwayson_script_name, script_runner)
if alwayson_script == None:
raise HTTPException(status_code=422, detail=f"always on script {alwayson_script_name} not found")
# Selectable script in always on script param check
if alwayson_script.alwayson == False:
raise HTTPException(status_code=422, detail=f"Cannot have a selectable script in the always on scripts params")
# always on script with no arg should always run so you don't really need to add them to the requests
if "args" in request.alwayson_scripts[alwayson_script_name]:
script_args[alwayson_script.args_from:alwayson_script.args_to] = request.alwayson_scripts[alwayson_script_name]["args"]
return script_args
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
script, script_idx = self.get_script(txt2imgreq.script_name, scripts.scripts_txt2img) script_runner = scripts.scripts_txt2img
if not script_runner.scripts:
script_runner.initialize_scripts(False)
ui.create_ui()
selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner)
populate = txt2imgreq.copy(update={ # Override __init__ params populate = txt2imgreq.copy(update={ # Override __init__ params
"sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index), "sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index),
"do_not_save_samples": True, "do_not_save_samples": not txt2imgreq.save_images,
"do_not_save_grid": True "do_not_save_grid": not txt2imgreq.save_images,
} })
)
if populate.sampler_name: if populate.sampler_name:
populate.sampler_index = None # prevent a warning later on populate.sampler_index = None # prevent a warning later on
args = vars(populate) args = vars(populate)
args.pop('script_name', None) args.pop('script_name', None)
args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
args.pop('alwayson_scripts', None)
script_args = self.init_script_args(txt2imgreq, selectable_scripts, selectable_script_idx, script_runner)
send_images = args.pop('send_images', True)
args.pop('save_images', None)
with self.queue_lock: with self.queue_lock:
p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args) p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)
p.scripts = script_runner
p.outpath_grids = opts.outdir_txt2img_grids
p.outpath_samples = opts.outdir_txt2img_samples
shared.state.begin() shared.state.begin()
if script is not None: if selectable_scripts != None:
p.outpath_grids = opts.outdir_txt2img_grids p.script_args = script_args
p.outpath_samples = opts.outdir_txt2img_samples processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args
processed = scripts.scripts_txt2img.run(p, *p.script_args)
else: else:
p.script_args = tuple(script_args) # Need to pass args as tuple here
processed = process_images(p) processed = process_images(p)
shared.state.end() shared.state.end()
b64images = list(map(encode_pil_to_base64, processed.images)) b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js()) return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
...@@ -212,41 +264,53 @@ class Api: ...@@ -212,41 +264,53 @@ class Api:
if init_images is None: if init_images is None:
raise HTTPException(status_code=404, detail="Init image not found") raise HTTPException(status_code=404, detail="Init image not found")
script, script_idx = self.get_script(img2imgreq.script_name, scripts.scripts_img2img)
mask = img2imgreq.mask mask = img2imgreq.mask
if mask: if mask:
mask = decode_base64_to_image(mask) mask = decode_base64_to_image(mask)
populate = img2imgreq.copy(update={ # Override __init__ params script_runner = scripts.scripts_img2img
if not script_runner.scripts:
script_runner.initialize_scripts(True)
ui.create_ui()
selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner)
populate = img2imgreq.copy(update={ # Override __init__ params
"sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index), "sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index),
"do_not_save_samples": True, "do_not_save_samples": not img2imgreq.save_images,
"do_not_save_grid": True, "do_not_save_grid": not img2imgreq.save_images,
"mask": mask "mask": mask,
} })
)
if populate.sampler_name: if populate.sampler_name:
populate.sampler_index = None # prevent a warning later on populate.sampler_index = None # prevent a warning later on
args = vars(populate) args = vars(populate)
args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine. args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
args.pop('script_name', None) args.pop('script_name', None)
args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
args.pop('alwayson_scripts', None)
script_args = self.init_script_args(img2imgreq, selectable_scripts, selectable_script_idx, script_runner)
send_images = args.pop('send_images', True)
args.pop('save_images', None)
with self.queue_lock: with self.queue_lock:
p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args) p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
p.init_images = [decode_base64_to_image(x) for x in init_images] p.init_images = [decode_base64_to_image(x) for x in init_images]
p.scripts = script_runner
p.outpath_grids = opts.outdir_img2img_grids
p.outpath_samples = opts.outdir_img2img_samples
shared.state.begin() shared.state.begin()
if script is not None: if selectable_scripts != None:
p.outpath_grids = opts.outdir_img2img_grids p.script_args = script_args
p.outpath_samples = opts.outdir_img2img_samples processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args
processed = scripts.scripts_img2img.run(p, *p.script_args)
else: else:
p.script_args = tuple(script_args) # Need to pass args as tuple here
processed = process_images(p) processed = process_images(p)
shared.state.end() shared.state.end()
b64images = list(map(encode_pil_to_base64, processed.images)) b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
if not img2imgreq.include_init_images: if not img2imgreq.include_init_images:
img2imgreq.init_images = None img2imgreq.init_images = None
......
...@@ -14,8 +14,8 @@ API_NOT_ALLOWED = [ ...@@ -14,8 +14,8 @@ API_NOT_ALLOWED = [
"outpath_samples", "outpath_samples",
"outpath_grids", "outpath_grids",
"sampler_index", "sampler_index",
"do_not_save_samples", # "do_not_save_samples",
"do_not_save_grid", # "do_not_save_grid",
"extra_generation_params", "extra_generation_params",
"overlay_images", "overlay_images",
"do_not_reload_embeddings", "do_not_reload_embeddings",
...@@ -100,13 +100,31 @@ class PydanticModelGenerator: ...@@ -100,13 +100,31 @@ class PydanticModelGenerator:
StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator( StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
"StableDiffusionProcessingTxt2Img", "StableDiffusionProcessingTxt2Img",
StableDiffusionProcessingTxt2Img, StableDiffusionProcessingTxt2Img,
[{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}] [
{"key": "sampler_index", "type": str, "default": "Euler"},
{"key": "script_name", "type": str, "default": None},
{"key": "script_args", "type": list, "default": []},
{"key": "send_images", "type": bool, "default": True},
{"key": "save_images", "type": bool, "default": False},
{"key": "alwayson_scripts", "type": dict, "default": {}},
]
).generate_model() ).generate_model()
StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator( StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
"StableDiffusionProcessingImg2Img", "StableDiffusionProcessingImg2Img",
StableDiffusionProcessingImg2Img, StableDiffusionProcessingImg2Img,
[{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}] [
{"key": "sampler_index", "type": str, "default": "Euler"},
{"key": "init_images", "type": list, "default": None},
{"key": "denoising_strength", "type": float, "default": 0.75},
{"key": "mask", "type": str, "default": None},
{"key": "include_init_images", "type": bool, "default": False, "exclude" : True},
{"key": "script_name", "type": str, "default": None},
{"key": "script_args", "type": list, "default": []},
{"key": "send_images", "type": bool, "default": True},
{"key": "save_images", "type": bool, "default": False},
{"key": "alwayson_scripts", "type": dict, "default": {}},
]
).generate_model() ).generate_model()
class TextToImageResponse(BaseModel): class TextToImageResponse(BaseModel):
...@@ -267,3 +285,7 @@ class EmbeddingsResponse(BaseModel): ...@@ -267,3 +285,7 @@ class EmbeddingsResponse(BaseModel):
class MemoryResponse(BaseModel): class MemoryResponse(BaseModel):
ram: dict = Field(title="RAM", description="System memory stats") ram: dict = Field(title="RAM", description="System memory stats")
cuda: dict = Field(title="CUDA", description="nVidia CUDA memory stats") cuda: dict = Field(title="CUDA", description="nVidia CUDA memory stats")
class ScriptsList(BaseModel):
txt2img: list = Field(default=None,title="Txt2img", description="Titles of scripts (txt2img)")
img2img: list = Field(default=None,title="Img2img", description="Titles of scripts (img2img)")
\ No newline at end of file
...@@ -55,7 +55,7 @@ def setup_model(dirname): ...@@ -55,7 +55,7 @@ def setup_model(dirname):
if self.net is not None and self.face_helper is not None: if self.net is not None and self.face_helper is not None:
self.net.to(devices.device_codeformer) self.net.to(devices.device_codeformer)
return self.net, self.face_helper return self.net, self.face_helper
model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth') model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth', ext_filter=['.pth'])
if len(model_paths) != 0: if len(model_paths) != 0:
ckpt_path = model_paths[0] ckpt_path = model_paths[0]
else: else:
......
...@@ -66,7 +66,7 @@ class Extension: ...@@ -66,7 +66,7 @@ class Extension:
def check_updates(self): def check_updates(self):
repo = git.Repo(self.path) repo = git.Repo(self.path)
for fetch in repo.remote().fetch("--dry-run"): for fetch in repo.remote().fetch(dry_run=True):
if fetch.flags != fetch.HEAD_UPTODATE: if fetch.flags != fetch.HEAD_UPTODATE:
self.can_update = True self.can_update = True
self.status = "behind" self.status = "behind"
...@@ -79,8 +79,8 @@ class Extension: ...@@ -79,8 +79,8 @@ class Extension:
repo = git.Repo(self.path) repo = git.Repo(self.path)
# Fix: `error: Your local changes to the following files would be overwritten by merge`, # Fix: `error: Your local changes to the following files would be overwritten by merge`,
# because WSL2 Docker set 755 file permissions instead of 644, this results to the error. # because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
repo.git.fetch('--all') repo.git.fetch(all=True)
repo.git.reset('--hard', 'origin') repo.git.reset('origin', hard=True)
def list_extensions(): def list_extensions():
......
...@@ -23,13 +23,14 @@ registered_param_bindings = [] ...@@ -23,13 +23,14 @@ registered_param_bindings = []
class ParamBinding: class ParamBinding:
def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None): def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=[]):
self.paste_button = paste_button self.paste_button = paste_button
self.tabname = tabname self.tabname = tabname
self.source_text_component = source_text_component self.source_text_component = source_text_component
self.source_image_component = source_image_component self.source_image_component = source_image_component
self.source_tabname = source_tabname self.source_tabname = source_tabname
self.override_settings_component = override_settings_component self.override_settings_component = override_settings_component
self.paste_field_names = paste_field_names
def reset(): def reset():
...@@ -134,7 +135,7 @@ def connect_paste_params_buttons(): ...@@ -134,7 +135,7 @@ def connect_paste_params_buttons():
connect_paste(binding.paste_button, fields, binding.source_text_component, override_settings_component, binding.tabname) connect_paste(binding.paste_button, fields, binding.source_text_component, override_settings_component, binding.tabname)
if binding.source_tabname is not None and fields is not None: if binding.source_tabname is not None and fields is not None:
paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else []) paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else []) + binding.paste_field_names
binding.paste_button.click( binding.paste_button.click(
fn=lambda *x: x, fn=lambda *x: x,
inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names], inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names],
...@@ -292,6 +293,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model ...@@ -292,6 +293,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
settings_map = {} settings_map = {}
infotext_to_setting_name_mapping = [ infotext_to_setting_name_mapping = [
('Clip skip', 'CLIP_stop_at_last_layers', ), ('Clip skip', 'CLIP_stop_at_last_layers', ),
('Conditional mask weight', 'inpainting_mask_weight'), ('Conditional mask weight', 'inpainting_mask_weight'),
...@@ -300,7 +303,11 @@ infotext_to_setting_name_mapping = [ ...@@ -300,7 +303,11 @@ infotext_to_setting_name_mapping = [
('Noise multiplier', 'initial_noise_multiplier'), ('Noise multiplier', 'initial_noise_multiplier'),
('Eta', 'eta_ancestral'), ('Eta', 'eta_ancestral'),
('Eta DDIM', 'eta_ddim'), ('Eta DDIM', 'eta_ddim'),
('Discard penultimate sigma', 'always_discard_next_to_last_sigma') ('Discard penultimate sigma', 'always_discard_next_to_last_sigma'),
('UniPC variant', 'uni_pc_variant'),
('UniPC skip type', 'uni_pc_skip_type'),
('UniPC order', 'uni_pc_order'),
('UniPC lower order final', 'uni_pc_lower_order_final'),
] ]
......
...@@ -556,7 +556,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i ...@@ -556,7 +556,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
elif image_to_save.mode == 'I;16': elif image_to_save.mode == 'I;16':
image_to_save = image_to_save.point(lambda p: p * 0.0038910505836576).convert("RGB" if extension.lower() == ".webp" else "L") image_to_save = image_to_save.point(lambda p: p * 0.0038910505836576).convert("RGB" if extension.lower() == ".webp" else "L")
image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality) image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality, lossless=opts.webp_lossless)
if opts.enable_pnginfo and info is not None: if opts.enable_pnginfo and info is not None:
exif_bytes = piexif.dump({ exif_bytes = piexif.dump({
...@@ -573,6 +573,11 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i ...@@ -573,6 +573,11 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
os.replace(temp_file_path, filename_without_extension + extension) os.replace(temp_file_path, filename_without_extension + extension)
fullfn_without_extension, extension = os.path.splitext(params.filename) fullfn_without_extension, extension = os.path.splitext(params.filename)
if hasattr(os, 'statvfs'):
max_name_len = os.statvfs(path).f_namemax
fullfn_without_extension = fullfn_without_extension[:max_name_len - max(4, len(extension))]
params.filename = fullfn_without_extension + extension
fullfn = params.filename
_atomically_save_image(image, fullfn_without_extension, extension) _atomically_save_image(image, fullfn_without_extension, extension)
image.already_saved_as = fullfn image.already_saved_as = fullfn
......
...@@ -23,7 +23,7 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs): ...@@ -23,7 +23,7 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs):
output_dtype = kwargs.get('dtype', input.dtype) output_dtype = kwargs.get('dtype', input.dtype)
if output_dtype == torch.int64: if output_dtype == torch.int64:
return cumsum_func(input.cpu(), *args, **kwargs).to(input.device) return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
elif cumsum_needs_bool_fix and output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16): elif output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64) return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
return cumsum_func(input, *args, **kwargs) return cumsum_func(input, *args, **kwargs)
...@@ -45,7 +45,6 @@ if has_mps: ...@@ -45,7 +45,6 @@ if has_mps:
CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad) CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad)
elif version.parse(torch.__version__) > version.parse("1.13.1"): elif version.parse(torch.__version__) > version.parse("1.13.1"):
cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0)) cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0))
cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs) cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs)
CondFunc('torch.cumsum', cumsum_fix_func, None) CondFunc('torch.cumsum', cumsum_fix_func, None)
CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None) CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
......
...@@ -23,12 +23,16 @@ class MemUsageMonitor(threading.Thread): ...@@ -23,12 +23,16 @@ class MemUsageMonitor(threading.Thread):
self.data = defaultdict(int) self.data = defaultdict(int)
try: try:
torch.cuda.mem_get_info() self.cuda_mem_get_info()
torch.cuda.memory_stats(self.device) torch.cuda.memory_stats(self.device)
except Exception as e: # AMD or whatever except Exception as e: # AMD or whatever
print(f"Warning: caught exception '{e}', memory monitor disabled") print(f"Warning: caught exception '{e}', memory monitor disabled")
self.disabled = True self.disabled = True
def cuda_mem_get_info(self):
index = self.device.index if self.device.index is not None else torch.cuda.current_device()
return torch.cuda.mem_get_info(index)
def run(self): def run(self):
if self.disabled: if self.disabled:
return return
...@@ -43,10 +47,10 @@ class MemUsageMonitor(threading.Thread): ...@@ -43,10 +47,10 @@ class MemUsageMonitor(threading.Thread):
self.run_flag.clear() self.run_flag.clear()
continue continue
self.data["min_free"] = torch.cuda.mem_get_info()[0] self.data["min_free"] = self.cuda_mem_get_info()[0]
while self.run_flag.is_set(): while self.run_flag.is_set():
free, total = torch.cuda.mem_get_info() # calling with self.device errors, torch bug? free, total = self.cuda_mem_get_info()
self.data["min_free"] = min(self.data["min_free"], free) self.data["min_free"] = min(self.data["min_free"], free)
time.sleep(1 / self.opts.memmon_poll_rate) time.sleep(1 / self.opts.memmon_poll_rate)
...@@ -70,7 +74,7 @@ class MemUsageMonitor(threading.Thread): ...@@ -70,7 +74,7 @@ class MemUsageMonitor(threading.Thread):
def read(self): def read(self):
if not self.disabled: if not self.disabled:
free, total = torch.cuda.mem_get_info() free, total = self.cuda_mem_get_info()
self.data["free"] = free self.data["free"] = free
self.data["total"] = total self.data["total"] = total
......
...@@ -6,7 +6,7 @@ from urllib.parse import urlparse ...@@ -6,7 +6,7 @@ from urllib.parse import urlparse
from basicsr.utils.download_util import load_file_from_url from basicsr.utils.download_util import load_file_from_url
from modules import shared from modules import shared
from modules.upscaler import Upscaler from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
from modules.paths import script_path, models_path from modules.paths import script_path, models_path
...@@ -169,4 +169,8 @@ def load_upscalers(): ...@@ -169,4 +169,8 @@ def load_upscalers():
scaler = cls(commandline_options.get(cmd_name, None)) scaler = cls(commandline_options.get(cmd_name, None))
datas += scaler.scalers datas += scaler.scalers
shared.sd_upscalers = datas shared.sd_upscalers = sorted(
datas,
# Special case for UpscalerNone keeps it at the beginning of the list.
key=lambda x: x.name.lower() if not isinstance(x.scaler, (UpscalerNone, UpscalerLanczos, UpscalerNearest)) else ""
)
from .sampler import UniPCSampler
"""SAMPLING ONLY."""
import torch
from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC
from modules import shared, devices
class UniPCSampler(object):
def __init__(self, model, **kwargs):
super().__init__()
self.model = model
to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
self.before_sample = None
self.after_sample = None
self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != devices.device:
attr = attr.to(devices.device)
setattr(self, name, attr)
def set_hooks(self, before_sample, after_sample, after_update):
self.before_sample = before_sample
self.after_sample = after_sample
self.after_update = after_update
@torch.no_grad()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs
):
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list): ctmp = ctmp[0]
cbs = ctmp.shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
elif isinstance(conditioning, list):
for ctmp in conditioning:
if ctmp.shape[0] != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
# print(f'Data shape for UniPC sampling is {size}')
device = self.model.betas.device
if x_T is None:
img = torch.randn(size, device=device)
else:
img = x_T
ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
# SD 1.X is "noise", SD 2.X is "v"
model_type = "v" if self.model.parameterization == "v" else "noise"
model_fn = model_wrapper(
lambda x, t, c: self.model.apply_model(x, t, c),
ns,
model_type=model_type,
guidance_type="classifier-free",
#condition=conditioning,
#unconditional_condition=unconditional_conditioning,
guidance_scale=unconditional_guidance_scale,
)
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, variant=shared.opts.uni_pc_variant, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample, after_update=self.after_update)
x = uni_pc.sample(img, steps=S, skip_type=shared.opts.uni_pc_skip_type, method="multistep", order=shared.opts.uni_pc_order, lower_order_final=shared.opts.uni_pc_lower_order_final)
return x.to(device), None
This diff is collapsed.
...@@ -597,6 +597,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: ...@@ -597,6 +597,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if state.job_count == -1: if state.job_count == -1:
state.job_count = p.n_iter state.job_count = p.n_iter
extra_network_data = None
for n in range(p.n_iter): for n in range(p.n_iter):
p.iteration = n p.iteration = n
...@@ -620,6 +621,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: ...@@ -620,6 +621,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size] seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size] subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
if p.scripts is not None:
p.scripts.before_process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
if len(prompts) == 0: if len(prompts) == 0:
break break
...@@ -753,7 +757,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: ...@@ -753,7 +757,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if opts.grid_save: if opts.grid_save:
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True) images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
if not p.disable_extra_networks: if not p.disable_extra_networks and extra_network_data:
extra_networks.deactivate(p, extra_network_data) extra_networks.deactivate(p, extra_network_data)
devices.torch_gc() devices.torch_gc()
...@@ -944,7 +948,10 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): ...@@ -944,7 +948,10 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
shared.state.nextjob() shared.state.nextjob()
img2img_sampler_name = self.sampler_name if self.sampler_name != 'PLMS' else 'DDIM' # PLMS does not support img2img so we just silently switch ot DDIM img2img_sampler_name = self.sampler_name
if self.sampler_name in ['PLMS', 'UniPC']: # PLMS/UniPC do not support img2img so we just silently switch to DDIM
img2img_sampler_name = 'DDIM'
if self.hr_sampler == '---': if self.hr_sampler == '---':
pass pass
......
...@@ -29,7 +29,7 @@ class ImageSaveParams: ...@@ -29,7 +29,7 @@ class ImageSaveParams:
class CFGDenoiserParams: class CFGDenoiserParams:
def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps): def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond):
self.x = x self.x = x
"""Latent image representation in the process of being denoised""" """Latent image representation in the process of being denoised"""
...@@ -44,6 +44,12 @@ class CFGDenoiserParams: ...@@ -44,6 +44,12 @@ class CFGDenoiserParams:
self.total_sampling_steps = total_sampling_steps self.total_sampling_steps = total_sampling_steps
"""Total number of sampling steps planned""" """Total number of sampling steps planned"""
self.text_cond = text_cond
""" Encoder hidden states of text conditioning from prompt"""
self.text_uncond = text_uncond
""" Encoder hidden states of text conditioning from negative prompt"""
class CFGDenoisedParams: class CFGDenoisedParams:
......
...@@ -33,6 +33,11 @@ class Script: ...@@ -33,6 +33,11 @@ class Script:
parsing infotext to set the value for the component; see ui.py's txt2img_paste_fields for an example parsing infotext to set the value for the component; see ui.py's txt2img_paste_fields for an example
""" """
paste_field_names = None
"""if set in ui(), this is a list of names of infotext fields; the fields will be sent through the
various "Send to <X>" buttons when clicked
"""
def title(self): def title(self):
"""this function should return the title of the script. This is what will be displayed in the dropdown menu.""" """this function should return the title of the script. This is what will be displayed in the dropdown menu."""
...@@ -80,6 +85,20 @@ class Script: ...@@ -80,6 +85,20 @@ class Script:
pass pass
def before_process_batch(self, p, *args, **kwargs):
"""
Called before extra networks are parsed from the prompt, so you can add
new extra network keywords to the prompt with this callback.
**kwargs will have those items:
- batch_number - index of current batch, from 0 to number of batches-1
- prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
- seeds - list of seeds for current batch
- subseeds - list of subseeds for current batch
"""
pass
def process_batch(self, p, *args, **kwargs): def process_batch(self, p, *args, **kwargs):
""" """
Same as process(), but called for every batch. Same as process(), but called for every batch.
...@@ -256,6 +275,7 @@ class ScriptRunner: ...@@ -256,6 +275,7 @@ class ScriptRunner:
self.alwayson_scripts = [] self.alwayson_scripts = []
self.titles = [] self.titles = []
self.infotext_fields = [] self.infotext_fields = []
self.paste_field_names = []
def initialize_scripts(self, is_img2img): def initialize_scripts(self, is_img2img):
from modules import scripts_auto_postprocessing from modules import scripts_auto_postprocessing
...@@ -304,6 +324,9 @@ class ScriptRunner: ...@@ -304,6 +324,9 @@ class ScriptRunner:
if script.infotext_fields is not None: if script.infotext_fields is not None:
self.infotext_fields += script.infotext_fields self.infotext_fields += script.infotext_fields
if script.paste_field_names is not None:
self.paste_field_names += script.paste_field_names
inputs += controls inputs += controls
inputs_alwayson += [script.alwayson for _ in controls] inputs_alwayson += [script.alwayson for _ in controls]
script.args_to = len(inputs) script.args_to = len(inputs)
...@@ -388,6 +411,15 @@ class ScriptRunner: ...@@ -388,6 +411,15 @@ class ScriptRunner:
print(f"Error running process: {script.filename}", file=sys.stderr) print(f"Error running process: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
def before_process_batch(self, p, **kwargs):
for script in self.alwayson_scripts:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.before_process_batch(p, *script_args, **kwargs)
except Exception:
print(f"Error running before_process_batch: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
def process_batch(self, p, **kwargs): def process_batch(self, p, **kwargs):
for script in self.alwayson_scripts: for script in self.alwayson_scripts:
try: try:
......
...@@ -37,11 +37,23 @@ def apply_optimizations(): ...@@ -37,11 +37,23 @@ def apply_optimizations():
optimization_method = None optimization_method = None
can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(getattr(torch.nn.functional, "scaled_dot_product_attention")) # not everyone has torch 2.x to use sdp
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)): if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
print("Applying xformers cross attention optimization.") print("Applying xformers cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
optimization_method = 'xformers' optimization_method = 'xformers'
elif cmd_opts.opt_sdp_no_mem_attention and can_use_sdp:
print("Applying scaled dot product cross attention optimization (without memory efficient attention).")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_no_mem_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sdp_no_mem_attnblock_forward
optimization_method = 'sdp-no-mem'
elif cmd_opts.opt_sdp_attention and can_use_sdp:
print("Applying scaled dot product cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sdp_attnblock_forward
optimization_method = 'sdp'
elif cmd_opts.opt_sub_quad_attention: elif cmd_opts.opt_sub_quad_attention:
print("Applying sub-quadratic cross attention optimization.") print("Applying sub-quadratic cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward
......
...@@ -346,6 +346,52 @@ def xformers_attention_forward(self, x, context=None, mask=None): ...@@ -346,6 +346,52 @@ def xformers_attention_forward(self, x, context=None, mask=None):
out = rearrange(out, 'b n h d -> b n (h d)', h=h) out = rearrange(out, 'b n h d -> b n (h d)', h=h)
return self.to_out(out) return self.to_out(out)
# Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py
# The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface
def scaled_dot_product_attention_forward(self, x, context=None, mask=None):
batch_size, sequence_length, inner_dim = x.shape
if mask is not None:
mask = self.prepare_attention_mask(mask, sequence_length, batch_size)
mask = mask.view(batch_size, self.heads, -1, mask.shape[-1])
h = self.heads
q_in = self.to_q(x)
context = default(context, x)
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
k_in = self.to_k(context_k)
v_in = self.to_v(context_v)
head_dim = inner_dim // h
q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
del q_in, k_in, v_in
dtype = q.dtype
if shared.opts.upcast_attn:
q, k = q.float(), k.float()
# the output of sdp = (batch, num_heads, seq_len, head_dim)
hidden_states = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, h * head_dim)
hidden_states = hidden_states.to(dtype)
# linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)
return hidden_states
def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None):
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
return scaled_dot_product_attention_forward(self, x, context, mask)
def cross_attention_attnblock_forward(self, x): def cross_attention_attnblock_forward(self, x):
h_ = x h_ = x
h_ = self.norm(h_) h_ = self.norm(h_)
...@@ -427,6 +473,30 @@ def xformers_attnblock_forward(self, x): ...@@ -427,6 +473,30 @@ def xformers_attnblock_forward(self, x):
except NotImplementedError: except NotImplementedError:
return cross_attention_attnblock_forward(self, x) return cross_attention_attnblock_forward(self, x)
def sdp_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
b, c, h, w = q.shape
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
dtype = q.dtype
if shared.opts.upcast_attn:
q, k = q.float(), k.float()
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
out = out.to(dtype)
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
out = self.proj_out(out)
return x + out
def sdp_no_mem_attnblock_forward(self, x):
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
return sdp_attnblock_forward(self, x)
def sub_quad_attnblock_forward(self, x): def sub_quad_attnblock_forward(self, x):
h_ = x h_ = x
h_ = self.norm(h_) h_ = self.norm(h_)
......
...@@ -210,6 +210,30 @@ def get_state_dict_from_checkpoint(pl_sd): ...@@ -210,6 +210,30 @@ def get_state_dict_from_checkpoint(pl_sd):
return pl_sd return pl_sd
def read_metadata_from_safetensors(filename):
import json
with open(filename, mode="rb") as file:
metadata_len = file.read(8)
metadata_len = int.from_bytes(metadata_len, "little")
json_start = file.read(2)
assert metadata_len > 2 and json_start in (b'{"', b"{'"), f"{filename} is not a safetensors file"
json_data = json_start + file.read(metadata_len-2)
json_obj = json.loads(json_data)
res = {}
for k, v in json_obj.get("__metadata__", {}).items():
res[k] = v
if isinstance(v, str) and v[0:1] == '{':
try:
res[k] = json.loads(v)
except Exception as e:
pass
return res
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None): def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
_, extension = os.path.splitext(checkpoint_file) _, extension = os.path.splitext(checkpoint_file)
if extension.lower() == ".safetensors": if extension.lower() == ".safetensors":
......
...@@ -32,7 +32,7 @@ def set_samplers(): ...@@ -32,7 +32,7 @@ def set_samplers():
global samplers, samplers_for_img2img global samplers, samplers_for_img2img
hidden = set(shared.opts.hide_samplers) hidden = set(shared.opts.hide_samplers)
hidden_img2img = set(shared.opts.hide_samplers + ['PLMS']) hidden_img2img = set(shared.opts.hide_samplers + ['PLMS', 'UniPC'])
samplers = [x for x in all_samplers if x.name not in hidden] samplers = [x for x in all_samplers if x.name not in hidden]
samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img] samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
......
...@@ -7,19 +7,27 @@ import torch ...@@ -7,19 +7,27 @@ import torch
from modules.shared import state from modules.shared import state
from modules import sd_samplers_common, prompt_parser, shared from modules import sd_samplers_common, prompt_parser, shared
import modules.models.diffusion.uni_pc
samplers_data_compvis = [ samplers_data_compvis = [
sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}), sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}), sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
sd_samplers_common.SamplerData('UniPC', lambda model: VanillaStableDiffusionSampler(modules.models.diffusion.uni_pc.UniPCSampler, model), [], {}),
] ]
class VanillaStableDiffusionSampler: class VanillaStableDiffusionSampler:
def __init__(self, constructor, sd_model): def __init__(self, constructor, sd_model):
self.sampler = constructor(sd_model) self.sampler = constructor(sd_model)
self.is_ddim = hasattr(self.sampler, 'p_sample_ddim')
self.is_plms = hasattr(self.sampler, 'p_sample_plms') self.is_plms = hasattr(self.sampler, 'p_sample_plms')
self.orig_p_sample_ddim = self.sampler.p_sample_plms if self.is_plms else self.sampler.p_sample_ddim self.is_unipc = isinstance(self.sampler, modules.models.diffusion.uni_pc.UniPCSampler)
self.orig_p_sample_ddim = None
if self.is_plms:
self.orig_p_sample_ddim = self.sampler.p_sample_plms
elif self.is_ddim:
self.orig_p_sample_ddim = self.sampler.p_sample_ddim
self.mask = None self.mask = None
self.nmask = None self.nmask = None
self.init_latent = None self.init_latent = None
...@@ -45,6 +53,15 @@ class VanillaStableDiffusionSampler: ...@@ -45,6 +53,15 @@ class VanillaStableDiffusionSampler:
return self.last_latent return self.last_latent
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs): def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
x_dec, ts, cond, unconditional_conditioning = self.before_sample(x_dec, ts, cond, unconditional_conditioning)
res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
x_dec, ts, cond, unconditional_conditioning, res = self.after_sample(x_dec, ts, cond, unconditional_conditioning, res)
return res
def before_sample(self, x, ts, cond, unconditional_conditioning):
if state.interrupted or state.skipped: if state.interrupted or state.skipped:
raise sd_samplers_common.InterruptedException raise sd_samplers_common.InterruptedException
...@@ -76,7 +93,7 @@ class VanillaStableDiffusionSampler: ...@@ -76,7 +93,7 @@ class VanillaStableDiffusionSampler:
if self.mask is not None: if self.mask is not None:
img_orig = self.sampler.model.q_sample(self.init_latent, ts) img_orig = self.sampler.model.q_sample(self.init_latent, ts)
x_dec = img_orig * self.mask + self.nmask * x_dec x = img_orig * self.mask + self.nmask * x
# Wrap the image conditioning back up since the DDIM code can accept the dict directly. # Wrap the image conditioning back up since the DDIM code can accept the dict directly.
# Note that they need to be lists because it just concatenates them later. # Note that they need to be lists because it just concatenates them later.
...@@ -84,12 +101,13 @@ class VanillaStableDiffusionSampler: ...@@ -84,12 +101,13 @@ class VanillaStableDiffusionSampler:
cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]} cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs) return x, ts, cond, unconditional_conditioning
def update_step(self, last_latent):
if self.mask is not None: if self.mask is not None:
self.last_latent = self.init_latent * self.mask + self.nmask * res[1] self.last_latent = self.init_latent * self.mask + self.nmask * last_latent
else: else:
self.last_latent = res[1] self.last_latent = last_latent
sd_samplers_common.store_latent(self.last_latent) sd_samplers_common.store_latent(self.last_latent)
...@@ -97,26 +115,51 @@ class VanillaStableDiffusionSampler: ...@@ -97,26 +115,51 @@ class VanillaStableDiffusionSampler:
state.sampling_step = self.step state.sampling_step = self.step
shared.total_tqdm.update() shared.total_tqdm.update()
return res def after_sample(self, x, ts, cond, uncond, res):
if not self.is_unipc:
self.update_step(res[1])
return x, ts, cond, uncond, res
def unipc_after_update(self, x, model_x):
self.update_step(x)
def initialize(self, p): def initialize(self, p):
self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim
if self.eta != 0.0: if self.eta != 0.0:
p.extra_generation_params["Eta DDIM"] = self.eta p.extra_generation_params["Eta DDIM"] = self.eta
if self.is_unipc:
keys = [
('UniPC variant', 'uni_pc_variant'),
('UniPC skip type', 'uni_pc_skip_type'),
('UniPC order', 'uni_pc_order'),
('UniPC lower order final', 'uni_pc_lower_order_final'),
]
for name, key in keys:
v = getattr(shared.opts, key)
if v != shared.opts.get_default(key):
p.extra_generation_params[name] = v
for fieldname in ['p_sample_ddim', 'p_sample_plms']: for fieldname in ['p_sample_ddim', 'p_sample_plms']:
if hasattr(self.sampler, fieldname): if hasattr(self.sampler, fieldname):
setattr(self.sampler, fieldname, self.p_sample_ddim_hook) setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
if self.is_unipc:
self.sampler.set_hooks(lambda x, t, c, u: self.before_sample(x, t, c, u), lambda x, t, c, u, r: self.after_sample(x, t, c, u, r), lambda x, mx: self.unipc_after_update(x, mx))
self.mask = p.mask if hasattr(p, 'mask') else None self.mask = p.mask if hasattr(p, 'mask') else None
self.nmask = p.nmask if hasattr(p, 'nmask') else None self.nmask = p.nmask if hasattr(p, 'nmask') else None
def adjust_steps_if_invalid(self, p, num_steps): def adjust_steps_if_invalid(self, p, num_steps):
if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'): if ((self.config.name == 'DDIM') and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS') or (self.config.name == 'UniPC'):
if self.config.name == 'UniPC' and num_steps < shared.opts.uni_pc_order:
num_steps = shared.opts.uni_pc_order
valid_step = 999 / (1000 // num_steps) valid_step = 999 / (1000 // num_steps)
if valid_step == math.floor(valid_step): if valid_step == math.floor(valid_step):
return int(valid_step) + 1 return int(valid_step) + 1
return num_steps return num_steps
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
......
...@@ -101,11 +101,13 @@ class CFGDenoiser(torch.nn.Module): ...@@ -101,11 +101,13 @@ class CFGDenoiser(torch.nn.Module):
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma]) sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond] + [torch.zeros_like(self.init_latent)]) image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond] + [torch.zeros_like(self.init_latent)])
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps) denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond)
cfg_denoiser_callback(denoiser_params) cfg_denoiser_callback(denoiser_params)
x_in = denoiser_params.x x_in = denoiser_params.x
image_cond_in = denoiser_params.image_cond image_cond_in = denoiser_params.image_cond
sigma_in = denoiser_params.sigma sigma_in = denoiser_params.sigma
tensor = denoiser_params.text_cond
uncond = denoiser_params.text_uncond
if tensor.shape[1] == uncond.shape[1]: if tensor.shape[1] == uncond.shape[1]:
if not is_edit_model: if not is_edit_model:
......
...@@ -35,8 +35,11 @@ def model(): ...@@ -35,8 +35,11 @@ def model():
global sd_vae_approx_model global sd_vae_approx_model
if sd_vae_approx_model is None: if sd_vae_approx_model is None:
model_path = os.path.join(paths.models_path, "VAE-approx", "model.pt")
sd_vae_approx_model = VAEApprox() sd_vae_approx_model = VAEApprox()
sd_vae_approx_model.load_state_dict(torch.load(os.path.join(paths.models_path, "VAE-approx", "model.pt"), map_location='cpu' if devices.device.type != 'cuda' else None)) if not os.path.exists(model_path):
model_path = os.path.join(paths.script_path, "models", "VAE-approx", "model.pt")
sd_vae_approx_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None))
sd_vae_approx_model.eval() sd_vae_approx_model.eval()
sd_vae_approx_model.to(devices.device, devices.dtype) sd_vae_approx_model.to(devices.device, devices.dtype)
......
...@@ -69,6 +69,8 @@ parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size fo ...@@ -69,6 +69,8 @@ parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size fo
parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage of VRAM threshold for the sub-quadratic cross-attention layer optimization to use chunking", default=None) parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage of VRAM threshold for the sub-quadratic cross-attention layer optimization to use chunking", default=None)
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.") parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
parser.add_argument("--opt-sdp-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization; requires PyTorch 2.*")
parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization without memory efficient attention, makes image generation deterministic; requires PyTorch 2.*")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI") parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower) parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
...@@ -114,7 +116,10 @@ parser.add_argument("--no-download-sd-model", action='store_true', help="don't d ...@@ -114,7 +116,10 @@ parser.add_argument("--no-download-sd-model", action='store_true', help="don't d
script_loading.preload_extensions(extensions.extensions_dir, parser) script_loading.preload_extensions(extensions.extensions_dir, parser)
script_loading.preload_extensions(extensions.extensions_builtin_dir, parser) script_loading.preload_extensions(extensions.extensions_builtin_dir, parser)
cmd_opts = parser.parse_args() if os.environ.get('IGNORE_CMD_ARGS_ERRORS', None) is None:
cmd_opts = parser.parse_args()
else:
cmd_opts, _ = parser.parse_known_args()
restricted_opts = { restricted_opts = {
"samples_filename_pattern", "samples_filename_pattern",
...@@ -305,6 +310,7 @@ def list_samplers(): ...@@ -305,6 +310,7 @@ def list_samplers():
hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config} hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
tab_names = []
options_templates = {} options_templates = {}
...@@ -327,9 +333,11 @@ options_templates.update(options_section(('saving-images', "Saving images/grids" ...@@ -327,9 +333,11 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."), "save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."),
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
"jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}), "jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
"webp_lossless": OptionInfo(False, "Use lossless compression for webp images"),
"export_for_4chan": OptionInfo(True, "If the saved image file size is above the limit, or its either width or height are above the limit, save a downscaled copy as JPG"), "export_for_4chan": OptionInfo(True, "If the saved image file size is above the limit, or its either width or height are above the limit, save a downscaled copy as JPG"),
"img_downscale_threshold": OptionInfo(4.0, "File size limit for the above option, MB", gr.Number), "img_downscale_threshold": OptionInfo(4.0, "File size limit for the above option, MB", gr.Number),
"target_side_length": OptionInfo(4000, "Width/height limit for the above option, in pixels", gr.Number), "target_side_length": OptionInfo(4000, "Width/height limit for the above option, in pixels", gr.Number),
"img_max_size_mp": OptionInfo(200, "Maximum image size, in megapixels", gr.Number),
"use_original_name_batch": OptionInfo(True, "Use original name for output filename during batch process in extras tab"), "use_original_name_batch": OptionInfo(True, "Use original name for output filename during batch process in extras tab"),
"use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"), "use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"),
...@@ -440,6 +448,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), ...@@ -440,6 +448,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
options_templates.update(options_section(('extra_networks', "Extra Networks"), { options_templates.update(options_section(('extra_networks', "Extra Networks"), {
"extra_networks_default_view": OptionInfo("cards", "Default view for Extra Networks", gr.Dropdown, {"choices": ["cards", "thumbs"]}), "extra_networks_default_view": OptionInfo("cards", "Default view for Extra Networks", gr.Dropdown, {"choices": ["cards", "thumbs"]}),
"extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), "extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"extra_networks_add_text_separator": OptionInfo(" ", "Extra text to add before <...> when adding extra network to prompt"),
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), "sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
})) }))
...@@ -460,6 +469,7 @@ options_templates.update(options_section(('ui', "User interface"), { ...@@ -460,6 +469,7 @@ options_templates.update(options_section(('ui', "User interface"), {
"keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), "keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
"keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), "keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
"quicksettings": OptionInfo("sd_model_checkpoint", "Quicksettings list"), "quicksettings": OptionInfo("sd_model_checkpoint", "Quicksettings list"),
"hidden_tabs": OptionInfo([], "Hidden UI tabs (requires restart)", ui_components.DropdownMulti, lambda: {"choices": [x for x in tab_names]}),
"ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"), "ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"),
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order"), "ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order"),
"localization": OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)), "localization": OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
...@@ -485,6 +495,10 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" ...@@ -485,6 +495,10 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}), 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}),
'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma"), 'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma"),
'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"]}),
'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}),
'uni_pc_order': OptionInfo(3, "UniPC order (must be < sampling steps)", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}),
'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final"),
})) }))
options_templates.update(options_section(('postprocessing', "Postprocessing"), { options_templates.update(options_section(('postprocessing', "Postprocessing"), {
...@@ -559,6 +573,15 @@ class Options: ...@@ -559,6 +573,15 @@ class Options:
return True return True
def get_default(self, key):
"""returns the default value for the key"""
data_label = self.data_labels.get(key)
if data_label is None:
return None
return data_label.default
def save(self, filename): def save(self, filename):
assert not cmd_opts.freeze_settings, "saving settings is disabled" assert not cmd_opts.freeze_settings, "saving settings is disabled"
...@@ -691,6 +714,7 @@ class TotalTQDM: ...@@ -691,6 +714,7 @@ class TotalTQDM:
def clear(self): def clear(self):
if self._tqdm is not None: if self._tqdm is not None:
self._tqdm.refresh()
self._tqdm.close() self._tqdm.close()
self._tqdm = None self._tqdm = None
......
...@@ -33,3 +33,6 @@ class Timer: ...@@ -33,3 +33,6 @@ class Timer:
res += ")" res += ")"
return res return res
def reset(self):
self.__init__()
...@@ -957,7 +957,7 @@ def create_ui(): ...@@ -957,7 +957,7 @@ def create_ui():
) )
token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter]) token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter]) negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[img2img_negative_prompt, steps], outputs=[negative_token_counter])
ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery) ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery)
...@@ -1581,6 +1581,10 @@ def create_ui(): ...@@ -1581,6 +1581,10 @@ def create_ui():
extensions_interface = ui_extensions.create_ui() extensions_interface = ui_extensions.create_ui()
interfaces += [(extensions_interface, "Extensions", "extensions")] interfaces += [(extensions_interface, "Extensions", "extensions")]
shared.tab_names = []
for _interface, label, _ifid in interfaces:
shared.tab_names.append(label)
with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo: with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
with gr.Row(elem_id="quicksettings", variant="compact"): with gr.Row(elem_id="quicksettings", variant="compact"):
for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])): for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])):
...@@ -1591,6 +1595,8 @@ def create_ui(): ...@@ -1591,6 +1595,8 @@ def create_ui():
with gr.Tabs(elem_id="tabs") as tabs: with gr.Tabs(elem_id="tabs") as tabs:
for interface, label, ifid in interfaces: for interface, label, ifid in interfaces:
if label in shared.opts.hidden_tabs:
continue
with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid): with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid):
interface.render() interface.render()
...@@ -1763,7 +1769,8 @@ def create_ui(): ...@@ -1763,7 +1769,8 @@ def create_ui():
def reload_javascript(): def reload_javascript():
head = f'<script type="text/javascript" src="file={os.path.abspath("script.js")}?{os.path.getmtime("script.js")}"></script>\n' script_js = os.path.join(script_path, "script.js")
head = f'<script type="text/javascript" src="file={os.path.abspath(script_js)}?{os.path.getmtime(script_js)}"></script>\n'
inline = f"{localization.localization_js(shared.opts.localization)};" inline = f"{localization.localization_js(shared.opts.localization)};"
if cmd_opts.theme is not None: if cmd_opts.theme is not None:
...@@ -1772,6 +1779,9 @@ def reload_javascript(): ...@@ -1772,6 +1779,9 @@ def reload_javascript():
for script in modules.scripts.list_scripts("javascript", ".js"): for script in modules.scripts.list_scripts("javascript", ".js"):
head += f'<script type="text/javascript" src="file={script.path}?{os.path.getmtime(script.path)}"></script>\n' head += f'<script type="text/javascript" src="file={script.path}?{os.path.getmtime(script.path)}"></script>\n'
for script in modules.scripts.list_scripts("javascript", ".mjs"):
head += f'<script type="module" src="file={script.path}?{os.path.getmtime(script.path)}"></script>\n'
head += f'<script type="text/javascript">{inline}</script>\n' head += f'<script type="text/javascript">{inline}</script>\n'
def template_response(*args, **kwargs): def template_response(*args, **kwargs):
......
...@@ -198,9 +198,16 @@ Requested path was: {f} ...@@ -198,9 +198,16 @@ Requested path was: {f}
html_info = gr.HTML(elem_id=f'html_info_{tabname}') html_info = gr.HTML(elem_id=f'html_info_{tabname}')
html_log = gr.HTML(elem_id=f'html_log_{tabname}') html_log = gr.HTML(elem_id=f'html_log_{tabname}')
paste_field_names = []
if tabname == "txt2img":
paste_field_names = modules.scripts.scripts_txt2img.paste_field_names
elif tabname == "img2img":
paste_field_names = modules.scripts.scripts_img2img.paste_field_names
for paste_tabname, paste_button in buttons.items(): for paste_tabname, paste_button in buttons.items():
parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding( parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
paste_button=paste_button, tabname=paste_tabname, source_tabname="txt2img" if tabname == "txt2img" else None, source_image_component=result_gallery paste_button=paste_button, tabname=paste_tabname, source_tabname="txt2img" if tabname == "txt2img" else None, source_image_component=result_gallery,
paste_field_names=paste_field_names
)) ))
return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
...@@ -304,7 +304,7 @@ def create_ui(): ...@@ -304,7 +304,7 @@ def create_ui():
with gr.TabItem("Available"): with gr.TabItem("Available"):
with gr.Row(): with gr.Row():
refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary") refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary")
available_extensions_index = gr.Text(value="https://raw.githubusercontent.com/wiki/AUTOMATIC1111/stable-diffusion-webui/Extensions-index.md", label="Extension index URL").style(container=False) available_extensions_index = gr.Text(value="https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui-extensions/master/index.json", label="Extension index URL").style(container=False)
extension_to_install = gr.Text(elem_id="extension_to_install", visible=False) extension_to_install = gr.Text(elem_id="extension_to_install", visible=False)
install_extension_button = gr.Button(elem_id="install_extension_button", visible=False) install_extension_button = gr.Button(elem_id="install_extension_button", visible=False)
......
...@@ -30,8 +30,8 @@ def add_pages_to_demo(app): ...@@ -30,8 +30,8 @@ def add_pages_to_demo(app):
raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.") raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
ext = os.path.splitext(filename)[1].lower() ext = os.path.splitext(filename)[1].lower()
if ext not in (".png", ".jpg"): if ext not in (".png", ".jpg", ".webp"):
raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg.") raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg and webp.")
# would profit from returning 304 # would profit from returning 304
return FileResponse(filename, headers={"Accept-Ranges": "bytes"}) return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
...@@ -124,19 +124,56 @@ class ExtraNetworksPage: ...@@ -124,19 +124,56 @@ class ExtraNetworksPage:
if onclick is None: if onclick is None:
onclick = '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"' onclick = '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
metadata_button = ""
metadata = item.get("metadata")
if metadata:
metadata_onclick = '"' + html.escape(f"""extraNetworksShowMetadata({json.dumps(metadata)}); return false;""") + '"'
metadata_button = f"<div class='metadata-button' title='Show metadata' onclick={metadata_onclick}></div>"
args = { args = {
"preview_html": "style='background-image: url(\"" + html.escape(preview) + "\")'" if preview else '', "preview_html": "style='background-image: url(\"" + html.escape(preview) + "\")'" if preview else '',
"prompt": item.get("prompt", None), "prompt": item.get("prompt", None),
"tabname": json.dumps(tabname), "tabname": json.dumps(tabname),
"local_preview": json.dumps(item["local_preview"]), "local_preview": json.dumps(item["local_preview"]),
"name": item["name"], "name": item["name"],
"description": (item.get("description") or ""),
"card_clicked": onclick, "card_clicked": onclick,
"save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"', "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"',
"search_term": item.get("search_term", ""), "search_term": item.get("search_term", ""),
"metadata_button": metadata_button,
} }
return self.card_page.format(**args) return self.card_page.format(**args)
def find_preview(self, path):
"""
Find a preview PNG for a given path (without extension) and call link_preview on it.
"""
preview_extensions = ["png", "jpg", "webp"]
if shared.opts.samples_format not in preview_extensions:
preview_extensions.append(shared.opts.samples_format)
potential_files = sum([[path + "." + ext, path + ".preview." + ext] for ext in preview_extensions], [])
for file in potential_files:
if os.path.isfile(file):
return self.link_preview(file)
return None
def find_description(self, path):
"""
Find and read a description file for a given path (without extension).
"""
for file in [f"{path}.txt", f"{path}.description.txt"]:
try:
with open(file, "r", encoding="utf-8", errors="replace") as f:
return f.read()
except OSError:
pass
return None
def intialize(): def intialize():
extra_pages.clear() extra_pages.clear()
...@@ -183,7 +220,6 @@ def create_ui(container, button, tabname): ...@@ -183,7 +220,6 @@ def create_ui(container, button, tabname):
filter = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False) filter = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False)
button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh") button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh")
button_close = gr.Button('Close', elem_id=tabname+"_extra_close")
ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False)
...@@ -194,7 +230,6 @@ def create_ui(container, button, tabname): ...@@ -194,7 +230,6 @@ def create_ui(container, button, tabname):
state_visible = gr.State(value=False) state_visible = gr.State(value=False)
button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container]) button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container])
button_close.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container])
def refresh(): def refresh():
res = [] res = []
......
import html import html
import json import json
import os import os
import urllib.parse
from modules import shared, ui_extra_networks, sd_models from modules import shared, ui_extra_networks, sd_models
...@@ -17,21 +16,14 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): ...@@ -17,21 +16,14 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
checkpoint: sd_models.CheckpointInfo checkpoint: sd_models.CheckpointInfo
for name, checkpoint in sd_models.checkpoints_list.items(): for name, checkpoint in sd_models.checkpoints_list.items():
path, ext = os.path.splitext(checkpoint.filename) path, ext = os.path.splitext(checkpoint.filename)
previews = [path + ".png", path + ".preview.png"]
preview = None
for file in previews:
if os.path.isfile(file):
preview = self.link_preview(file)
break
yield { yield {
"name": checkpoint.name_for_extra, "name": checkpoint.name_for_extra,
"filename": path, "filename": path,
"preview": preview, "preview": self.find_preview(path),
"description": self.find_description(path),
"search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""), "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""),
"onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"', "onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"',
"local_preview": path + ".png", "local_preview": f"{path}.{shared.opts.samples_format}",
} }
def allowed_directories_for_previews(self): def allowed_directories_for_previews(self):
......
...@@ -14,21 +14,15 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): ...@@ -14,21 +14,15 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
def list_items(self): def list_items(self):
for name, path in shared.hypernetworks.items(): for name, path in shared.hypernetworks.items():
path, ext = os.path.splitext(path) path, ext = os.path.splitext(path)
previews = [path + ".png", path + ".preview.png"]
preview = None
for file in previews:
if os.path.isfile(file):
preview = self.link_preview(file)
break
yield { yield {
"name": name, "name": name,
"filename": path, "filename": path,
"preview": preview, "preview": self.find_preview(path),
"description": self.find_description(path),
"search_term": self.search_terms_from_path(path), "search_term": self.search_terms_from_path(path),
"prompt": json.dumps(f"<hypernet:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"), "prompt": json.dumps(f"<hypernet:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
"local_preview": path + ".png", "local_preview": f"{path}.preview.{shared.opts.samples_format}",
} }
def allowed_directories_for_previews(self): def allowed_directories_for_previews(self):
......
import json import json
import os import os
from modules import ui_extra_networks, sd_hijack from modules import ui_extra_networks, sd_hijack, shared
class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
...@@ -15,19 +15,14 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): ...@@ -15,19 +15,14 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
def list_items(self): def list_items(self):
for embedding in sd_hijack.model_hijack.embedding_db.word_embeddings.values(): for embedding in sd_hijack.model_hijack.embedding_db.word_embeddings.values():
path, ext = os.path.splitext(embedding.filename) path, ext = os.path.splitext(embedding.filename)
preview_file = path + ".preview.png"
preview = None
if os.path.isfile(preview_file):
preview = self.link_preview(preview_file)
yield { yield {
"name": embedding.name, "name": embedding.name,
"filename": embedding.filename, "filename": embedding.filename,
"preview": preview, "preview": self.find_preview(path),
"description": self.find_description(path),
"search_term": self.search_terms_from_path(embedding.filename), "search_term": self.search_terms_from_path(embedding.filename),
"prompt": json.dumps(embedding.name), "prompt": json.dumps(embedding.name),
"local_preview": path + ".preview.png", "local_preview": f"{path}.preview.{shared.opts.samples_format}",
} }
def allowed_directories_for_previews(self): def allowed_directories_for_previews(self):
......
...@@ -23,8 +23,8 @@ torchdiffeq==0.2.3 ...@@ -23,8 +23,8 @@ torchdiffeq==0.2.3
kornia==0.6.7 kornia==0.6.7
lark==1.1.2 lark==1.1.2
inflection==0.5.1 inflection==0.5.1
GitPython==3.1.27 GitPython==3.1.30
torchsde==0.2.5 torchsde==0.2.5
safetensors==0.2.7 safetensors==0.2.7
httpcore<=0.15 httpcore<=0.15
fastapi==0.90.1 fastapi==0.94.0
...@@ -100,7 +100,7 @@ class Script(scripts.Script): ...@@ -100,7 +100,7 @@ class Script(scripts.Script):
processed = process_images(p) processed = process_images(p)
grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2)) grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2))
grid = images.draw_prompt_matrix(grid, processed.images[0].width, processed.images[1].height, prompt_matrix_parts, margin_size) grid = images.draw_prompt_matrix(grid, processed.images[0].width, processed.images[0].height, prompt_matrix_parts, margin_size)
processed.images.insert(0, grid) processed.images.insert(0, grid)
processed.index_of_first_image = 1 processed.index_of_first_image = 1
processed.infotexts.insert(0, processed.infotexts[0]) processed.infotexts.insert(0, processed.infotexts[0])
......
This diff is collapsed.
...@@ -362,6 +362,46 @@ input[type="range"]{ ...@@ -362,6 +362,46 @@ input[type="range"]{
height: 100%; height: 100%;
} }
.popup-metadata{
color: black;
background: white;
display: inline-block;
padding: 1em;
white-space: pre-wrap;
}
.global-popup{
display: flex;
position: fixed;
z-index: 1001;
left: 0;
top: 0;
width: 100%;
height: 100%;
overflow: auto;
background-color: rgba(20, 20, 20, 0.95);
}
.global-popup-close:before {
content: "×";
}
.global-popup-close{
position: fixed;
right: 0.25em;
top: 0;
cursor: pointer;
color: white;
font-size: 32pt;
}
.global-popup-inner{
display: inline-block;
margin: auto;
padding: 2em;
}
#lightboxModal{ #lightboxModal{
display: none; display: none;
position: fixed; position: fixed;
...@@ -436,9 +476,7 @@ input[type="range"]{ ...@@ -436,9 +476,7 @@ input[type="range"]{
#modalImage { #modalImage {
display: block; display: block;
margin-left: auto; margin: auto;
margin-right: auto;
margin-top: auto;
width: auto; width: auto;
} }
...@@ -839,6 +877,27 @@ footer { ...@@ -839,6 +877,27 @@ footer {
margin-left: 0.5em; margin-left: 0.5em;
} }
.extra-network-cards .card .metadata-button:before, .extra-network-thumbs .card .metadata-button:before{
content: "🛈";
}
.extra-network-cards .card .metadata-button, .extra-network-thumbs .card .metadata-button{
display: none;
position: absolute;
right: 0;
color: white;
text-shadow: 2px 2px 3px black;
padding: 0.25em;
font-size: 22pt;
}
.extra-network-cards .card:hover .metadata-button, .extra-network-thumbs .card:hover .metadata-button{
display: inline-block;
}
.extra-network-cards .card .metadata-button:hover, .extra-network-thumbs .card .metadata-button:hover{
color: red;
}
.extra-network-thumbs { .extra-network-thumbs {
display: flex; display: flex;
flex-flow: row wrap; flex-flow: row wrap;
...@@ -856,7 +915,7 @@ footer { ...@@ -856,7 +915,7 @@ footer {
} }
.extra-network-thumbs .card:hover .additional a { .extra-network-thumbs .card:hover .additional a {
display: block; display: inline-block;
} }
.extra-network-thumbs .actions .additional a { .extra-network-thumbs .actions .additional a {
...@@ -939,6 +998,17 @@ footer { ...@@ -939,6 +998,17 @@ footer {
line-break: anywhere; line-break: anywhere;
} }
.extra-network-cards .card .actions .description {
display: block;
max-height: 3em;
white-space: pre-wrap;
line-height: 1.1;
}
.extra-network-cards .card .actions .description:hover {
max-height: none;
}
.extra-network-cards .card .actions:hover .additional{ .extra-network-cards .card .actions:hover .additional{
display: block; display: block;
} }
......
import os
import unittest import unittest
import requests import requests
from gradio.processing_utils import encode_pil_to_base64 from gradio.processing_utils import encode_pil_to_base64
from PIL import Image from PIL import Image
from modules.paths import script_path
class TestExtrasWorking(unittest.TestCase): class TestExtrasWorking(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -19,7 +21,7 @@ class TestExtrasWorking(unittest.TestCase): ...@@ -19,7 +21,7 @@ class TestExtrasWorking(unittest.TestCase):
"upscaler_1": "None", "upscaler_1": "None",
"upscaler_2": "None", "upscaler_2": "None",
"extras_upscaler_2_visibility": 0, "extras_upscaler_2_visibility": 0,
"image": encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png")) "image": encode_pil_to_base64(Image.open(os.path.join(script_path, r"test/test_files/img2img_basic.png")))
} }
def test_simple_upscaling_performed(self): def test_simple_upscaling_performed(self):
...@@ -31,7 +33,7 @@ class TestPngInfoWorking(unittest.TestCase): ...@@ -31,7 +33,7 @@ class TestPngInfoWorking(unittest.TestCase):
def setUp(self): def setUp(self):
self.url_png_info = "http://localhost:7860/sdapi/v1/extra-single-image" self.url_png_info = "http://localhost:7860/sdapi/v1/extra-single-image"
self.png_info = { self.png_info = {
"image": encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png")) "image": encode_pil_to_base64(Image.open(os.path.join(script_path, r"test/test_files/img2img_basic.png")))
} }
def test_png_info_performed(self): def test_png_info_performed(self):
...@@ -42,7 +44,7 @@ class TestInterrogateWorking(unittest.TestCase): ...@@ -42,7 +44,7 @@ class TestInterrogateWorking(unittest.TestCase):
def setUp(self): def setUp(self):
self.url_interrogate = "http://localhost:7860/sdapi/v1/extra-single-image" self.url_interrogate = "http://localhost:7860/sdapi/v1/extra-single-image"
self.interrogate = { self.interrogate = {
"image": encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png")), "image": encode_pil_to_base64(Image.open(os.path.join(script_path, r"test/test_files/img2img_basic.png"))),
"model": "clip" "model": "clip"
} }
......
import os
import unittest import unittest
import requests import requests
from gradio.processing_utils import encode_pil_to_base64 from gradio.processing_utils import encode_pil_to_base64
from PIL import Image from PIL import Image
from modules.paths import script_path
class TestImg2ImgWorking(unittest.TestCase): class TestImg2ImgWorking(unittest.TestCase):
def setUp(self): def setUp(self):
self.url_img2img = "http://localhost:7860/sdapi/v1/img2img" self.url_img2img = "http://localhost:7860/sdapi/v1/img2img"
self.simple_img2img = { self.simple_img2img = {
"init_images": [encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png"))], "init_images": [encode_pil_to_base64(Image.open(os.path.join(script_path, r"test/test_files/img2img_basic.png")))],
"resize_mode": 0, "resize_mode": 0,
"denoising_strength": 0.75, "denoising_strength": 0.75,
"mask": None, "mask": None,
...@@ -47,11 +49,11 @@ class TestImg2ImgWorking(unittest.TestCase): ...@@ -47,11 +49,11 @@ class TestImg2ImgWorking(unittest.TestCase):
self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200) self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200)
def test_inpainting_masked_performed(self): def test_inpainting_masked_performed(self):
self.simple_img2img["mask"] = encode_pil_to_base64(Image.open(r"test/test_files/mask_basic.png")) self.simple_img2img["mask"] = encode_pil_to_base64(Image.open(os.path.join(script_path, r"test/test_files/img2img_basic.png")))
self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200) self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200)
def test_inpainting_with_inverted_masked_performed(self): def test_inpainting_with_inverted_masked_performed(self):
self.simple_img2img["mask"] = encode_pil_to_base64(Image.open(r"test/test_files/mask_basic.png")) self.simple_img2img["mask"] = encode_pil_to_base64(Image.open(os.path.join(script_path, r"test/test_files/img2img_basic.png")))
self.simple_img2img["inpainting_mask_invert"] = True self.simple_img2img["inpainting_mask_invert"] = True
self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200) self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200)
......
...@@ -66,6 +66,8 @@ class TestTxt2ImgWorking(unittest.TestCase): ...@@ -66,6 +66,8 @@ class TestTxt2ImgWorking(unittest.TestCase):
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
self.simple_txt2img["sampler_index"] = "DDIM" self.simple_txt2img["sampler_index"] = "DDIM"
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
self.simple_txt2img["sampler_index"] = "UniPC"
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
def test_txt2img_multiple_batches_performed(self): def test_txt2img_multiple_batches_performed(self):
self.simple_txt2img["n_iter"] = 2 self.simple_txt2img["n_iter"] = 2
......
import unittest import unittest
import requests import requests
import time import time
import os
from modules.paths import script_path
def run_tests(proc, test_dir): def run_tests(proc, test_dir):
...@@ -15,8 +17,8 @@ def run_tests(proc, test_dir): ...@@ -15,8 +17,8 @@ def run_tests(proc, test_dir):
break break
if proc.poll() is None: if proc.poll() is None:
if test_dir is None: if test_dir is None:
test_dir = "test" test_dir = os.path.join(script_path, "test")
suite = unittest.TestLoader().discover(test_dir, pattern="*_test.py", top_level_dir="test") suite = unittest.TestLoader().discover(test_dir, pattern="*_test.py", top_level_dir=test_dir)
result = unittest.TextTestRunner(verbosity=2).run(suite) result = unittest.TextTestRunner(verbosity=2).run(suite)
return len(result.failures) + len(result.errors) return len(result.failures) + len(result.errors)
else: else:
......
...@@ -12,11 +12,22 @@ from packaging import version ...@@ -12,11 +12,22 @@ from packaging import version
import logging import logging
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
from modules import import_hook, errors, extra_networks, ui_extra_networks_checkpoints from modules import paths, timer, import_hook, errors
from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call startup_timer = timer.Timer()
import torch import torch
startup_timer.record("import torch")
import gradio
startup_timer.record("import gradio")
import ldm.modules.encoders.modules
startup_timer.record("import ldm")
from modules import extra_networks, ui_extra_networks_checkpoints
from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors # Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
if ".dev" in torch.__version__ or "+git" in torch.__version__: if ".dev" in torch.__version__ or "+git" in torch.__version__:
...@@ -30,7 +41,6 @@ import modules.gfpgan_model as gfpgan ...@@ -30,7 +41,6 @@ import modules.gfpgan_model as gfpgan
import modules.img2img import modules.img2img
import modules.lowvram import modules.lowvram
import modules.paths
import modules.scripts import modules.scripts
import modules.sd_hijack import modules.sd_hijack
import modules.sd_models import modules.sd_models
...@@ -45,6 +55,8 @@ from modules import modelloader ...@@ -45,6 +55,8 @@ from modules import modelloader
from modules.shared import cmd_opts from modules.shared import cmd_opts
import modules.hypernetworks.hypernetwork import modules.hypernetworks.hypernetwork
startup_timer.record("other imports")
if cmd_opts.server_name: if cmd_opts.server_name:
server_name = cmd_opts.server_name server_name = cmd_opts.server_name
...@@ -88,6 +100,7 @@ def initialize(): ...@@ -88,6 +100,7 @@ def initialize():
extensions.list_extensions() extensions.list_extensions()
localization.list_localizations(cmd_opts.localizations_dir) localization.list_localizations(cmd_opts.localizations_dir)
startup_timer.record("list extensions")
if cmd_opts.ui_debug_mode: if cmd_opts.ui_debug_mode:
shared.sd_upscalers = upscaler.UpscalerLanczos().scalers shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
...@@ -96,16 +109,28 @@ def initialize(): ...@@ -96,16 +109,28 @@ def initialize():
modelloader.cleanup_models() modelloader.cleanup_models()
modules.sd_models.setup_model() modules.sd_models.setup_model()
startup_timer.record("list SD models")
codeformer.setup_model(cmd_opts.codeformer_models_path) codeformer.setup_model(cmd_opts.codeformer_models_path)
startup_timer.record("setup codeformer")
gfpgan.setup_model(cmd_opts.gfpgan_models_path) gfpgan.setup_model(cmd_opts.gfpgan_models_path)
startup_timer.record("setup gfpgan")
modelloader.list_builtin_upscalers() modelloader.list_builtin_upscalers()
startup_timer.record("list builtin upscalers")
modules.scripts.load_scripts() modules.scripts.load_scripts()
startup_timer.record("load scripts")
modelloader.load_upscalers() modelloader.load_upscalers()
startup_timer.record("load upscalers")
modules.sd_vae.refresh_vae_list() modules.sd_vae.refresh_vae_list()
startup_timer.record("refresh VAE")
modules.textual_inversion.textual_inversion.list_textual_inversion_templates() modules.textual_inversion.textual_inversion.list_textual_inversion_templates()
startup_timer.record("refresh textual inversion templates")
try: try:
modules.sd_models.load_model() modules.sd_models.load_model()
...@@ -114,6 +139,7 @@ def initialize(): ...@@ -114,6 +139,7 @@ def initialize():
print("", file=sys.stderr) print("", file=sys.stderr)
print("Stable diffusion model failed to load, exiting", file=sys.stderr) print("Stable diffusion model failed to load, exiting", file=sys.stderr)
exit(1) exit(1)
startup_timer.record("load SD checkpoint")
shared.opts.data["sd_model_checkpoint"] = shared.sd_model.sd_checkpoint_info.title shared.opts.data["sd_model_checkpoint"] = shared.sd_model.sd_checkpoint_info.title
...@@ -121,8 +147,10 @@ def initialize(): ...@@ -121,8 +147,10 @@ def initialize():
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed) shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
startup_timer.record("opts onchange")
shared.reload_hypernetworks() shared.reload_hypernetworks()
startup_timer.record("reload hypernets")
ui_extra_networks.intialize() ui_extra_networks.intialize()
ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion()) ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
...@@ -131,6 +159,7 @@ def initialize(): ...@@ -131,6 +159,7 @@ def initialize():
extra_networks.initialize() extra_networks.initialize()
extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet()) extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
startup_timer.record("extra networks")
if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None: if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None:
...@@ -144,6 +173,7 @@ def initialize(): ...@@ -144,6 +173,7 @@ def initialize():
print("TLS setup invalid, running webui without TLS") print("TLS setup invalid, running webui without TLS")
else: else:
print("Running with TLS") print("Running with TLS")
startup_timer.record("TLS")
# make the program just exit at ctrl+c without waiting for anything # make the program just exit at ctrl+c without waiting for anything
def sigint_handler(sig, frame): def sigint_handler(sig, frame):
...@@ -153,13 +183,16 @@ def initialize(): ...@@ -153,13 +183,16 @@ def initialize():
signal.signal(signal.SIGINT, sigint_handler) signal.signal(signal.SIGINT, sigint_handler)
def setup_cors(app): def setup_middleware(app):
app.middleware_stack = None # reset current middleware to allow modifying user provided list
app.add_middleware(GZipMiddleware, minimum_size=1000)
if cmd_opts.cors_allow_origins and cmd_opts.cors_allow_origins_regex: if cmd_opts.cors_allow_origins and cmd_opts.cors_allow_origins_regex:
app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*']) app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
elif cmd_opts.cors_allow_origins: elif cmd_opts.cors_allow_origins:
app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*'], allow_credentials=True, allow_headers=['*']) app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
elif cmd_opts.cors_allow_origins_regex: elif cmd_opts.cors_allow_origins_regex:
app.add_middleware(CORSMiddleware, allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*']) app.add_middleware(CORSMiddleware, allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
app.build_middleware_stack() # rebuild middleware stack on-the-fly
def create_api(app): def create_api(app):
...@@ -183,12 +216,12 @@ def api_only(): ...@@ -183,12 +216,12 @@ def api_only():
initialize() initialize()
app = FastAPI() app = FastAPI()
setup_cors(app) setup_middleware(app)
app.add_middleware(GZipMiddleware, minimum_size=1000)
api = create_api(app) api = create_api(app)
modules.script_callbacks.app_started_callback(None, app) modules.script_callbacks.app_started_callback(None, app)
print(f"Startup time: {startup_timer.summary()}.")
api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861) api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
...@@ -199,21 +232,24 @@ def webui(): ...@@ -199,21 +232,24 @@ def webui():
while 1: while 1:
if shared.opts.clean_temp_dir_at_start: if shared.opts.clean_temp_dir_at_start:
ui_tempdir.cleanup_tmpdr() ui_tempdir.cleanup_tmpdr()
startup_timer.record("cleanup temp dir")
modules.script_callbacks.before_ui_callback() modules.script_callbacks.before_ui_callback()
startup_timer.record("scripts before_ui_callback")
shared.demo = modules.ui.create_ui() shared.demo = modules.ui.create_ui()
startup_timer.record("create ui")
if cmd_opts.gradio_queue: if cmd_opts.gradio_queue:
shared.demo.queue(64) shared.demo.queue(64)
gradio_auth_creds = [] gradio_auth_creds = []
if cmd_opts.gradio_auth: if cmd_opts.gradio_auth:
gradio_auth_creds += cmd_opts.gradio_auth.strip('"').replace('\n', '').split(',') gradio_auth_creds += [x.strip() for x in cmd_opts.gradio_auth.strip('"').replace('\n', '').split(',') if x.strip()]
if cmd_opts.gradio_auth_path: if cmd_opts.gradio_auth_path:
with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file: with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file:
for line in file.readlines(): for line in file.readlines():
gradio_auth_creds += [x.strip() for x in line.split(',')] gradio_auth_creds += [x.strip() for x in line.split(',') if x.strip()]
app, local_url, share_url = shared.demo.launch( app, local_url, share_url = shared.demo.launch(
share=cmd_opts.share, share=cmd_opts.share,
...@@ -229,15 +265,15 @@ def webui(): ...@@ -229,15 +265,15 @@ def webui():
# after initial launch, disable --autolaunch for subsequent restarts # after initial launch, disable --autolaunch for subsequent restarts
cmd_opts.autolaunch = False cmd_opts.autolaunch = False
startup_timer.record("gradio launch")
# gradio uses a very open CORS policy via app.user_middleware, which makes it possible for # gradio uses a very open CORS policy via app.user_middleware, which makes it possible for
# an attacker to trick the user into opening a malicious HTML page, which makes a request to the # an attacker to trick the user into opening a malicious HTML page, which makes a request to the
# running web ui and do whatever the attacker wants, including installing an extension and # running web ui and do whatever the attacker wants, including installing an extension and
# running its code. We disable this here. Suggested by RyotaK. # running its code. We disable this here. Suggested by RyotaK.
app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware'] app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware']
setup_cors(app) setup_middleware(app)
app.add_middleware(GZipMiddleware, minimum_size=1000)
modules.progress.setup_progress_api(app) modules.progress.setup_progress_api(app)
...@@ -247,28 +283,42 @@ def webui(): ...@@ -247,28 +283,42 @@ def webui():
ui_extra_networks.add_pages_to_demo(app) ui_extra_networks.add_pages_to_demo(app)
modules.script_callbacks.app_started_callback(shared.demo, app) modules.script_callbacks.app_started_callback(shared.demo, app)
startup_timer.record("scripts app_started_callback")
print(f"Startup time: {startup_timer.summary()}.")
wait_on_server(shared.demo) wait_on_server(shared.demo)
print('Restarting UI...') print('Restarting UI...')
startup_timer.reset()
sd_samplers.set_samplers() sd_samplers.set_samplers()
modules.script_callbacks.script_unloaded_callback() modules.script_callbacks.script_unloaded_callback()
extensions.list_extensions() extensions.list_extensions()
startup_timer.record("list extensions")
localization.list_localizations(cmd_opts.localizations_dir) localization.list_localizations(cmd_opts.localizations_dir)
modelloader.forbid_loaded_nonbuiltin_upscalers() modelloader.forbid_loaded_nonbuiltin_upscalers()
modules.scripts.reload_scripts() modules.scripts.reload_scripts()
startup_timer.record("load scripts")
modules.script_callbacks.model_loaded_callback(shared.sd_model) modules.script_callbacks.model_loaded_callback(shared.sd_model)
startup_timer.record("model loaded callback")
modelloader.load_upscalers() modelloader.load_upscalers()
startup_timer.record("load upscalers")
for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]: for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
importlib.reload(module) importlib.reload(module)
startup_timer.record("reload script modules")
modules.sd_models.list_models() modules.sd_models.list_models()
startup_timer.record("list SD models")
shared.reload_hypernetworks() shared.reload_hypernetworks()
startup_timer.record("reload hypernetworks")
ui_extra_networks.intialize() ui_extra_networks.intialize()
ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion()) ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
...@@ -277,6 +327,7 @@ def webui(): ...@@ -277,6 +327,7 @@ def webui():
extra_networks.initialize() extra_networks.initialize()
extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet()) extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
startup_timer.record("initialize extra networks")
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment