Commit 63a2f8d8 authored by Karun's avatar Karun Committed by GitHub

Merge branch 'master' into master

parents ca2b8faa 70615448
...@@ -157,5 +157,6 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al ...@@ -157,5 +157,6 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
- Sampling in float32 precision from a float16 UNet - marunine for the idea, Birch-san for the example Diffusers implementation (https://github.com/Birch-san/diffusers-play/tree/92feee6) - Sampling in float32 precision from a float16 UNet - marunine for the idea, Birch-san for the example Diffusers implementation (https://github.com/Birch-san/diffusers-play/tree/92feee6)
- Instruct pix2pix - Tim Brooks (star), Aleksander Holynski (star), Alexei A. Efros (no star) - https://github.com/timothybrooks/instruct-pix2pix - Instruct pix2pix - Tim Brooks (star), Aleksander Holynski (star), Alexei A. Efros (no star) - https://github.com/timothybrooks/instruct-pix2pix
- Security advice - RyotaK - Security advice - RyotaK
- UniPC sampler - Wenliang Zhao - https://github.com/wl-zhao/UniPC
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user. - Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
- (You) - (You)
...@@ -3,7 +3,9 @@ import os ...@@ -3,7 +3,9 @@ import os
import re import re
import torch import torch
from modules import shared, devices, sd_models from modules import shared, devices, sd_models, errors
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
re_digits = re.compile(r"\d+") re_digits = re.compile(r"\d+")
re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)") re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)")
...@@ -43,6 +45,23 @@ class LoraOnDisk: ...@@ -43,6 +45,23 @@ class LoraOnDisk:
def __init__(self, name, filename): def __init__(self, name, filename):
self.name = name self.name = name
self.filename = filename self.filename = filename
self.metadata = {}
_, ext = os.path.splitext(filename)
if ext.lower() == ".safetensors":
try:
self.metadata = sd_models.read_metadata_from_safetensors(filename)
except Exception as e:
errors.display(e, f"reading lora {filename}")
if self.metadata:
m = {}
for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)):
m[k] = v
self.metadata = m
self.ssmd_cover_images = self.metadata.pop('ssmd_cover_images', None) # those are cover images and they are too big to display in UI as text
class LoraModule: class LoraModule:
...@@ -159,6 +178,7 @@ def load_loras(names, multipliers=None): ...@@ -159,6 +178,7 @@ def load_loras(names, multipliers=None):
def lora_forward(module, input, res): def lora_forward(module, input, res):
input = devices.cond_cast_unet(input)
if len(loaded_loras) == 0: if len(loaded_loras) == 0:
return res return res
......
...@@ -15,21 +15,15 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): ...@@ -15,21 +15,15 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
def list_items(self): def list_items(self):
for name, lora_on_disk in lora.available_loras.items(): for name, lora_on_disk in lora.available_loras.items():
path, ext = os.path.splitext(lora_on_disk.filename) path, ext = os.path.splitext(lora_on_disk.filename)
previews = [path + ".png", path + ".preview.png"]
preview = None
for file in previews:
if os.path.isfile(file):
preview = self.link_preview(file)
break
yield { yield {
"name": name, "name": name,
"filename": path, "filename": path,
"preview": preview, "preview": self.find_preview(path),
"description": self.find_description(path),
"search_term": self.search_terms_from_path(lora_on_disk.filename), "search_term": self.search_terms_from_path(lora_on_disk.filename),
"prompt": json.dumps(f"<lora:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"), "prompt": json.dumps(f"<lora:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
"local_preview": path + ".png", "local_preview": f"{path}.{shared.opts.samples_format}",
"metadata": json.dumps(lora_on_disk.metadata, indent=4) if lora_on_disk.metadata else None,
} }
def allowed_directories_for_previews(self): def allowed_directories_for_previews(self):
......
...@@ -89,22 +89,15 @@ function checkBrackets(evt, textArea, counterElt) { ...@@ -89,22 +89,15 @@ function checkBrackets(evt, textArea, counterElt) {
function setupBracketChecking(id_prompt, id_counter){ function setupBracketChecking(id_prompt, id_counter){
var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea"); var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea");
var counter = gradioApp().getElementById(id_counter) var counter = gradioApp().getElementById(id_counter)
textarea.addEventListener("input", function(evt){ textarea.addEventListener("input", function(evt){
checkBrackets(evt, textarea, counter) checkBrackets(evt, textarea, counter)
}); });
} }
var shadowRootLoaded = setInterval(function() { onUiLoaded(function(){
var shadowRoot = document.querySelector('gradio-app').shadowRoot;
if(! shadowRoot) return false;
var shadowTextArea = shadowRoot.querySelectorAll('#txt2img_prompt > label > textarea');
if(shadowTextArea.length < 1) return false;
clearInterval(shadowRootLoaded);
setupBracketChecking('txt2img_prompt', 'txt2img_token_counter') setupBracketChecking('txt2img_prompt', 'txt2img_token_counter')
setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter') setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter')
setupBracketChecking('img2img_prompt', 'imgimg_token_counter') setupBracketChecking('img2img_prompt', 'img2img_token_counter')
setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter') setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter')
}, 1000); })
\ No newline at end of file
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
<span style="display:none" class='search_term'>{search_term}</span> <span style="display:none" class='search_term'>{search_term}</span>
</div> </div>
<span class='name'>{name}</span> <span class='name'>{name}</span>
<span class='description'>{description}</span>
</div> </div>
</div> </div>
This diff is collapsed.
...@@ -43,7 +43,7 @@ contextMenuInit = function(){ ...@@ -43,7 +43,7 @@ contextMenuInit = function(){
}) })
gradioApp().getRootNode().appendChild(contextMenu) gradioApp().appendChild(contextMenu)
let menuWidth = contextMenu.offsetWidth + 4; let menuWidth = contextMenu.offsetWidth + 4;
let menuHeight = contextMenu.offsetHeight + 4; let menuHeight = contextMenu.offsetHeight + 4;
......
function keyupEditAttention(event){ function keyupEditAttention(event){
let target = event.originalTarget || event.composedPath()[0]; let target = event.originalTarget || event.composedPath()[0];
if (!target.matches("[id*='_toprow'] textarea.gr-text-input[placeholder]")) return; if (! target.matches("[id*='_toprow'] [id*='_prompt'] textarea")) return;
if (! (event.metaKey || event.ctrlKey)) return; if (! (event.metaKey || event.ctrlKey)) return;
let isPlus = event.key == "ArrowUp" let isPlus = event.key == "ArrowUp"
......
...@@ -5,12 +5,10 @@ function setupExtraNetworksForTab(tabname){ ...@@ -5,12 +5,10 @@ function setupExtraNetworksForTab(tabname){
var tabs = gradioApp().querySelector('#'+tabname+'_extra_tabs > div') var tabs = gradioApp().querySelector('#'+tabname+'_extra_tabs > div')
var search = gradioApp().querySelector('#'+tabname+'_extra_search textarea') var search = gradioApp().querySelector('#'+tabname+'_extra_search textarea')
var refresh = gradioApp().getElementById(tabname+'_extra_refresh') var refresh = gradioApp().getElementById(tabname+'_extra_refresh')
var close = gradioApp().getElementById(tabname+'_extra_close')
search.classList.add('search') search.classList.add('search')
tabs.appendChild(search) tabs.appendChild(search)
tabs.appendChild(refresh) tabs.appendChild(refresh)
tabs.appendChild(close)
search.addEventListener("input", function(evt){ search.addEventListener("input", function(evt){
searchTerm = search.value.toLowerCase() searchTerm = search.value.toLowerCase()
...@@ -78,7 +76,7 @@ function cardClicked(tabname, textToAdd, allowNegativePrompt){ ...@@ -78,7 +76,7 @@ function cardClicked(tabname, textToAdd, allowNegativePrompt){
var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea") var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea")
if(! tryToRemoveExtraNetworkFromPrompt(textarea, textToAdd)){ if(! tryToRemoveExtraNetworkFromPrompt(textarea, textToAdd)){
textarea.value = textarea.value + " " + textToAdd textarea.value = textarea.value + opts.extra_networks_add_text_separator + textToAdd
} }
updateInput(textarea) updateInput(textarea)
...@@ -104,4 +102,76 @@ function extraNetworksSearchButton(tabs_id, event){ ...@@ -104,4 +102,76 @@ function extraNetworksSearchButton(tabs_id, event){
searchTextarea.value = text searchTextarea.value = text
updateInput(searchTextarea) updateInput(searchTextarea)
} }
\ No newline at end of file
var globalPopup = null;
var globalPopupInner = null;
function popup(contents){
if(! globalPopup){
globalPopup = document.createElement('div')
globalPopup.onclick = function(){ globalPopup.style.display = "none"; };
globalPopup.classList.add('global-popup');
var close = document.createElement('div')
close.classList.add('global-popup-close');
close.onclick = function(){ globalPopup.style.display = "none"; };
close.title = "Close";
globalPopup.appendChild(close)
globalPopupInner = document.createElement('div')
globalPopupInner.onclick = function(event){ event.stopPropagation(); return false; };
globalPopupInner.classList.add('global-popup-inner');
globalPopup.appendChild(globalPopupInner)
gradioApp().appendChild(globalPopup);
}
globalPopupInner.innerHTML = '';
globalPopupInner.appendChild(contents);
globalPopup.style.display = "flex";
}
function extraNetworksShowMetadata(text){
elem = document.createElement('pre')
elem.classList.add('popup-metadata');
elem.textContent = text;
popup(elem);
}
function requestGet(url, data, handler, errorHandler){
var xhr = new XMLHttpRequest();
var args = Object.keys(data).map(function(k){ return encodeURIComponent(k) + '=' + encodeURIComponent(data[k]) }).join('&')
xhr.open("GET", url + "?" + args, true);
xhr.onreadystatechange = function () {
if (xhr.readyState === 4) {
if (xhr.status === 200) {
try {
var js = JSON.parse(xhr.responseText);
handler(js)
} catch (error) {
console.error(error);
errorHandler()
}
} else{
errorHandler()
}
}
};
var js = JSON.stringify(data);
xhr.send(js);
}
function extraNetworksRequestMetadata(extraPage, cardName){
showError = function(){ extraNetworksShowMetadata("there was an error getting metadata"); }
requestGet("./sd_extra_networks/metadata", {"page": extraPage, "item": cardName}, function(data){
if(data && data.metadata){
extraNetworksShowMetadata(data.metadata)
} else{
showError()
}
}, showError)
}
...@@ -6,6 +6,7 @@ titles = { ...@@ -6,6 +6,7 @@ titles = {
"GFPGAN": "Restore low quality faces using GFPGAN neural network", "GFPGAN": "Restore low quality faces using GFPGAN neural network",
"Euler a": "Euler Ancestral - very creative, each can get a completely different picture depending on step count, setting steps higher than 30-40 does not help", "Euler a": "Euler Ancestral - very creative, each can get a completely different picture depending on step count, setting steps higher than 30-40 does not help",
"DDIM": "Denoising Diffusion Implicit Models - best at inpainting", "DDIM": "Denoising Diffusion Implicit Models - best at inpainting",
"UniPC": "Unified Predictor-Corrector Framework for Fast Sampling of Diffusion Models",
"DPM adaptive": "Ignores step count - uses a number of steps determined by the CFG and resolution", "DPM adaptive": "Ignores step count - uses a number of steps determined by the CFG and resolution",
"Batch count": "How many batches of images to create (has no impact on generation performance or VRAM usage)", "Batch count": "How many batches of images to create (has no impact on generation performance or VRAM usage)",
...@@ -17,7 +18,7 @@ titles = { ...@@ -17,7 +18,7 @@ titles = {
"\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.", "\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.",
"\u{1f4c2}": "Open images output directory", "\u{1f4c2}": "Open images output directory",
"\u{1f4be}": "Save style", "\u{1f4be}": "Save style",
"\u{1f5d1}": "Clear prompt", "\u{1f5d1}\ufe0f": "Clear prompt",
"\u{1f4cb}": "Apply selected styles to current prompt", "\u{1f4cb}": "Apply selected styles to current prompt",
"\u{1f4d2}": "Paste available values into the field", "\u{1f4d2}": "Paste available values into the field",
"\u{1f3b4}": "Show extra networks", "\u{1f3b4}": "Show extra networks",
...@@ -39,8 +40,7 @@ titles = { ...@@ -39,8 +40,7 @@ titles = {
"Inpaint at full resolution": "Upscale masked region to target resolution, do inpainting, downscale back and paste into original image", "Inpaint at full resolution": "Upscale masked region to target resolution, do inpainting, downscale back and paste into original image",
"Denoising strength": "Determines how little respect the algorithm should have for image's content. At 0, nothing will change, and at 1 you'll get an unrelated image. With values below 1.0, processing will take less steps than the Sampling Steps slider specifies.", "Denoising strength": "Determines how little respect the algorithm should have for image's content. At 0, nothing will change, and at 1 you'll get an unrelated image. With values below 1.0, processing will take less steps than the Sampling Steps slider specifies.",
"Denoising strength change factor": "In loopback mode, on each loop the denoising strength is multiplied by this value. <1 means decreasing variety so your sequence will converge on a fixed picture. >1 means increasing variety so your sequence will become more and more chaotic.",
"Skip": "Stop processing current image and continue processing.", "Skip": "Stop processing current image and continue processing.",
"Interrupt": "Stop processing images and return any results accumulated so far.", "Interrupt": "Stop processing images and return any results accumulated so far.",
"Save": "Write image to a directory (default - log/images) and generation parameters into csv file.", "Save": "Write image to a directory (default - log/images) and generation parameters into csv file.",
...@@ -70,8 +70,10 @@ titles = { ...@@ -70,8 +70,10 @@ titles = {
"Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg],[prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.", "Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg],[prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.",
"Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle", "Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle",
"Loopback": "Process an image, use it as an input, repeat.", "Loopback": "Performs img2img processing multiple times. Output images are used as input for the next loop.",
"Loops": "How many times to repeat processing an image and using it as input for the next iteration", "Loops": "How many times to process an image. Each output is used as the input of the next loop. If set to 1, behavior will be as if this script were not used.",
"Final denoising strength": "The denoising strength for the final loop of each image in the batch.",
"Denoising strength curve": "The denoising curve controls the rate of denoising strength change each loop. Aggressive: Most of the change will happen towards the start of the loops. Linear: Change will be constant through all loops. Lazy: Most of the change will happen towards the end of the loops.",
"Style 1": "Style to apply; styles have components for both positive and negative prompts and apply to both", "Style 1": "Style to apply; styles have components for both positive and negative prompts and apply to both",
"Style 2": "Style to apply; styles have components for both positive and negative prompts and apply to both", "Style 2": "Style to apply; styles have components for both positive and negative prompts and apply to both",
......
...@@ -11,7 +11,7 @@ function showModal(event) { ...@@ -11,7 +11,7 @@ function showModal(event) {
if (modalImage.style.display === 'none') { if (modalImage.style.display === 'none') {
lb.style.setProperty('background-image', 'url(' + source.src + ')'); lb.style.setProperty('background-image', 'url(' + source.src + ')');
} }
lb.style.display = "block"; lb.style.display = "flex";
lb.focus() lb.focus()
const tabTxt2Img = gradioApp().getElementById("tab_txt2img") const tabTxt2Img = gradioApp().getElementById("tab_txt2img")
...@@ -50,7 +50,7 @@ function updateOnBackgroundChange() { ...@@ -50,7 +50,7 @@ function updateOnBackgroundChange() {
} }
function modalImageSwitch(offset) { function modalImageSwitch(offset) {
var allgalleryButtons = gradioApp().querySelectorAll(".gallery-item.transition-all") var allgalleryButtons = gradioApp().querySelectorAll(".gradio-gallery .thumbnail-item")
var galleryButtons = [] var galleryButtons = []
allgalleryButtons.forEach(function(elem) { allgalleryButtons.forEach(function(elem) {
if (elem.parentElement.offsetParent) { if (elem.parentElement.offsetParent) {
...@@ -59,7 +59,7 @@ function modalImageSwitch(offset) { ...@@ -59,7 +59,7 @@ function modalImageSwitch(offset) {
}) })
if (galleryButtons.length > 1) { if (galleryButtons.length > 1) {
var allcurrentButtons = gradioApp().querySelectorAll(".gallery-item.transition-all.\\!ring-2") var allcurrentButtons = gradioApp().querySelectorAll(".gradio-gallery .thumbnail-item.selected")
var currentButton = null var currentButton = null
allcurrentButtons.forEach(function(elem) { allcurrentButtons.forEach(function(elem) {
if (elem.parentElement.offsetParent) { if (elem.parentElement.offsetParent) {
...@@ -136,37 +136,29 @@ function modalKeyHandler(event) { ...@@ -136,37 +136,29 @@ function modalKeyHandler(event) {
} }
} }
function showGalleryImage() { function setupImageForLightbox(e) {
setTimeout(function() { if (e.dataset.modded)
fullImg_preview = gradioApp().querySelectorAll('img.w-full.object-contain') return;
if (fullImg_preview != null) { e.dataset.modded = true;
fullImg_preview.forEach(function function_name(e) { e.style.cursor='pointer'
if (e.dataset.modded) e.style.userSelect='none'
return;
e.dataset.modded = true; var isFirefox = navigator.userAgent.toLowerCase().indexOf('firefox') > -1
if(e && e.parentElement.tagName == 'DIV'){
e.style.cursor='pointer' // For Firefox, listening on click first switched to next image then shows the lightbox.
e.style.userSelect='none' // If you know how to fix this without switching to mousedown event, please.
// For other browsers the event is click to make it possiblr to drag picture.
var isFirefox = isFirefox = navigator.userAgent.toLowerCase().indexOf('firefox') > -1 var event = isFirefox ? 'mousedown' : 'click'
// For Firefox, listening on click first switched to next image then shows the lightbox. e.addEventListener(event, function (evt) {
// If you know how to fix this without switching to mousedown event, please. if(!opts.js_modal_lightbox || evt.button != 0) return;
// For other browsers the event is click to make it possiblr to drag picture.
var event = isFirefox ? 'mousedown' : 'click' modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed)
evt.preventDefault()
e.addEventListener(event, function (evt) { showModal(evt)
if(!opts.js_modal_lightbox || evt.button != 0) return; }, true);
modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed)
evt.preventDefault()
showModal(evt)
}, true);
}
});
}
}, 100);
} }
function modalZoomSet(modalImage, enable) { function modalZoomSet(modalImage, enable) {
...@@ -199,21 +191,21 @@ function modalTileImageToggle(event) { ...@@ -199,21 +191,21 @@ function modalTileImageToggle(event) {
} }
function galleryImageHandler(e) { function galleryImageHandler(e) {
if (e && e.parentElement.tagName == 'BUTTON') { //if (e && e.parentElement.tagName == 'BUTTON') {
e.onclick = showGalleryImage; e.onclick = showGalleryImage;
} //}
} }
onUiUpdate(function() { onUiUpdate(function() {
fullImg_preview = gradioApp().querySelectorAll('img.w-full') fullImg_preview = gradioApp().querySelectorAll('.gradio-gallery > div > img')
if (fullImg_preview != null) { if (fullImg_preview != null) {
fullImg_preview.forEach(galleryImageHandler); fullImg_preview.forEach(setupImageForLightbox);
} }
updateOnBackgroundChange(); updateOnBackgroundChange();
}) })
document.addEventListener("DOMContentLoaded", function() { document.addEventListener("DOMContentLoaded", function() {
const modalFragment = document.createDocumentFragment(); //const modalFragment = document.createDocumentFragment();
const modal = document.createElement('div') const modal = document.createElement('div')
modal.onclick = closeModal; modal.onclick = closeModal;
modal.id = "lightboxModal"; modal.id = "lightboxModal";
...@@ -277,9 +269,9 @@ document.addEventListener("DOMContentLoaded", function() { ...@@ -277,9 +269,9 @@ document.addEventListener("DOMContentLoaded", function() {
modal.appendChild(modalNext) modal.appendChild(modalNext)
gradioApp().appendChild(modal)
gradioApp().getRootNode().appendChild(modal)
document.body.appendChild(modalFragment); document.body.appendChild(modal);
}); });
...@@ -15,7 +15,7 @@ onUiUpdate(function(){ ...@@ -15,7 +15,7 @@ onUiUpdate(function(){
} }
} }
const galleryPreviews = gradioApp().querySelectorAll('div[id^="tab_"][style*="display: block"] img.h-full.w-full.overflow-hidden'); const galleryPreviews = gradioApp().querySelectorAll('div[id^="tab_"][style*="display: block"] div[id$="_results"] img.h-full.w-full.overflow-hidden');
if (galleryPreviews == null) return; if (galleryPreviews == null) return;
......
// code related to showing and updating progressbar shown as the image is being made // code related to showing and updating progressbar shown as the image is being made
galleries = {}
storedGallerySelections = {}
galleryObservers = {}
function rememberGallerySelection(id_gallery){ function rememberGallerySelection(id_gallery){
storedGallerySelections[id_gallery] = getGallerySelectedIndex(id_gallery)
}
function getGallerySelectedIndex(id_gallery){
let galleryButtons = gradioApp().querySelectorAll('#'+id_gallery+' .gallery-item')
let galleryBtnSelected = gradioApp().querySelector('#'+id_gallery+' .gallery-item.\\!ring-2')
let currentlySelectedIndex = -1
galleryButtons.forEach(function(v, i){ if(v==galleryBtnSelected) { currentlySelectedIndex = i } })
return currentlySelectedIndex
} }
// this is a workaround for https://github.com/gradio-app/gradio/issues/2984 function getGallerySelectedIndex(id_gallery){
function check_gallery(id_gallery){
let gallery = gradioApp().getElementById(id_gallery)
// if gallery has no change, no need to setting up observer again.
if (gallery && galleries[id_gallery] !== gallery){
galleries[id_gallery] = gallery;
if(galleryObservers[id_gallery]){
galleryObservers[id_gallery].disconnect();
}
storedGallerySelections[id_gallery] = -1
galleryObservers[id_gallery] = new MutationObserver(function (){
let galleryButtons = gradioApp().querySelectorAll('#'+id_gallery+' .gallery-item')
let galleryBtnSelected = gradioApp().querySelector('#'+id_gallery+' .gallery-item.\\!ring-2')
let currentlySelectedIndex = getGallerySelectedIndex(id_gallery)
prevSelectedIndex = storedGallerySelections[id_gallery]
storedGallerySelections[id_gallery] = -1
if (prevSelectedIndex !== -1 && galleryButtons.length>prevSelectedIndex && !galleryBtnSelected) {
// automatically re-open previously selected index (if exists)
activeElement = gradioApp().activeElement;
let scrollX = window.scrollX;
let scrollY = window.scrollY;
galleryButtons[prevSelectedIndex].click();
showGalleryImage();
// When the gallery button is clicked, it gains focus and scrolls itself into view
// We need to scroll back to the previous position
setTimeout(function (){
window.scrollTo(scrollX, scrollY);
}, 50);
if(activeElement){
// i fought this for about an hour; i don't know why the focus is lost or why this helps recover it
// if someone has a better solution please by all means
setTimeout(function (){
activeElement.focus({
preventScroll: true // Refocus the element that was focused before the gallery was opened without scrolling to it
})
}, 1);
}
}
})
galleryObservers[id_gallery].observe( gallery, { childList:true, subtree:false })
}
} }
onUiUpdate(function(){
check_gallery('txt2img_gallery')
check_gallery('img2img_gallery')
})
function request(url, data, handler, errorHandler){ function request(url, data, handler, errorHandler){
var xhr = new XMLHttpRequest(); var xhr = new XMLHttpRequest();
var url = url; var url = url;
...@@ -139,7 +74,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre ...@@ -139,7 +74,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
var divProgress = document.createElement('div') var divProgress = document.createElement('div')
divProgress.className='progressDiv' divProgress.className='progressDiv'
divProgress.style.display = opts.show_progressbar ? "" : "none" divProgress.style.display = opts.show_progressbar ? "block" : "none"
var divInner = document.createElement('div') var divInner = document.createElement('div')
divInner.className='progress' divInner.className='progress'
......
...@@ -86,7 +86,7 @@ function get_tab_index(tabId){ ...@@ -86,7 +86,7 @@ function get_tab_index(tabId){
var res = 0 var res = 0
gradioApp().getElementById(tabId).querySelector('div').querySelectorAll('button').forEach(function(button, i){ gradioApp().getElementById(tabId).querySelector('div').querySelectorAll('button').forEach(function(button, i){
if(button.className.indexOf('bg-white') != -1) if(button.className.indexOf('selected') != -1)
res = i res = i
}) })
...@@ -255,7 +255,6 @@ onUiUpdate(function(){ ...@@ -255,7 +255,6 @@ onUiUpdate(function(){
} }
prompt.parentElement.insertBefore(counter, prompt) prompt.parentElement.insertBefore(counter, prompt)
counter.classList.add("token-counter")
prompt.parentElement.style.position = "relative" prompt.parentElement.style.position = "relative"
promptTokecountUpdateFuncs[id] = function(){ update_token_counter(id_button); } promptTokecountUpdateFuncs[id] = function(){ update_token_counter(id_button); }
......
...@@ -8,6 +8,14 @@ import platform ...@@ -8,6 +8,14 @@ import platform
import argparse import argparse
import json import json
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument("--ui-settings-file", type=str, default='config.json')
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.realpath(__file__)))
args, _ = parser.parse_known_args(sys.argv)
script_path = os.path.dirname(__file__)
data_path = args.data_dir
dir_repos = "repositories" dir_repos = "repositories"
dir_extensions = "extensions" dir_extensions = "extensions"
python = sys.executable python = sys.executable
...@@ -122,7 +130,7 @@ def is_installed(package): ...@@ -122,7 +130,7 @@ def is_installed(package):
def repo_dir(name): def repo_dir(name):
return os.path.join(dir_repos, name) return os.path.join(script_path, dir_repos, name)
def run_python(code, desc=None, errdesc=None): def run_python(code, desc=None, errdesc=None):
...@@ -161,7 +169,17 @@ def git_clone(url, dir, name, commithash=None): ...@@ -161,7 +169,17 @@ def git_clone(url, dir, name, commithash=None):
if commithash is not None: if commithash is not None:
run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}") run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
def git_pull_recursive(dir):
for subdir, _, _ in os.walk(dir):
if os.path.exists(os.path.join(subdir, '.git')):
try:
output = subprocess.check_output([git, '-C', subdir, 'pull', '--autostash'])
print(f"Pulled changes for repository in '{subdir}':\n{output.decode('utf-8').strip()}\n")
except subprocess.CalledProcessError as e:
print(f"Couldn't perform 'git pull' on repository in '{subdir}':\n{e.output.decode('utf-8').strip()}\n")
def version_check(commit): def version_check(commit):
try: try:
import requests import requests
...@@ -205,7 +223,7 @@ def list_extensions(settings_file): ...@@ -205,7 +223,7 @@ def list_extensions(settings_file):
disabled_extensions = set(settings.get('disabled_extensions', [])) disabled_extensions = set(settings.get('disabled_extensions', []))
return [x for x in os.listdir(dir_extensions) if x not in disabled_extensions] return [x for x in os.listdir(os.path.join(data_path, dir_extensions)) if x not in disabled_extensions]
def run_extensions_installers(settings_file): def run_extensions_installers(settings_file):
...@@ -213,7 +231,7 @@ def run_extensions_installers(settings_file): ...@@ -213,7 +231,7 @@ def run_extensions_installers(settings_file):
return return
for dirname_extension in list_extensions(settings_file): for dirname_extension in list_extensions(settings_file):
run_extension_installer(os.path.join(dir_extensions, dirname_extension)) run_extension_installer(os.path.join(data_path, dir_extensions, dirname_extension))
def prepare_environment(): def prepare_environment():
...@@ -242,11 +260,8 @@ def prepare_environment(): ...@@ -242,11 +260,8 @@ def prepare_environment():
sys.argv += shlex.split(commandline_args) sys.argv += shlex.split(commandline_args)
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default='config.json')
args, _ = parser.parse_known_args(sys.argv)
sys.argv, _ = extract_arg(sys.argv, '-f') sys.argv, _ = extract_arg(sys.argv, '-f')
sys.argv, update_all_extensions = extract_arg(sys.argv, '--update-all-extensions')
sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test') sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test')
sys.argv, skip_python_version_check = extract_arg(sys.argv, '--skip-python-version-check') sys.argv, skip_python_version_check = extract_arg(sys.argv, '--skip-python-version-check')
sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers') sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers')
...@@ -295,7 +310,7 @@ def prepare_environment(): ...@@ -295,7 +310,7 @@ def prepare_environment():
if not is_installed("pyngrok") and ngrok: if not is_installed("pyngrok") and ngrok:
run_pip("install pyngrok", "ngrok") run_pip("install pyngrok", "ngrok")
os.makedirs(dir_repos, exist_ok=True) os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash) git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
git_clone(taming_transformers_repo, repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash) git_clone(taming_transformers_repo, repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash)
...@@ -304,14 +319,19 @@ def prepare_environment(): ...@@ -304,14 +319,19 @@ def prepare_environment():
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash) git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
if not is_installed("lpips"): if not is_installed("lpips"):
run_pip(f"install -r {os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}", "requirements for CodeFormer") run_pip(f"install -r \"{os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}\"", "requirements for CodeFormer")
run_pip(f"install -r {requirements_file}", "requirements for Web UI") if not os.path.isfile(requirements_file):
requirements_file = os.path.join(script_path, requirements_file)
run_pip(f"install -r \"{requirements_file}\"", "requirements for Web UI")
run_extensions_installers(settings_file=args.ui_settings_file) run_extensions_installers(settings_file=args.ui_settings_file)
if update_check: if update_check:
version_check(commit) version_check(commit)
if update_all_extensions:
git_pull_recursive(os.path.join(data_path, dir_extensions))
if "--exit" in sys.argv: if "--exit" in sys.argv:
print("Exiting because of --exit argument") print("Exiting because of --exit argument")
...@@ -327,7 +347,7 @@ def tests(test_dir): ...@@ -327,7 +347,7 @@ def tests(test_dir):
sys.argv.append("--api") sys.argv.append("--api")
if "--ckpt" not in sys.argv: if "--ckpt" not in sys.argv:
sys.argv.append("--ckpt") sys.argv.append("--ckpt")
sys.argv.append("./test/test_files/empty.pt") sys.argv.append(os.path.join(script_path, "test/test_files/empty.pt"))
if "--skip-torch-cuda-test" not in sys.argv: if "--skip-torch-cuda-test" not in sys.argv:
sys.argv.append("--skip-torch-cuda-test") sys.argv.append("--skip-torch-cuda-test")
if "--disable-nan-check" not in sys.argv: if "--disable-nan-check" not in sys.argv:
...@@ -336,7 +356,7 @@ def tests(test_dir): ...@@ -336,7 +356,7 @@ def tests(test_dir):
print(f"Launching Web UI in another process for testing with arguments: {' '.join(sys.argv[1:])}") print(f"Launching Web UI in another process for testing with arguments: {' '.join(sys.argv[1:])}")
os.environ['COMMANDLINE_ARGS'] = "" os.environ['COMMANDLINE_ARGS'] = ""
with open('test/stdout.txt', "w", encoding="utf8") as stdout, open('test/stderr.txt', "w", encoding="utf8") as stderr: with open(os.path.join(script_path, 'test/stdout.txt'), "w", encoding="utf8") as stdout, open(os.path.join(script_path, 'test/stderr.txt'), "w", encoding="utf8") as stderr:
proc = subprocess.Popen([sys.executable, *sys.argv], stdout=stdout, stderr=stderr) proc = subprocess.Popen([sys.executable, *sys.argv], stdout=stdout, stderr=stderr)
import test.server_poll import test.server_poll
......
This diff is collapsed.
...@@ -14,8 +14,8 @@ API_NOT_ALLOWED = [ ...@@ -14,8 +14,8 @@ API_NOT_ALLOWED = [
"outpath_samples", "outpath_samples",
"outpath_grids", "outpath_grids",
"sampler_index", "sampler_index",
"do_not_save_samples", # "do_not_save_samples",
"do_not_save_grid", # "do_not_save_grid",
"extra_generation_params", "extra_generation_params",
"overlay_images", "overlay_images",
"do_not_reload_embeddings", "do_not_reload_embeddings",
...@@ -100,13 +100,31 @@ class PydanticModelGenerator: ...@@ -100,13 +100,31 @@ class PydanticModelGenerator:
StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator( StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
"StableDiffusionProcessingTxt2Img", "StableDiffusionProcessingTxt2Img",
StableDiffusionProcessingTxt2Img, StableDiffusionProcessingTxt2Img,
[{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}] [
{"key": "sampler_index", "type": str, "default": "Euler"},
{"key": "script_name", "type": str, "default": None},
{"key": "script_args", "type": list, "default": []},
{"key": "send_images", "type": bool, "default": True},
{"key": "save_images", "type": bool, "default": False},
{"key": "alwayson_scripts", "type": dict, "default": {}},
]
).generate_model() ).generate_model()
StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator( StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
"StableDiffusionProcessingImg2Img", "StableDiffusionProcessingImg2Img",
StableDiffusionProcessingImg2Img, StableDiffusionProcessingImg2Img,
[{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}] [
{"key": "sampler_index", "type": str, "default": "Euler"},
{"key": "init_images", "type": list, "default": None},
{"key": "denoising_strength", "type": float, "default": 0.75},
{"key": "mask", "type": str, "default": None},
{"key": "include_init_images", "type": bool, "default": False, "exclude" : True},
{"key": "script_name", "type": str, "default": None},
{"key": "script_args", "type": list, "default": []},
{"key": "send_images", "type": bool, "default": True},
{"key": "save_images", "type": bool, "default": False},
{"key": "alwayson_scripts", "type": dict, "default": {}},
]
).generate_model() ).generate_model()
class TextToImageResponse(BaseModel): class TextToImageResponse(BaseModel):
...@@ -267,3 +285,7 @@ class EmbeddingsResponse(BaseModel): ...@@ -267,3 +285,7 @@ class EmbeddingsResponse(BaseModel):
class MemoryResponse(BaseModel): class MemoryResponse(BaseModel):
ram: dict = Field(title="RAM", description="System memory stats") ram: dict = Field(title="RAM", description="System memory stats")
cuda: dict = Field(title="CUDA", description="nVidia CUDA memory stats") cuda: dict = Field(title="CUDA", description="nVidia CUDA memory stats")
class ScriptsList(BaseModel):
txt2img: list = Field(default=None,title="Txt2img", description="Titles of scripts (txt2img)")
img2img: list = Field(default=None,title="Img2img", description="Titles of scripts (img2img)")
\ No newline at end of file
...@@ -55,7 +55,7 @@ def setup_model(dirname): ...@@ -55,7 +55,7 @@ def setup_model(dirname):
if self.net is not None and self.face_helper is not None: if self.net is not None and self.face_helper is not None:
self.net.to(devices.device_codeformer) self.net.to(devices.device_codeformer)
return self.net, self.face_helper return self.net, self.face_helper
model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth') model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth', ext_filter=['.pth'])
if len(model_paths) != 0: if len(model_paths) != 0:
ckpt_path = model_paths[0] ckpt_path = model_paths[0]
else: else:
......
...@@ -66,7 +66,7 @@ class Extension: ...@@ -66,7 +66,7 @@ class Extension:
def check_updates(self): def check_updates(self):
repo = git.Repo(self.path) repo = git.Repo(self.path)
for fetch in repo.remote().fetch("--dry-run"): for fetch in repo.remote().fetch(dry_run=True):
if fetch.flags != fetch.HEAD_UPTODATE: if fetch.flags != fetch.HEAD_UPTODATE:
self.can_update = True self.can_update = True
self.status = "behind" self.status = "behind"
...@@ -79,8 +79,8 @@ class Extension: ...@@ -79,8 +79,8 @@ class Extension:
repo = git.Repo(self.path) repo = git.Repo(self.path)
# Fix: `error: Your local changes to the following files would be overwritten by merge`, # Fix: `error: Your local changes to the following files would be overwritten by merge`,
# because WSL2 Docker set 755 file permissions instead of 644, this results to the error. # because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
repo.git.fetch('--all') repo.git.fetch(all=True)
repo.git.reset('--hard', 'origin') repo.git.reset('origin', hard=True)
def list_extensions(): def list_extensions():
......
...@@ -23,13 +23,14 @@ registered_param_bindings = [] ...@@ -23,13 +23,14 @@ registered_param_bindings = []
class ParamBinding: class ParamBinding:
def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None): def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=[]):
self.paste_button = paste_button self.paste_button = paste_button
self.tabname = tabname self.tabname = tabname
self.source_text_component = source_text_component self.source_text_component = source_text_component
self.source_image_component = source_image_component self.source_image_component = source_image_component
self.source_tabname = source_tabname self.source_tabname = source_tabname
self.override_settings_component = override_settings_component self.override_settings_component = override_settings_component
self.paste_field_names = paste_field_names
def reset(): def reset():
...@@ -134,7 +135,7 @@ def connect_paste_params_buttons(): ...@@ -134,7 +135,7 @@ def connect_paste_params_buttons():
connect_paste(binding.paste_button, fields, binding.source_text_component, override_settings_component, binding.tabname) connect_paste(binding.paste_button, fields, binding.source_text_component, override_settings_component, binding.tabname)
if binding.source_tabname is not None and fields is not None: if binding.source_tabname is not None and fields is not None:
paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else []) paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else []) + binding.paste_field_names
binding.paste_button.click( binding.paste_button.click(
fn=lambda *x: x, fn=lambda *x: x,
inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names], inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names],
...@@ -288,6 +289,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model ...@@ -288,6 +289,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
settings_map = {} settings_map = {}
infotext_to_setting_name_mapping = [ infotext_to_setting_name_mapping = [
('Clip skip', 'CLIP_stop_at_last_layers', ), ('Clip skip', 'CLIP_stop_at_last_layers', ),
('Conditional mask weight', 'inpainting_mask_weight'), ('Conditional mask weight', 'inpainting_mask_weight'),
...@@ -296,7 +299,11 @@ infotext_to_setting_name_mapping = [ ...@@ -296,7 +299,11 @@ infotext_to_setting_name_mapping = [
('Noise multiplier', 'initial_noise_multiplier'), ('Noise multiplier', 'initial_noise_multiplier'),
('Eta', 'eta_ancestral'), ('Eta', 'eta_ancestral'),
('Eta DDIM', 'eta_ddim'), ('Eta DDIM', 'eta_ddim'),
('Discard penultimate sigma', 'always_discard_next_to_last_sigma') ('Discard penultimate sigma', 'always_discard_next_to_last_sigma'),
('UniPC variant', 'uni_pc_variant'),
('UniPC skip type', 'uni_pc_skip_type'),
('UniPC order', 'uni_pc_order'),
('UniPC lower order final', 'uni_pc_lower_order_final'),
] ]
...@@ -394,9 +401,14 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component, ...@@ -394,9 +401,14 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
button.click( button.click(
fn=paste_func, fn=paste_func,
_js=f"recalculate_prompts_{tabname}",
inputs=[input_comp], inputs=[input_comp],
outputs=[x[0] for x in paste_fields], outputs=[x[0] for x in paste_fields],
) )
button.click(
fn=None,
_js=f"recalculate_prompts_{tabname}",
inputs=[],
outputs=[],
)
...@@ -556,7 +556,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i ...@@ -556,7 +556,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
elif image_to_save.mode == 'I;16': elif image_to_save.mode == 'I;16':
image_to_save = image_to_save.point(lambda p: p * 0.0038910505836576).convert("RGB" if extension.lower() == ".webp" else "L") image_to_save = image_to_save.point(lambda p: p * 0.0038910505836576).convert("RGB" if extension.lower() == ".webp" else "L")
image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality) image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality, lossless=opts.webp_lossless)
if opts.enable_pnginfo and info is not None: if opts.enable_pnginfo and info is not None:
exif_bytes = piexif.dump({ exif_bytes = piexif.dump({
...@@ -573,6 +573,11 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i ...@@ -573,6 +573,11 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
os.replace(temp_file_path, filename_without_extension + extension) os.replace(temp_file_path, filename_without_extension + extension)
fullfn_without_extension, extension = os.path.splitext(params.filename) fullfn_without_extension, extension = os.path.splitext(params.filename)
if hasattr(os, 'statvfs'):
max_name_len = os.statvfs(path).f_namemax
fullfn_without_extension = fullfn_without_extension[:max_name_len - max(4, len(extension))]
params.filename = fullfn_without_extension + extension
fullfn = params.filename
_atomically_save_image(image, fullfn_without_extension, extension) _atomically_save_image(image, fullfn_without_extension, extension)
image.already_saved_as = fullfn image.already_saved_as = fullfn
...@@ -582,9 +587,9 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i ...@@ -582,9 +587,9 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
ratio = image.width / image.height ratio = image.width / image.height
if oversize and ratio > 1: if oversize and ratio > 1:
image = image.resize((opts.target_side_length, image.height * opts.target_side_length // image.width), LANCZOS) image = image.resize((round(opts.target_side_length), round(image.height * opts.target_side_length / image.width)), LANCZOS)
elif oversize: elif oversize:
image = image.resize((image.width * opts.target_side_length // image.height, opts.target_side_length), LANCZOS) image = image.resize((round(image.width * opts.target_side_length / image.height), round(opts.target_side_length)), LANCZOS)
try: try:
_atomically_save_image(image, fullfn_without_extension, ".jpg") _atomically_save_image(image, fullfn_without_extension, ".jpg")
...@@ -640,6 +645,8 @@ Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]} ...@@ -640,6 +645,8 @@ Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}
def image_data(data): def image_data(data):
import gradio as gr
try: try:
image = Image.open(io.BytesIO(data)) image = Image.open(io.BytesIO(data))
textinfo, _ = read_info_from_image(image) textinfo, _ = read_info_from_image(image)
...@@ -655,7 +662,7 @@ def image_data(data): ...@@ -655,7 +662,7 @@ def image_data(data):
except Exception: except Exception:
pass pass
return '', None return gr.update(), None
def flatten(img, bgcolor): def flatten(img, bgcolor):
......
import torch import torch
import platform
from modules import paths from modules import paths
from modules.sd_hijack_utils import CondFunc from modules.sd_hijack_utils import CondFunc
from packaging import version from packaging import version
...@@ -23,7 +24,7 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs): ...@@ -23,7 +24,7 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs):
output_dtype = kwargs.get('dtype', input.dtype) output_dtype = kwargs.get('dtype', input.dtype)
if output_dtype == torch.int64: if output_dtype == torch.int64:
return cumsum_func(input.cpu(), *args, **kwargs).to(input.device) return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
elif cumsum_needs_bool_fix and output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16): elif output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64) return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
return cumsum_func(input, *args, **kwargs) return cumsum_func(input, *args, **kwargs)
...@@ -32,6 +33,10 @@ if has_mps: ...@@ -32,6 +33,10 @@ if has_mps:
# MPS fix for randn in torchsde # MPS fix for randn in torchsde
CondFunc('torchsde._brownian.brownian_interval._randn', lambda _, size, dtype, device, seed: torch.randn(size, dtype=dtype, device=torch.device("cpu"), generator=torch.Generator(torch.device("cpu")).manual_seed(int(seed))).to(device), lambda _, size, dtype, device, seed: device.type == 'mps') CondFunc('torchsde._brownian.brownian_interval._randn', lambda _, size, dtype, device, seed: torch.randn(size, dtype=dtype, device=torch.device("cpu"), generator=torch.Generator(torch.device("cpu")).manual_seed(int(seed))).to(device), lambda _, size, dtype, device, seed: device.type == 'mps')
if platform.mac_ver()[0].startswith("13.2."):
# MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124)
CondFunc('torch.nn.functional.linear', lambda _, input, weight, bias: (torch.matmul(input, weight.t()) + bias) if bias is not None else torch.matmul(input, weight.t()), lambda _, input, weight, bias: input.numel() > 10485760)
if version.parse(torch.__version__) < version.parse("1.13"): if version.parse(torch.__version__) < version.parse("1.13"):
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working # PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
...@@ -45,9 +50,10 @@ if has_mps: ...@@ -45,9 +50,10 @@ if has_mps:
CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad) CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad)
elif version.parse(torch.__version__) > version.parse("1.13.1"): elif version.parse(torch.__version__) > version.parse("1.13.1"):
cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0)) cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0))
cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs) cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs)
CondFunc('torch.cumsum', cumsum_fix_func, None) CondFunc('torch.cumsum', cumsum_fix_func, None)
CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None) CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None) CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
if version.parse(torch.__version__) == version.parse("2.0"):
# MPS workaround for https://github.com/pytorch/pytorch/issues/96113
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda *args, **kwargs: len(args) == 6)
...@@ -23,12 +23,16 @@ class MemUsageMonitor(threading.Thread): ...@@ -23,12 +23,16 @@ class MemUsageMonitor(threading.Thread):
self.data = defaultdict(int) self.data = defaultdict(int)
try: try:
torch.cuda.mem_get_info() self.cuda_mem_get_info()
torch.cuda.memory_stats(self.device) torch.cuda.memory_stats(self.device)
except Exception as e: # AMD or whatever except Exception as e: # AMD or whatever
print(f"Warning: caught exception '{e}', memory monitor disabled") print(f"Warning: caught exception '{e}', memory monitor disabled")
self.disabled = True self.disabled = True
def cuda_mem_get_info(self):
index = self.device.index if self.device.index is not None else torch.cuda.current_device()
return torch.cuda.mem_get_info(index)
def run(self): def run(self):
if self.disabled: if self.disabled:
return return
...@@ -43,10 +47,10 @@ class MemUsageMonitor(threading.Thread): ...@@ -43,10 +47,10 @@ class MemUsageMonitor(threading.Thread):
self.run_flag.clear() self.run_flag.clear()
continue continue
self.data["min_free"] = torch.cuda.mem_get_info()[0] self.data["min_free"] = self.cuda_mem_get_info()[0]
while self.run_flag.is_set(): while self.run_flag.is_set():
free, total = torch.cuda.mem_get_info() # calling with self.device errors, torch bug? free, total = self.cuda_mem_get_info()
self.data["min_free"] = min(self.data["min_free"], free) self.data["min_free"] = min(self.data["min_free"], free)
time.sleep(1 / self.opts.memmon_poll_rate) time.sleep(1 / self.opts.memmon_poll_rate)
...@@ -70,7 +74,7 @@ class MemUsageMonitor(threading.Thread): ...@@ -70,7 +74,7 @@ class MemUsageMonitor(threading.Thread):
def read(self): def read(self):
if not self.disabled: if not self.disabled:
free, total = torch.cuda.mem_get_info() free, total = self.cuda_mem_get_info()
self.data["free"] = free self.data["free"] = free
self.data["total"] = total self.data["total"] = total
......
...@@ -4,9 +4,8 @@ import shutil ...@@ -4,9 +4,8 @@ import shutil
import importlib import importlib
from urllib.parse import urlparse from urllib.parse import urlparse
from basicsr.utils.download_util import load_file_from_url
from modules import shared from modules import shared
from modules.upscaler import Upscaler from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
from modules.paths import script_path, models_path from modules.paths import script_path, models_path
...@@ -59,6 +58,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None ...@@ -59,6 +58,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
if model_url is not None and len(output) == 0: if model_url is not None and len(output) == 0:
if download_name is not None: if download_name is not None:
from basicsr.utils.download_util import load_file_from_url
dl = load_file_from_url(model_url, model_path, True, download_name) dl = load_file_from_url(model_url, model_path, True, download_name)
output.append(dl) output.append(dl)
else: else:
...@@ -169,4 +169,8 @@ def load_upscalers(): ...@@ -169,4 +169,8 @@ def load_upscalers():
scaler = cls(commandline_options.get(cmd_name, None)) scaler = cls(commandline_options.get(cmd_name, None))
datas += scaler.scalers datas += scaler.scalers
shared.sd_upscalers = datas shared.sd_upscalers = sorted(
datas,
# Special case for UpscalerNone keeps it at the beginning of the list.
key=lambda x: x.name.lower() if not isinstance(x.scaler, (UpscalerNone, UpscalerLanczos, UpscalerNearest)) else ""
)
from .sampler import UniPCSampler
"""SAMPLING ONLY."""
import torch
from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC
from modules import shared, devices
class UniPCSampler(object):
def __init__(self, model, **kwargs):
super().__init__()
self.model = model
to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
self.before_sample = None
self.after_sample = None
self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != devices.device:
attr = attr.to(devices.device)
setattr(self, name, attr)
def set_hooks(self, before_sample, after_sample, after_update):
self.before_sample = before_sample
self.after_sample = after_sample
self.after_update = after_update
@torch.no_grad()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs
):
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list): ctmp = ctmp[0]
cbs = ctmp.shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
elif isinstance(conditioning, list):
for ctmp in conditioning:
if ctmp.shape[0] != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
# print(f'Data shape for UniPC sampling is {size}')
device = self.model.betas.device
if x_T is None:
img = torch.randn(size, device=device)
else:
img = x_T
ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
# SD 1.X is "noise", SD 2.X is "v"
model_type = "v" if self.model.parameterization == "v" else "noise"
model_fn = model_wrapper(
lambda x, t, c: self.model.apply_model(x, t, c),
ns,
model_type=model_type,
guidance_type="classifier-free",
#condition=conditioning,
#unconditional_condition=unconditional_conditioning,
guidance_scale=unconditional_guidance_scale,
)
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, variant=shared.opts.uni_pc_variant, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample, after_update=self.after_update)
x = uni_pc.sample(img, steps=S, skip_type=shared.opts.uni_pc_skip_type, method="multistep", order=shared.opts.uni_pc_order, lower_order_final=shared.opts.uni_pc_lower_order_final)
return x.to(device), None
This diff is collapsed.
...@@ -583,6 +583,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: ...@@ -583,6 +583,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if state.job_count == -1: if state.job_count == -1:
state.job_count = p.n_iter state.job_count = p.n_iter
extra_network_data = None
for n in range(p.n_iter): for n in range(p.n_iter):
p.iteration = n p.iteration = n
...@@ -597,6 +598,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: ...@@ -597,6 +598,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size] seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size] subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
if p.scripts is not None:
p.scripts.before_process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
if len(prompts) == 0: if len(prompts) == 0:
break break
...@@ -685,6 +689,22 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: ...@@ -685,6 +689,22 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
image.info["parameters"] = text image.info["parameters"] = text
output_images.append(image) output_images.append(image)
if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay:
image_mask = p.mask_for_overlay.convert('RGB')
image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), p.mask_for_overlay.convert('L')).convert('RGBA')
if opts.save_mask:
images.save_image(image_mask, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask")
if opts.save_mask_composite:
images.save_image(image_mask_composite, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask-composite")
if opts.return_mask:
output_images.append(image_mask)
if opts.return_mask_composite:
output_images.append(image_mask_composite)
del x_samples_ddim del x_samples_ddim
devices.torch_gc() devices.torch_gc()
...@@ -709,7 +729,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: ...@@ -709,7 +729,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if opts.grid_save: if opts.grid_save:
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True) images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
if not p.disable_extra_networks: if not p.disable_extra_networks and extra_network_data:
extra_networks.deactivate(p, extra_network_data) extra_networks.deactivate(p, extra_network_data)
devices.torch_gc() devices.torch_gc()
...@@ -888,7 +908,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): ...@@ -888,7 +908,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
shared.state.nextjob() shared.state.nextjob()
img2img_sampler_name = self.sampler_name if self.sampler_name != 'PLMS' else 'DDIM' # PLMS does not support img2img so we just silently switch ot DDIM img2img_sampler_name = self.sampler_name
if self.sampler_name in ['PLMS', 'UniPC']: # PLMS/UniPC do not support img2img so we just silently switch to DDIM
img2img_sampler_name = 'DDIM'
self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model) self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2] samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
......
...@@ -29,7 +29,7 @@ class ImageSaveParams: ...@@ -29,7 +29,7 @@ class ImageSaveParams:
class CFGDenoiserParams: class CFGDenoiserParams:
def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps): def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond):
self.x = x self.x = x
"""Latent image representation in the process of being denoised""" """Latent image representation in the process of being denoised"""
...@@ -44,6 +44,12 @@ class CFGDenoiserParams: ...@@ -44,6 +44,12 @@ class CFGDenoiserParams:
self.total_sampling_steps = total_sampling_steps self.total_sampling_steps = total_sampling_steps
"""Total number of sampling steps planned""" """Total number of sampling steps planned"""
self.text_cond = text_cond
""" Encoder hidden states of text conditioning from prompt"""
self.text_uncond = text_uncond
""" Encoder hidden states of text conditioning from negative prompt"""
class CFGDenoisedParams: class CFGDenoisedParams:
......
...@@ -33,6 +33,11 @@ class Script: ...@@ -33,6 +33,11 @@ class Script:
parsing infotext to set the value for the component; see ui.py's txt2img_paste_fields for an example parsing infotext to set the value for the component; see ui.py's txt2img_paste_fields for an example
""" """
paste_field_names = None
"""if set in ui(), this is a list of names of infotext fields; the fields will be sent through the
various "Send to <X>" buttons when clicked
"""
def title(self): def title(self):
"""this function should return the title of the script. This is what will be displayed in the dropdown menu.""" """this function should return the title of the script. This is what will be displayed in the dropdown menu."""
...@@ -80,6 +85,20 @@ class Script: ...@@ -80,6 +85,20 @@ class Script:
pass pass
def before_process_batch(self, p, *args, **kwargs):
"""
Called before extra networks are parsed from the prompt, so you can add
new extra network keywords to the prompt with this callback.
**kwargs will have those items:
- batch_number - index of current batch, from 0 to number of batches-1
- prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
- seeds - list of seeds for current batch
- subseeds - list of subseeds for current batch
"""
pass
def process_batch(self, p, *args, **kwargs): def process_batch(self, p, *args, **kwargs):
""" """
Same as process(), but called for every batch. Same as process(), but called for every batch.
...@@ -220,7 +239,15 @@ def load_scripts(): ...@@ -220,7 +239,15 @@ def load_scripts():
elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing): elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing):
postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module)) postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
for scriptfile in sorted(scripts_list): def orderby(basedir):
# 1st webui, 2nd extensions-builtin, 3rd extensions
priority = {os.path.join(paths.script_path, "extensions-builtin"):1, paths.script_path:0}
for key in priority:
if basedir.startswith(key):
return priority[key]
return 9999
for scriptfile in sorted(scripts_list, key=lambda x: [orderby(x.basedir), x]):
try: try:
if scriptfile.basedir != paths.script_path: if scriptfile.basedir != paths.script_path:
sys.path = [scriptfile.basedir] + sys.path sys.path = [scriptfile.basedir] + sys.path
...@@ -256,6 +283,7 @@ class ScriptRunner: ...@@ -256,6 +283,7 @@ class ScriptRunner:
self.alwayson_scripts = [] self.alwayson_scripts = []
self.titles = [] self.titles = []
self.infotext_fields = [] self.infotext_fields = []
self.paste_field_names = []
def initialize_scripts(self, is_img2img): def initialize_scripts(self, is_img2img):
from modules import scripts_auto_postprocessing from modules import scripts_auto_postprocessing
...@@ -304,6 +332,9 @@ class ScriptRunner: ...@@ -304,6 +332,9 @@ class ScriptRunner:
if script.infotext_fields is not None: if script.infotext_fields is not None:
self.infotext_fields += script.infotext_fields self.infotext_fields += script.infotext_fields
if script.paste_field_names is not None:
self.paste_field_names += script.paste_field_names
inputs += controls inputs += controls
inputs_alwayson += [script.alwayson for _ in controls] inputs_alwayson += [script.alwayson for _ in controls]
script.args_to = len(inputs) script.args_to = len(inputs)
...@@ -388,6 +419,15 @@ class ScriptRunner: ...@@ -388,6 +419,15 @@ class ScriptRunner:
print(f"Error running process: {script.filename}", file=sys.stderr) print(f"Error running process: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
def before_process_batch(self, p, **kwargs):
for script in self.alwayson_scripts:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.before_process_batch(p, *script_args, **kwargs)
except Exception:
print(f"Error running before_process_batch: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
def process_batch(self, p, **kwargs): def process_batch(self, p, **kwargs):
for script in self.alwayson_scripts: for script in self.alwayson_scripts:
try: try:
...@@ -481,6 +521,18 @@ def reload_scripts(): ...@@ -481,6 +521,18 @@ def reload_scripts():
scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner() scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
def add_classes_to_gradio_component(comp):
"""
this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
"""
comp.elem_classes = ["gradio-" + comp.get_block_name(), *(comp.elem_classes or [])]
if getattr(comp, 'multiselect', False):
comp.elem_classes.append('multiselect')
def IOComponent_init(self, *args, **kwargs): def IOComponent_init(self, *args, **kwargs):
if scripts_current is not None: if scripts_current is not None:
scripts_current.before_component(self, **kwargs) scripts_current.before_component(self, **kwargs)
...@@ -489,6 +541,8 @@ def IOComponent_init(self, *args, **kwargs): ...@@ -489,6 +541,8 @@ def IOComponent_init(self, *args, **kwargs):
res = original_IOComponent_init(self, *args, **kwargs) res = original_IOComponent_init(self, *args, **kwargs)
add_classes_to_gradio_component(self)
script_callbacks.after_component_callback(self, **kwargs) script_callbacks.after_component_callback(self, **kwargs)
if scripts_current is not None: if scripts_current is not None:
......
...@@ -109,7 +109,7 @@ class ScriptPostprocessingRunner: ...@@ -109,7 +109,7 @@ class ScriptPostprocessingRunner:
inputs = [] inputs = []
for script in self.scripts_in_preferred_order(): for script in self.scripts_in_preferred_order():
with gr.Box() as group: with gr.Row() as group:
self.create_script_ui(script, inputs) self.create_script_ui(script, inputs)
script.group = group script.group = group
......
...@@ -37,11 +37,23 @@ def apply_optimizations(): ...@@ -37,11 +37,23 @@ def apply_optimizations():
optimization_method = None optimization_method = None
can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(getattr(torch.nn.functional, "scaled_dot_product_attention")) # not everyone has torch 2.x to use sdp
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)): if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
print("Applying xformers cross attention optimization.") print("Applying xformers cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
optimization_method = 'xformers' optimization_method = 'xformers'
elif cmd_opts.opt_sdp_no_mem_attention and can_use_sdp:
print("Applying scaled dot product cross attention optimization (without memory efficient attention).")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_no_mem_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sdp_no_mem_attnblock_forward
optimization_method = 'sdp-no-mem'
elif cmd_opts.opt_sdp_attention and can_use_sdp:
print("Applying scaled dot product cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sdp_attnblock_forward
optimization_method = 'sdp'
elif cmd_opts.opt_sub_quad_attention: elif cmd_opts.opt_sub_quad_attention:
print("Applying sub-quadratic cross attention optimization.") print("Applying sub-quadratic cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward
......
...@@ -337,7 +337,7 @@ def xformers_attention_forward(self, x, context=None, mask=None): ...@@ -337,7 +337,7 @@ def xformers_attention_forward(self, x, context=None, mask=None):
dtype = q.dtype dtype = q.dtype
if shared.opts.upcast_attn: if shared.opts.upcast_attn:
q, k = q.float(), k.float() q, k, v = q.float(), k.float(), v.float()
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v)) out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v))
...@@ -346,6 +346,52 @@ def xformers_attention_forward(self, x, context=None, mask=None): ...@@ -346,6 +346,52 @@ def xformers_attention_forward(self, x, context=None, mask=None):
out = rearrange(out, 'b n h d -> b n (h d)', h=h) out = rearrange(out, 'b n h d -> b n (h d)', h=h)
return self.to_out(out) return self.to_out(out)
# Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py
# The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface
def scaled_dot_product_attention_forward(self, x, context=None, mask=None):
batch_size, sequence_length, inner_dim = x.shape
if mask is not None:
mask = self.prepare_attention_mask(mask, sequence_length, batch_size)
mask = mask.view(batch_size, self.heads, -1, mask.shape[-1])
h = self.heads
q_in = self.to_q(x)
context = default(context, x)
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
k_in = self.to_k(context_k)
v_in = self.to_v(context_v)
head_dim = inner_dim // h
q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
del q_in, k_in, v_in
dtype = q.dtype
if shared.opts.upcast_attn:
q, k, v = q.float(), k.float(), v.float()
# the output of sdp = (batch, num_heads, seq_len, head_dim)
hidden_states = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, h * head_dim)
hidden_states = hidden_states.to(dtype)
# linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)
return hidden_states
def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None):
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
return scaled_dot_product_attention_forward(self, x, context, mask)
def cross_attention_attnblock_forward(self, x): def cross_attention_attnblock_forward(self, x):
h_ = x h_ = x
h_ = self.norm(h_) h_ = self.norm(h_)
...@@ -427,6 +473,30 @@ def xformers_attnblock_forward(self, x): ...@@ -427,6 +473,30 @@ def xformers_attnblock_forward(self, x):
except NotImplementedError: except NotImplementedError:
return cross_attention_attnblock_forward(self, x) return cross_attention_attnblock_forward(self, x)
def sdp_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
b, c, h, w = q.shape
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
dtype = q.dtype
if shared.opts.upcast_attn:
q, k = q.float(), k.float()
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
out = out.to(dtype)
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
out = self.proj_out(out)
return x + out
def sdp_no_mem_attnblock_forward(self, x):
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
return sdp_attnblock_forward(self, x)
def sub_quad_attnblock_forward(self, x): def sub_quad_attnblock_forward(self, x):
h_ = x h_ = x
h_ = self.norm(h_) h_ = self.norm(h_)
......
...@@ -67,7 +67,7 @@ def hijack_ddpm_edit(): ...@@ -67,7 +67,7 @@ def hijack_ddpm_edit():
unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast) CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast) CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
if version.parse(torch.__version__) <= version.parse("1.13.1"): if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available():
CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast) CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast) CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU) CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)
......
...@@ -105,9 +105,15 @@ def checkpoint_tiles(): ...@@ -105,9 +105,15 @@ def checkpoint_tiles():
def list_models(): def list_models():
checkpoints_list.clear() checkpoints_list.clear()
checkpoint_alisases.clear() checkpoint_alisases.clear()
model_list = modelloader.load_models(model_path=model_path, model_url="https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors", command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], download_name="v1-5-pruned-emaonly.safetensors", ext_blacklist=[".vae.ckpt", ".vae.safetensors"])
cmd_ckpt = shared.cmd_opts.ckpt cmd_ckpt = shared.cmd_opts.ckpt
if shared.cmd_opts.no_download_sd_model or cmd_ckpt != shared.sd_model_file or os.path.exists(cmd_ckpt):
model_url = None
else:
model_url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors"
model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], download_name="v1-5-pruned-emaonly.safetensors", ext_blacklist=[".vae.ckpt", ".vae.safetensors"])
if os.path.exists(cmd_ckpt): if os.path.exists(cmd_ckpt):
checkpoint_info = CheckpointInfo(cmd_ckpt) checkpoint_info = CheckpointInfo(cmd_ckpt)
checkpoint_info.register() checkpoint_info.register()
...@@ -172,7 +178,7 @@ def select_checkpoint(): ...@@ -172,7 +178,7 @@ def select_checkpoint():
return checkpoint_info return checkpoint_info
chckpoint_dict_replacements = { checkpoint_dict_replacements = {
'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.', 'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.', 'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.', 'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.',
...@@ -180,7 +186,7 @@ chckpoint_dict_replacements = { ...@@ -180,7 +186,7 @@ chckpoint_dict_replacements = {
def transform_checkpoint_dict_key(k): def transform_checkpoint_dict_key(k):
for text, replacement in chckpoint_dict_replacements.items(): for text, replacement in checkpoint_dict_replacements.items():
if k.startswith(text): if k.startswith(text):
k = replacement + k[len(text):] k = replacement + k[len(text):]
...@@ -204,6 +210,30 @@ def get_state_dict_from_checkpoint(pl_sd): ...@@ -204,6 +210,30 @@ def get_state_dict_from_checkpoint(pl_sd):
return pl_sd return pl_sd
def read_metadata_from_safetensors(filename):
import json
with open(filename, mode="rb") as file:
metadata_len = file.read(8)
metadata_len = int.from_bytes(metadata_len, "little")
json_start = file.read(2)
assert metadata_len > 2 and json_start in (b'{"', b"{'"), f"{filename} is not a safetensors file"
json_data = json_start + file.read(metadata_len-2)
json_obj = json.loads(json_data)
res = {}
for k, v in json_obj.get("__metadata__", {}).items():
res[k] = v
if isinstance(v, str) and v[0:1] == '{':
try:
res[k] = json.loads(v)
except Exception as e:
pass
return res
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None): def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
_, extension = os.path.splitext(checkpoint_file) _, extension = os.path.splitext(checkpoint_file)
if extension.lower() == ".safetensors": if extension.lower() == ".safetensors":
...@@ -464,7 +494,7 @@ def reload_model_weights(sd_model=None, info=None): ...@@ -464,7 +494,7 @@ def reload_model_weights(sd_model=None, info=None):
if sd_model is None or checkpoint_config != sd_model.used_config: if sd_model is None or checkpoint_config != sd_model.used_config:
del sd_model del sd_model
checkpoints_loaded.clear() checkpoints_loaded.clear()
load_model(checkpoint_info, already_loaded_state_dict=state_dict, time_taken_to_load_state_dict=timer.records["load weights from disk"]) load_model(checkpoint_info, already_loaded_state_dict=state_dict)
return shared.sd_model return shared.sd_model
try: try:
...@@ -487,3 +517,23 @@ def reload_model_weights(sd_model=None, info=None): ...@@ -487,3 +517,23 @@ def reload_model_weights(sd_model=None, info=None):
print(f"Weights loaded in {timer.summary()}.") print(f"Weights loaded in {timer.summary()}.")
return sd_model return sd_model
def unload_model_weights(sd_model=None, info=None):
from modules import lowvram, devices, sd_hijack
timer = Timer()
if shared.sd_model:
# shared.sd_model.cond_stage_model.to(devices.cpu)
# shared.sd_model.first_stage_model.to(devices.cpu)
shared.sd_model.to(devices.cpu)
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
shared.sd_model = None
sd_model = None
gc.collect()
devices.torch_gc()
torch.cuda.empty_cache()
print(f"Unloaded weights {timer.summary()}.")
return sd_model
\ No newline at end of file
...@@ -32,7 +32,7 @@ def set_samplers(): ...@@ -32,7 +32,7 @@ def set_samplers():
global samplers, samplers_for_img2img global samplers, samplers_for_img2img
hidden = set(shared.opts.hide_samplers) hidden = set(shared.opts.hide_samplers)
hidden_img2img = set(shared.opts.hide_samplers + ['PLMS']) hidden_img2img = set(shared.opts.hide_samplers + ['PLMS', 'UniPC'])
samplers = [x for x in all_samplers if x.name not in hidden] samplers = [x for x in all_samplers if x.name not in hidden]
samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img] samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
......
...@@ -7,19 +7,27 @@ import torch ...@@ -7,19 +7,27 @@ import torch
from modules.shared import state from modules.shared import state
from modules import sd_samplers_common, prompt_parser, shared from modules import sd_samplers_common, prompt_parser, shared
import modules.models.diffusion.uni_pc
samplers_data_compvis = [ samplers_data_compvis = [
sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}), sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}), sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
sd_samplers_common.SamplerData('UniPC', lambda model: VanillaStableDiffusionSampler(modules.models.diffusion.uni_pc.UniPCSampler, model), [], {}),
] ]
class VanillaStableDiffusionSampler: class VanillaStableDiffusionSampler:
def __init__(self, constructor, sd_model): def __init__(self, constructor, sd_model):
self.sampler = constructor(sd_model) self.sampler = constructor(sd_model)
self.is_ddim = hasattr(self.sampler, 'p_sample_ddim')
self.is_plms = hasattr(self.sampler, 'p_sample_plms') self.is_plms = hasattr(self.sampler, 'p_sample_plms')
self.orig_p_sample_ddim = self.sampler.p_sample_plms if self.is_plms else self.sampler.p_sample_ddim self.is_unipc = isinstance(self.sampler, modules.models.diffusion.uni_pc.UniPCSampler)
self.orig_p_sample_ddim = None
if self.is_plms:
self.orig_p_sample_ddim = self.sampler.p_sample_plms
elif self.is_ddim:
self.orig_p_sample_ddim = self.sampler.p_sample_ddim
self.mask = None self.mask = None
self.nmask = None self.nmask = None
self.init_latent = None self.init_latent = None
...@@ -45,6 +53,15 @@ class VanillaStableDiffusionSampler: ...@@ -45,6 +53,15 @@ class VanillaStableDiffusionSampler:
return self.last_latent return self.last_latent
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs): def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
x_dec, ts, cond, unconditional_conditioning = self.before_sample(x_dec, ts, cond, unconditional_conditioning)
res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
x_dec, ts, cond, unconditional_conditioning, res = self.after_sample(x_dec, ts, cond, unconditional_conditioning, res)
return res
def before_sample(self, x, ts, cond, unconditional_conditioning):
if state.interrupted or state.skipped: if state.interrupted or state.skipped:
raise sd_samplers_common.InterruptedException raise sd_samplers_common.InterruptedException
...@@ -76,7 +93,7 @@ class VanillaStableDiffusionSampler: ...@@ -76,7 +93,7 @@ class VanillaStableDiffusionSampler:
if self.mask is not None: if self.mask is not None:
img_orig = self.sampler.model.q_sample(self.init_latent, ts) img_orig = self.sampler.model.q_sample(self.init_latent, ts)
x_dec = img_orig * self.mask + self.nmask * x_dec x = img_orig * self.mask + self.nmask * x
# Wrap the image conditioning back up since the DDIM code can accept the dict directly. # Wrap the image conditioning back up since the DDIM code can accept the dict directly.
# Note that they need to be lists because it just concatenates them later. # Note that they need to be lists because it just concatenates them later.
...@@ -84,12 +101,13 @@ class VanillaStableDiffusionSampler: ...@@ -84,12 +101,13 @@ class VanillaStableDiffusionSampler:
cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]} cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs) return x, ts, cond, unconditional_conditioning
def update_step(self, last_latent):
if self.mask is not None: if self.mask is not None:
self.last_latent = self.init_latent * self.mask + self.nmask * res[1] self.last_latent = self.init_latent * self.mask + self.nmask * last_latent
else: else:
self.last_latent = res[1] self.last_latent = last_latent
sd_samplers_common.store_latent(self.last_latent) sd_samplers_common.store_latent(self.last_latent)
...@@ -97,26 +115,51 @@ class VanillaStableDiffusionSampler: ...@@ -97,26 +115,51 @@ class VanillaStableDiffusionSampler:
state.sampling_step = self.step state.sampling_step = self.step
shared.total_tqdm.update() shared.total_tqdm.update()
return res def after_sample(self, x, ts, cond, uncond, res):
if not self.is_unipc:
self.update_step(res[1])
return x, ts, cond, uncond, res
def unipc_after_update(self, x, model_x):
self.update_step(x)
def initialize(self, p): def initialize(self, p):
self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim
if self.eta != 0.0: if self.eta != 0.0:
p.extra_generation_params["Eta DDIM"] = self.eta p.extra_generation_params["Eta DDIM"] = self.eta
if self.is_unipc:
keys = [
('UniPC variant', 'uni_pc_variant'),
('UniPC skip type', 'uni_pc_skip_type'),
('UniPC order', 'uni_pc_order'),
('UniPC lower order final', 'uni_pc_lower_order_final'),
]
for name, key in keys:
v = getattr(shared.opts, key)
if v != shared.opts.get_default(key):
p.extra_generation_params[name] = v
for fieldname in ['p_sample_ddim', 'p_sample_plms']: for fieldname in ['p_sample_ddim', 'p_sample_plms']:
if hasattr(self.sampler, fieldname): if hasattr(self.sampler, fieldname):
setattr(self.sampler, fieldname, self.p_sample_ddim_hook) setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
if self.is_unipc:
self.sampler.set_hooks(lambda x, t, c, u: self.before_sample(x, t, c, u), lambda x, t, c, u, r: self.after_sample(x, t, c, u, r), lambda x, mx: self.unipc_after_update(x, mx))
self.mask = p.mask if hasattr(p, 'mask') else None self.mask = p.mask if hasattr(p, 'mask') else None
self.nmask = p.nmask if hasattr(p, 'nmask') else None self.nmask = p.nmask if hasattr(p, 'nmask') else None
def adjust_steps_if_invalid(self, p, num_steps): def adjust_steps_if_invalid(self, p, num_steps):
if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'): if ((self.config.name == 'DDIM') and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS') or (self.config.name == 'UniPC'):
if self.config.name == 'UniPC' and num_steps < shared.opts.uni_pc_order:
num_steps = shared.opts.uni_pc_order
valid_step = 999 / (1000 // num_steps) valid_step = 999 / (1000 // num_steps)
if valid_step == math.floor(valid_step): if valid_step == math.floor(valid_step):
return int(valid_step) + 1 return int(valid_step) + 1
return num_steps return num_steps
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
......
...@@ -101,11 +101,13 @@ class CFGDenoiser(torch.nn.Module): ...@@ -101,11 +101,13 @@ class CFGDenoiser(torch.nn.Module):
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma]) sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond] + [torch.zeros_like(self.init_latent)]) image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond] + [torch.zeros_like(self.init_latent)])
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps) denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond)
cfg_denoiser_callback(denoiser_params) cfg_denoiser_callback(denoiser_params)
x_in = denoiser_params.x x_in = denoiser_params.x
image_cond_in = denoiser_params.image_cond image_cond_in = denoiser_params.image_cond
sigma_in = denoiser_params.sigma sigma_in = denoiser_params.sigma
tensor = denoiser_params.text_cond
uncond = denoiser_params.text_uncond
if tensor.shape[1] == uncond.shape[1]: if tensor.shape[1] == uncond.shape[1]:
if not is_edit_model: if not is_edit_model:
......
...@@ -35,8 +35,11 @@ def model(): ...@@ -35,8 +35,11 @@ def model():
global sd_vae_approx_model global sd_vae_approx_model
if sd_vae_approx_model is None: if sd_vae_approx_model is None:
model_path = os.path.join(paths.models_path, "VAE-approx", "model.pt")
sd_vae_approx_model = VAEApprox() sd_vae_approx_model = VAEApprox()
sd_vae_approx_model.load_state_dict(torch.load(os.path.join(paths.models_path, "VAE-approx", "model.pt"), map_location='cpu' if devices.device.type != 'cuda' else None)) if not os.path.exists(model_path):
model_path = os.path.join(paths.script_path, "models", "VAE-approx", "model.pt")
sd_vae_approx_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None))
sd_vae_approx_model.eval() sd_vae_approx_model.eval()
sd_vae_approx_model.to(devices.device, devices.dtype) sd_vae_approx_model.to(devices.device, devices.dtype)
......
This diff is collapsed.
...@@ -115,7 +115,7 @@ class PersonalizedBase(Dataset): ...@@ -115,7 +115,7 @@ class PersonalizedBase(Dataset):
weight /= weight.mean() weight /= weight.mean()
elif use_weight: elif use_weight:
#If an image does not have a alpha channel, add a ones weight map anyway so we can stack it later #If an image does not have a alpha channel, add a ones weight map anyway so we can stack it later
weight = torch.ones([channels] + latent_size) weight = torch.ones(latent_sample.shape)
else: else:
weight = None weight = None
......
...@@ -152,7 +152,11 @@ class EmbeddingDatabase: ...@@ -152,7 +152,11 @@ class EmbeddingDatabase:
name = data.get('name', name) name = data.get('name', name)
else: else:
data = extract_image_data_embed(embed_image) data = extract_image_data_embed(embed_image)
name = data.get('name', name) if data:
name = data.get('name', name)
else:
# if data is None, means this is not an embeding, just a preview image
return
elif ext in ['.BIN', '.PT']: elif ext in ['.BIN', '.PT']:
data = torch.load(path, map_location="cpu") data = torch.load(path, map_location="cpu")
elif ext in ['.SAFETENSORS']: elif ext in ['.SAFETENSORS']:
......
...@@ -33,3 +33,6 @@ class Timer: ...@@ -33,3 +33,6 @@ class Timer:
res += ")" res += ")"
return res return res
def reset(self):
self.__init__()
This diff is collapsed.
...@@ -129,8 +129,8 @@ Requested path was: {f} ...@@ -129,8 +129,8 @@ Requested path was: {f}
generation_info = None generation_info = None
with gr.Column(): with gr.Column():
with gr.Row(elem_id=f"image_buttons_{tabname}"): with gr.Row(elem_id=f"image_buttons_{tabname}", elem_classes="image-buttons"):
open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}') open_folder_button = gr.Button(folder_symbol, visible=not shared.cmd_opts.hide_ui_dir_config)
if tabname != "extras": if tabname != "extras":
save = gr.Button('Save', elem_id=f'save_{tabname}') save = gr.Button('Save', elem_id=f'save_{tabname}')
...@@ -160,6 +160,7 @@ Requested path was: {f} ...@@ -160,6 +160,7 @@ Requested path was: {f}
_js="function(x, y, z){ return [x, y, selected_gallery_index()] }", _js="function(x, y, z){ return [x, y, selected_gallery_index()] }",
inputs=[generation_info, html_info, html_info], inputs=[generation_info, html_info, html_info],
outputs=[html_info, html_info], outputs=[html_info, html_info],
show_progress=False,
) )
save.click( save.click(
...@@ -198,9 +199,16 @@ Requested path was: {f} ...@@ -198,9 +199,16 @@ Requested path was: {f}
html_info = gr.HTML(elem_id=f'html_info_{tabname}') html_info = gr.HTML(elem_id=f'html_info_{tabname}')
html_log = gr.HTML(elem_id=f'html_log_{tabname}') html_log = gr.HTML(elem_id=f'html_log_{tabname}')
paste_field_names = []
if tabname == "txt2img":
paste_field_names = modules.scripts.scripts_txt2img.paste_field_names
elif tabname == "img2img":
paste_field_names = modules.scripts.scripts_img2img.paste_field_names
for paste_tabname, paste_button in buttons.items(): for paste_tabname, paste_button in buttons.items():
parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding( parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
paste_button=paste_button, tabname=paste_tabname, source_tabname="txt2img" if tabname == "txt2img" else None, source_image_component=result_gallery paste_button=paste_button, tabname=paste_tabname, source_tabname="txt2img" if tabname == "txt2img" else None, source_image_component=result_gallery,
paste_field_names=paste_field_names
)) ))
return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
import gradio as gr import gradio as gr
class ToolButton(gr.Button, gr.components.FormComponent): class FormComponent:
"""Small button with single emoji as text, fits inside gradio forms""" def get_expected_parent(self):
return gr.components.Form
def __init__(self, **kwargs):
super().__init__(variant="tool", **kwargs)
def get_block_name(self): gr.Dropdown.get_expected_parent = FormComponent.get_expected_parent
return "button"
class ToolButtonTop(gr.Button, gr.components.FormComponent): class ToolButton(FormComponent, gr.Button):
"""Small button with single emoji as text, with extra margin at top, fits inside gradio forms""" """Small button with single emoji as text, fits inside gradio forms"""
def __init__(self, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(variant="tool-top", **kwargs) classes = kwargs.pop("elem_classes", [])
super().__init__(*args, elem_classes=["tool", *classes], **kwargs)
def get_block_name(self): def get_block_name(self):
return "button" return "button"
class FormRow(gr.Row, gr.components.FormComponent): class FormRow(FormComponent, gr.Row):
"""Same as gr.Row but fits inside gradio forms""" """Same as gr.Row but fits inside gradio forms"""
def get_block_name(self): def get_block_name(self):
return "row" return "row"
class FormGroup(gr.Group, gr.components.FormComponent): class FormColumn(FormComponent, gr.Column):
"""Same as gr.Column but fits inside gradio forms"""
def get_block_name(self):
return "column"
class FormGroup(FormComponent, gr.Group):
"""Same as gr.Row but fits inside gradio forms""" """Same as gr.Row but fits inside gradio forms"""
def get_block_name(self): def get_block_name(self):
return "group" return "group"
class FormHTML(gr.HTML, gr.components.FormComponent): class FormHTML(FormComponent, gr.HTML):
"""Same as gr.HTML but fits inside gradio forms""" """Same as gr.HTML but fits inside gradio forms"""
def get_block_name(self): def get_block_name(self):
return "html" return "html"
class FormColorPicker(gr.ColorPicker, gr.components.FormComponent): class FormColorPicker(FormComponent, gr.ColorPicker):
"""Same as gr.ColorPicker but fits inside gradio forms""" """Same as gr.ColorPicker but fits inside gradio forms"""
def get_block_name(self): def get_block_name(self):
return "colorpicker" return "colorpicker"
class DropdownMulti(gr.Dropdown): class DropdownMulti(FormComponent, gr.Dropdown):
"""Same as gr.Dropdown but always multiselect""" """Same as gr.Dropdown but always multiselect"""
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(multiselect=True, **kwargs) super().__init__(multiselect=True, **kwargs)
......
import json import json
import os.path import os.path
import shutil
import sys import sys
import time import time
import traceback import traceback
...@@ -141,22 +140,20 @@ def install_extension_from_url(dirname, url): ...@@ -141,22 +140,20 @@ def install_extension_from_url(dirname, url):
try: try:
shutil.rmtree(tmpdir, True) shutil.rmtree(tmpdir, True)
with git.Repo.clone_from(url, tmpdir) as repo:
repo = git.Repo.clone_from(url, tmpdir) repo.remote().fetch()
repo.remote().fetch() for submodule in repo.submodules:
submodule.update()
try: try:
os.rename(tmpdir, target_dir) os.rename(tmpdir, target_dir)
except OSError as err: except OSError as err:
# TODO what does this do on windows? I think it'll be a different error code but I don't have a system to check it
# Shouldn't cause any new issues at least but we probably want to handle it there too.
if err.errno == errno.EXDEV: if err.errno == errno.EXDEV:
# Cross device link, typical in docker or when tmp/ and extensions/ are on different file systems # Cross device link, typical in docker or when tmp/ and extensions/ are on different file systems
# Since we can't use a rename, do the slower but more versitile shutil.move() # Since we can't use a rename, do the slower but more versitile shutil.move()
shutil.move(tmpdir, target_dir) shutil.move(tmpdir, target_dir)
else: else:
# Something else, not enough free space, permissions, etc. rethrow it so that it gets handled. # Something else, not enough free space, permissions, etc. rethrow it so that it gets handled.
raise(err) raise err
import launch import launch
launch.run_extension_installer(target_dir) launch.run_extension_installer(target_dir)
...@@ -244,7 +241,7 @@ def refresh_available_extensions_from_data(hide_tags, sort_column): ...@@ -244,7 +241,7 @@ def refresh_available_extensions_from_data(hide_tags, sort_column):
hidden += 1 hidden += 1
continue continue
install_code = f"""<input onclick="install_extension_from_index(this, '{html.escape(url)}')" type="button" value="{"Install" if not existing else "Installed"}" {"disabled=disabled" if existing else ""} class="gr-button gr-button-lg gr-button-secondary">""" install_code = f"""<button onclick="install_extension_from_index(this, '{html.escape(url)}')" {"disabled=disabled" if existing else ""} class="lg secondary gradio-button custom-button">{"Install" if not existing else "Installed"}</button>"""
tags_text = ", ".join([f"<span class='extension-tag' title='{tags.get(x, '')}'>{x}</span>" for x in extension_tags]) tags_text = ", ".join([f"<span class='extension-tag' title='{tags.get(x, '')}'>{x}</span>" for x in extension_tags])
...@@ -304,7 +301,7 @@ def create_ui(): ...@@ -304,7 +301,7 @@ def create_ui():
with gr.TabItem("Available"): with gr.TabItem("Available"):
with gr.Row(): with gr.Row():
refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary") refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary")
available_extensions_index = gr.Text(value="https://raw.githubusercontent.com/wiki/AUTOMATIC1111/stable-diffusion-webui/Extensions-index.md", label="Extension index URL").style(container=False) available_extensions_index = gr.Text(value="https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui-extensions/master/index.json", label="Extension index URL").style(container=False)
extension_to_install = gr.Text(elem_id="extension_to_install", visible=False) extension_to_install = gr.Text(elem_id="extension_to_install", visible=False)
install_extension_button = gr.Button(elem_id="install_extension_button", visible=False) install_extension_button = gr.Button(elem_id="install_extension_button", visible=False)
......
...@@ -22,21 +22,37 @@ def register_page(page): ...@@ -22,21 +22,37 @@ def register_page(page):
allowed_dirs.update(set(sum([x.allowed_directories_for_previews() for x in extra_pages], []))) allowed_dirs.update(set(sum([x.allowed_directories_for_previews() for x in extra_pages], [])))
def add_pages_to_demo(app): def fetch_file(filename: str = ""):
def fetch_file(filename: str = ""): from starlette.responses import FileResponse
from starlette.responses import FileResponse
if not any([Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs]):
raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
ext = os.path.splitext(filename)[1].lower()
if ext not in (".png", ".jpg", ".webp"):
raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg and webp.")
# would profit from returning 304
return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
if not any([Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs]): def get_metadata(page: str = "", item: str = ""):
raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.") from starlette.responses import JSONResponse
ext = os.path.splitext(filename)[1].lower() page = next(iter([x for x in extra_pages if x.name == page]), None)
if ext not in (".png", ".jpg"): if page is None:
raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg.") return JSONResponse({})
# would profit from returning 304 metadata = page.metadata.get(item)
return FileResponse(filename, headers={"Accept-Ranges": "bytes"}) if metadata is None:
return JSONResponse({})
return JSONResponse({"metadata": metadata})
def add_pages_to_demo(app):
app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"]) app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"])
app.add_api_route("/sd_extra_networks/metadata", get_metadata, methods=["GET"])
class ExtraNetworksPage: class ExtraNetworksPage:
...@@ -45,6 +61,7 @@ class ExtraNetworksPage: ...@@ -45,6 +61,7 @@ class ExtraNetworksPage:
self.name = title.lower() self.name = title.lower()
self.card_page = shared.html("extra-networks-card.html") self.card_page = shared.html("extra-networks-card.html")
self.allow_negative_prompt = False self.allow_negative_prompt = False
self.metadata = {}
def refresh(self): def refresh(self):
pass pass
...@@ -66,6 +83,8 @@ class ExtraNetworksPage: ...@@ -66,6 +83,8 @@ class ExtraNetworksPage:
view = shared.opts.extra_networks_default_view view = shared.opts.extra_networks_default_view
items_html = '' items_html = ''
self.metadata = {}
subdirs = {} subdirs = {}
for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]: for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]:
for x in glob.glob(os.path.join(parentdir, '**/*'), recursive=True): for x in glob.glob(os.path.join(parentdir, '**/*'), recursive=True):
...@@ -86,12 +105,16 @@ class ExtraNetworksPage: ...@@ -86,12 +105,16 @@ class ExtraNetworksPage:
subdirs = {"": 1, **subdirs} subdirs = {"": 1, **subdirs}
subdirs_html = "".join([f""" subdirs_html = "".join([f"""
<button class='gr-button gr-button-lg gr-button-secondary{" search-all" if subdir=="" else ""}' onclick='extraNetworksSearchButton("{tabname}_extra_tabs", event)'> <button class='lg secondary gradio-button custom-button{" search-all" if subdir=="" else ""}' onclick='extraNetworksSearchButton("{tabname}_extra_tabs", event)'>
{html.escape(subdir if subdir!="" else "all")} {html.escape(subdir if subdir!="" else "all")}
</button> </button>
""" for subdir in subdirs]) """ for subdir in subdirs])
for item in self.list_items(): for item in self.list_items():
metadata = item.get("metadata")
if metadata:
self.metadata[item["name"]] = metadata
items_html += self.create_html_for_item(item, tabname) items_html += self.create_html_for_item(item, tabname)
if items_html == '': if items_html == '':
...@@ -124,9 +147,13 @@ class ExtraNetworksPage: ...@@ -124,9 +147,13 @@ class ExtraNetworksPage:
if onclick is None: if onclick is None:
onclick = '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"' onclick = '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
height = f"height: {shared.opts.extra_networks_card_height}em;" if shared.opts.extra_networks_card_height else '' height = f"height: {shared.opts.extra_networks_card_height}px;" if shared.opts.extra_networks_card_height else ''
width = f"width: {shared.opts.extra_networks_card_width}em;" if shared.opts.extra_networks_card_width else '' width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else ''
background_image = f"background-image: url(\"{html.escape(preview)}\");" if preview else '' background_image = f"background-image: url(\"{html.escape(preview)}\");" if preview else ''
metadata_button = ""
metadata = item.get("metadata")
if metadata:
metadata_button = f"<div class='metadata-button' title='Show metadata' onclick='extraNetworksRequestMetadata({json.dumps(self.name)}, {json.dumps(item['name'])})'></div>"
args = { args = {
"style": f"'{height}{width}{background_image}'", "style": f"'{height}{width}{background_image}'",
...@@ -134,13 +161,44 @@ class ExtraNetworksPage: ...@@ -134,13 +161,44 @@ class ExtraNetworksPage:
"tabname": json.dumps(tabname), "tabname": json.dumps(tabname),
"local_preview": json.dumps(item["local_preview"]), "local_preview": json.dumps(item["local_preview"]),
"name": item["name"], "name": item["name"],
"description": (item.get("description") or ""),
"card_clicked": onclick, "card_clicked": onclick,
"save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"', "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"',
"search_term": item.get("search_term", ""), "search_term": item.get("search_term", ""),
"metadata_button": metadata_button,
} }
return self.card_page.format(**args) return self.card_page.format(**args)
def find_preview(self, path):
"""
Find a preview PNG for a given path (without extension) and call link_preview on it.
"""
preview_extensions = ["png", "jpg", "webp"]
if shared.opts.samples_format not in preview_extensions:
preview_extensions.append(shared.opts.samples_format)
potential_files = sum([[path + "." + ext, path + ".preview." + ext] for ext in preview_extensions], [])
for file in potential_files:
if os.path.isfile(file):
return self.link_preview(file)
return None
def find_description(self, path):
"""
Find and read a description file for a given path (without extension).
"""
for file in [f"{path}.txt", f"{path}.description.txt"]:
try:
with open(file, "r", encoding="utf-8", errors="replace") as f:
return f.read()
except OSError:
pass
return None
def intialize(): def intialize():
extra_pages.clear() extra_pages.clear()
...@@ -182,12 +240,12 @@ def create_ui(container, button, tabname): ...@@ -182,12 +240,12 @@ def create_ui(container, button, tabname):
with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs: with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs:
for page in ui.stored_extra_pages: for page in ui.stored_extra_pages:
with gr.Tab(page.title): with gr.Tab(page.title):
page_elem = gr.HTML(page.create_html(ui.tabname)) page_elem = gr.HTML(page.create_html(ui.tabname))
ui.pages.append(page_elem) ui.pages.append(page_elem)
filter = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False) filter = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False)
button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh") button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh")
button_close = gr.Button('Close', elem_id=tabname+"_extra_close")
ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False)
...@@ -198,7 +256,6 @@ def create_ui(container, button, tabname): ...@@ -198,7 +256,6 @@ def create_ui(container, button, tabname):
state_visible = gr.State(value=False) state_visible = gr.State(value=False)
button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container]) button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container])
button_close.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container])
def refresh(): def refresh():
res = [] res = []
......
import html import html
import json import json
import os import os
import urllib.parse
from modules import shared, ui_extra_networks, sd_models from modules import shared, ui_extra_networks, sd_models
...@@ -17,21 +16,14 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): ...@@ -17,21 +16,14 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
checkpoint: sd_models.CheckpointInfo checkpoint: sd_models.CheckpointInfo
for name, checkpoint in sd_models.checkpoints_list.items(): for name, checkpoint in sd_models.checkpoints_list.items():
path, ext = os.path.splitext(checkpoint.filename) path, ext = os.path.splitext(checkpoint.filename)
previews = [path + ".png", path + ".preview.png"]
preview = None
for file in previews:
if os.path.isfile(file):
preview = self.link_preview(file)
break
yield { yield {
"name": checkpoint.name_for_extra, "name": checkpoint.name_for_extra,
"filename": path, "filename": path,
"preview": preview, "preview": self.find_preview(path),
"description": self.find_description(path),
"search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""), "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""),
"onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"', "onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"',
"local_preview": path + ".png", "local_preview": f"{path}.{shared.opts.samples_format}",
} }
def allowed_directories_for_previews(self): def allowed_directories_for_previews(self):
......
...@@ -14,21 +14,15 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): ...@@ -14,21 +14,15 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
def list_items(self): def list_items(self):
for name, path in shared.hypernetworks.items(): for name, path in shared.hypernetworks.items():
path, ext = os.path.splitext(path) path, ext = os.path.splitext(path)
previews = [path + ".png", path + ".preview.png"]
preview = None
for file in previews:
if os.path.isfile(file):
preview = self.link_preview(file)
break
yield { yield {
"name": name, "name": name,
"filename": path, "filename": path,
"preview": preview, "preview": self.find_preview(path),
"description": self.find_description(path),
"search_term": self.search_terms_from_path(path), "search_term": self.search_terms_from_path(path),
"prompt": json.dumps(f"<hypernet:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"), "prompt": json.dumps(f"<hypernet:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
"local_preview": path + ".png", "local_preview": f"{path}.preview.{shared.opts.samples_format}",
} }
def allowed_directories_for_previews(self): def allowed_directories_for_previews(self):
......
import json import json
import os import os
from modules import ui_extra_networks, sd_hijack from modules import ui_extra_networks, sd_hijack, shared
class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
...@@ -15,19 +15,14 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): ...@@ -15,19 +15,14 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
def list_items(self): def list_items(self):
for embedding in sd_hijack.model_hijack.embedding_db.word_embeddings.values(): for embedding in sd_hijack.model_hijack.embedding_db.word_embeddings.values():
path, ext = os.path.splitext(embedding.filename) path, ext = os.path.splitext(embedding.filename)
preview_file = path + ".preview.png"
preview = None
if os.path.isfile(preview_file):
preview = self.link_preview(preview_file)
yield { yield {
"name": embedding.name, "name": embedding.name,
"filename": embedding.filename, "filename": embedding.filename,
"preview": preview, "preview": self.find_preview(path),
"description": self.find_description(path),
"search_term": self.search_terms_from_path(embedding.filename), "search_term": self.search_terms_from_path(embedding.filename),
"prompt": json.dumps(embedding.name), "prompt": json.dumps(embedding.name),
"local_preview": path + ".preview.png", "local_preview": f"{path}.preview.{shared.opts.samples_format}",
} }
def allowed_directories_for_previews(self): def allowed_directories_for_previews(self):
......
...@@ -3,7 +3,7 @@ transformers==4.25.1 ...@@ -3,7 +3,7 @@ transformers==4.25.1
accelerate==0.12.0 accelerate==0.12.0
basicsr==1.4.2 basicsr==1.4.2
gfpgan==1.3.8 gfpgan==1.3.8
gradio==3.16.2 gradio==3.23
numpy==1.23.3 numpy==1.23.3
Pillow==9.4.0 Pillow==9.4.0
realesrgan==0.3.0 realesrgan==0.3.0
...@@ -23,8 +23,8 @@ torchdiffeq==0.2.3 ...@@ -23,8 +23,8 @@ torchdiffeq==0.2.3
kornia==0.6.7 kornia==0.6.7
lark==1.1.2 lark==1.1.2
inflection==0.5.1 inflection==0.5.1
GitPython==3.1.27 GitPython==3.1.30
torchsde==0.2.5 torchsde==0.2.5
safetensors==0.2.7 safetensors==0.2.7
httpcore<=0.15 httpcore<=0.15
fastapi==0.90.1 fastapi==0.94.0
function gradioApp() { function gradioApp() {
const elems = document.getElementsByTagName('gradio-app') const elems = document.getElementsByTagName('gradio-app')
const gradioShadowRoot = elems.length == 0 ? null : elems[0].shadowRoot const elem = elems.length == 0 ? document : elems[0]
return !!gradioShadowRoot ? gradioShadowRoot : document;
elem.getElementById = function(id){ return document.getElementById(id) }
return elem.shadowRoot ? elem.shadowRoot : elem
} }
function get_uiCurrentTab() { function get_uiCurrentTab() {
......
import numpy as np import math
from tqdm import trange
import modules.scripts as scripts
import gradio as gr import gradio as gr
import modules.scripts as scripts
from modules import processing, shared, sd_samplers, images from modules import deepbooru, images, processing, shared
from modules.processing import Processed from modules.processing import Processed
from modules.sd_samplers import samplers from modules.shared import opts, state
from modules.shared import opts, cmd_opts, state
class Script(scripts.Script): class Script(scripts.Script):
...@@ -19,37 +16,68 @@ class Script(scripts.Script): ...@@ -19,37 +16,68 @@ class Script(scripts.Script):
def ui(self, is_img2img): def ui(self, is_img2img):
loops = gr.Slider(minimum=1, maximum=32, step=1, label='Loops', value=4, elem_id=self.elem_id("loops")) loops = gr.Slider(minimum=1, maximum=32, step=1, label='Loops', value=4, elem_id=self.elem_id("loops"))
denoising_strength_change_factor = gr.Slider(minimum=0.9, maximum=1.1, step=0.01, label='Denoising strength change factor', value=1, elem_id=self.elem_id("denoising_strength_change_factor")) final_denoising_strength = gr.Slider(minimum=0, maximum=1, step=0.01, label='Final denoising strength', value=0.5, elem_id=self.elem_id("final_denoising_strength"))
denoising_curve = gr.Dropdown(label="Denoising strength curve", choices=["Aggressive", "Linear", "Lazy"], value="Linear")
append_interrogation = gr.Dropdown(label="Append interrogated prompt at each iteration", choices=["None", "CLIP", "DeepBooru"], value="None")
return [loops, denoising_strength_change_factor] return [loops, final_denoising_strength, denoising_curve, append_interrogation]
def run(self, p, loops, denoising_strength_change_factor): def run(self, p, loops, final_denoising_strength, denoising_curve, append_interrogation):
processing.fix_seed(p) processing.fix_seed(p)
batch_count = p.n_iter batch_count = p.n_iter
p.extra_generation_params = { p.extra_generation_params = {
"Denoising strength change factor": denoising_strength_change_factor, "Final denoising strength": final_denoising_strength,
"Denoising curve": denoising_curve
} }
p.batch_size = 1 p.batch_size = 1
p.n_iter = 1 p.n_iter = 1
output_images, info = None, None info = None
initial_seed = None initial_seed = None
initial_info = None initial_info = None
initial_denoising_strength = p.denoising_strength
grids = [] grids = []
all_images = [] all_images = []
original_init_image = p.init_images original_init_image = p.init_images
original_prompt = p.prompt
original_inpainting_fill = p.inpainting_fill
state.job_count = loops * batch_count state.job_count = loops * batch_count
initial_color_corrections = [processing.setup_color_correction(p.init_images[0])] initial_color_corrections = [processing.setup_color_correction(p.init_images[0])]
for n in range(batch_count): def calculate_denoising_strength(loop):
history = [] strength = initial_denoising_strength
if loops == 1:
return strength
progress = loop / (loops - 1)
match denoising_curve:
case "Aggressive":
strength = math.sin((progress) * math.pi * 0.5)
case "Lazy":
strength = 1 - math.cos((progress) * math.pi * 0.5)
case _:
strength = progress
change = (final_denoising_strength - initial_denoising_strength) * strength
return initial_denoising_strength + change
history = []
for n in range(batch_count):
# Reset to original init image at the start of each batch # Reset to original init image at the start of each batch
p.init_images = original_init_image p.init_images = original_init_image
# Reset to original denoising strength
p.denoising_strength = initial_denoising_strength
last_image = None
for i in range(loops): for i in range(loops):
p.n_iter = 1 p.n_iter = 1
p.batch_size = 1 p.batch_size = 1
...@@ -58,30 +86,57 @@ class Script(scripts.Script): ...@@ -58,30 +86,57 @@ class Script(scripts.Script):
if opts.img2img_color_correction: if opts.img2img_color_correction:
p.color_corrections = initial_color_corrections p.color_corrections = initial_color_corrections
if append_interrogation != "None":
p.prompt = original_prompt + ", " if original_prompt != "" else ""
if append_interrogation == "CLIP":
p.prompt += shared.interrogator.interrogate(p.init_images[0])
elif append_interrogation == "DeepBooru":
p.prompt += deepbooru.model.tag(p.init_images[0])
state.job = f"Iteration {i + 1}/{loops}, batch {n + 1}/{batch_count}" state.job = f"Iteration {i + 1}/{loops}, batch {n + 1}/{batch_count}"
processed = processing.process_images(p) processed = processing.process_images(p)
# Generation cancelled.
if state.interrupted:
break
if initial_seed is None: if initial_seed is None:
initial_seed = processed.seed initial_seed = processed.seed
initial_info = processed.info initial_info = processed.info
init_img = processed.images[0]
p.init_images = [init_img]
p.seed = processed.seed + 1 p.seed = processed.seed + 1
p.denoising_strength = min(max(p.denoising_strength * denoising_strength_change_factor, 0.1), 1) p.denoising_strength = calculate_denoising_strength(i + 1)
history.append(processed.images[0])
if state.skipped:
break
last_image = processed.images[0]
p.init_images = [last_image]
p.inpainting_fill = 1 # Set "masked content" to "original" for next loop.
if batch_count == 1:
history.append(last_image)
all_images.append(last_image)
if batch_count > 1 and not state.skipped and not state.interrupted:
history.append(last_image)
all_images.append(last_image)
p.inpainting_fill = original_inpainting_fill
if state.interrupted:
break
if len(history) > 1:
grid = images.image_grid(history, rows=1) grid = images.image_grid(history, rows=1)
if opts.grid_save: if opts.grid_save:
images.save_image(grid, p.outpath_grids, "grid", initial_seed, p.prompt, opts.grid_format, info=info, short_filename=not opts.grid_extended_filename, grid=True, p=p) images.save_image(grid, p.outpath_grids, "grid", initial_seed, p.prompt, opts.grid_format, info=info, short_filename=not opts.grid_extended_filename, grid=True, p=p)
grids.append(grid) if opts.return_grid:
all_images += history grids.append(grid)
if opts.return_grid: all_images = grids + all_images
all_images = grids + all_images
processed = Processed(p, all_images, initial_seed, initial_info) processed = Processed(p, all_images, initial_seed, initial_info)
......
...@@ -17,22 +17,24 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing): ...@@ -17,22 +17,24 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
def ui(self): def ui(self):
selected_tab = gr.State(value=0) selected_tab = gr.State(value=0)
with gr.Tabs(elem_id="extras_resize_mode"): with gr.Column():
with gr.TabItem('Scale by', elem_id="extras_scale_by_tab") as tab_scale_by: with FormRow():
upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize") with gr.Tabs(elem_id="extras_resize_mode"):
with gr.TabItem('Scale by', elem_id="extras_scale_by_tab") as tab_scale_by:
with gr.TabItem('Scale to', elem_id="extras_scale_to_tab") as tab_scale_to: upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize")
with FormRow():
upscaling_resize_w = gr.Number(label="Width", value=512, precision=0, elem_id="extras_upscaling_resize_w") with gr.TabItem('Scale to', elem_id="extras_scale_to_tab") as tab_scale_to:
upscaling_resize_h = gr.Number(label="Height", value=512, precision=0, elem_id="extras_upscaling_resize_h") with FormRow():
upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop") upscaling_resize_w = gr.Number(label="Width", value=512, precision=0, elem_id="extras_upscaling_resize_w")
upscaling_resize_h = gr.Number(label="Height", value=512, precision=0, elem_id="extras_upscaling_resize_h")
with FormRow(): upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop")
extras_upscaler_1 = gr.Dropdown(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)
with FormRow():
with FormRow(): extras_upscaler_1 = gr.Dropdown(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)
extras_upscaler_2 = gr.Dropdown(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)
extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=0.0, elem_id="extras_upscaler_2_visibility") with FormRow():
extras_upscaler_2 = gr.Dropdown(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)
extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=0.0, elem_id="extras_upscaler_2_visibility")
tab_scale_by.select(fn=lambda: 0, inputs=[], outputs=[selected_tab]) tab_scale_by.select(fn=lambda: 0, inputs=[], outputs=[selected_tab])
tab_scale_to.select(fn=lambda: 1, inputs=[], outputs=[selected_tab]) tab_scale_to.select(fn=lambda: 1, inputs=[], outputs=[selected_tab])
......
...@@ -100,7 +100,7 @@ class Script(scripts.Script): ...@@ -100,7 +100,7 @@ class Script(scripts.Script):
processed = process_images(p) processed = process_images(p)
grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2)) grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2))
grid = images.draw_prompt_matrix(grid, processed.images[0].width, processed.images[1].height, prompt_matrix_parts, margin_size) grid = images.draw_prompt_matrix(grid, processed.images[0].width, processed.images[0].height, prompt_matrix_parts, margin_size)
processed.images.insert(0, grid) processed.images.insert(0, grid)
processed.index_of_first_image = 1 processed.index_of_first_image = 1
processed.infotexts.insert(0, processed.infotexts[0]) processed.infotexts.insert(0, processed.infotexts[0])
......
This diff is collapsed.
This diff is collapsed.
import os
import unittest import unittest
import requests import requests
from gradio.processing_utils import encode_pil_to_base64 from gradio.processing_utils import encode_pil_to_base64
from PIL import Image from PIL import Image
from modules.paths import script_path
class TestExtrasWorking(unittest.TestCase): class TestExtrasWorking(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -19,7 +21,7 @@ class TestExtrasWorking(unittest.TestCase): ...@@ -19,7 +21,7 @@ class TestExtrasWorking(unittest.TestCase):
"upscaler_1": "None", "upscaler_1": "None",
"upscaler_2": "None", "upscaler_2": "None",
"extras_upscaler_2_visibility": 0, "extras_upscaler_2_visibility": 0,
"image": encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png")) "image": encode_pil_to_base64(Image.open(os.path.join(script_path, r"test/test_files/img2img_basic.png")))
} }
def test_simple_upscaling_performed(self): def test_simple_upscaling_performed(self):
...@@ -31,7 +33,7 @@ class TestPngInfoWorking(unittest.TestCase): ...@@ -31,7 +33,7 @@ class TestPngInfoWorking(unittest.TestCase):
def setUp(self): def setUp(self):
self.url_png_info = "http://localhost:7860/sdapi/v1/extra-single-image" self.url_png_info = "http://localhost:7860/sdapi/v1/extra-single-image"
self.png_info = { self.png_info = {
"image": encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png")) "image": encode_pil_to_base64(Image.open(os.path.join(script_path, r"test/test_files/img2img_basic.png")))
} }
def test_png_info_performed(self): def test_png_info_performed(self):
...@@ -42,7 +44,7 @@ class TestInterrogateWorking(unittest.TestCase): ...@@ -42,7 +44,7 @@ class TestInterrogateWorking(unittest.TestCase):
def setUp(self): def setUp(self):
self.url_interrogate = "http://localhost:7860/sdapi/v1/extra-single-image" self.url_interrogate = "http://localhost:7860/sdapi/v1/extra-single-image"
self.interrogate = { self.interrogate = {
"image": encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png")), "image": encode_pil_to_base64(Image.open(os.path.join(script_path, r"test/test_files/img2img_basic.png"))),
"model": "clip" "model": "clip"
} }
......
import os
import unittest import unittest
import requests import requests
from gradio.processing_utils import encode_pil_to_base64 from gradio.processing_utils import encode_pil_to_base64
from PIL import Image from PIL import Image
from modules.paths import script_path
class TestImg2ImgWorking(unittest.TestCase): class TestImg2ImgWorking(unittest.TestCase):
def setUp(self): def setUp(self):
self.url_img2img = "http://localhost:7860/sdapi/v1/img2img" self.url_img2img = "http://localhost:7860/sdapi/v1/img2img"
self.simple_img2img = { self.simple_img2img = {
"init_images": [encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png"))], "init_images": [encode_pil_to_base64(Image.open(os.path.join(script_path, r"test/test_files/img2img_basic.png")))],
"resize_mode": 0, "resize_mode": 0,
"denoising_strength": 0.75, "denoising_strength": 0.75,
"mask": None, "mask": None,
...@@ -47,11 +49,11 @@ class TestImg2ImgWorking(unittest.TestCase): ...@@ -47,11 +49,11 @@ class TestImg2ImgWorking(unittest.TestCase):
self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200) self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200)
def test_inpainting_masked_performed(self): def test_inpainting_masked_performed(self):
self.simple_img2img["mask"] = encode_pil_to_base64(Image.open(r"test/test_files/mask_basic.png")) self.simple_img2img["mask"] = encode_pil_to_base64(Image.open(os.path.join(script_path, r"test/test_files/img2img_basic.png")))
self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200) self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200)
def test_inpainting_with_inverted_masked_performed(self): def test_inpainting_with_inverted_masked_performed(self):
self.simple_img2img["mask"] = encode_pil_to_base64(Image.open(r"test/test_files/mask_basic.png")) self.simple_img2img["mask"] = encode_pil_to_base64(Image.open(os.path.join(script_path, r"test/test_files/img2img_basic.png")))
self.simple_img2img["inpainting_mask_invert"] = True self.simple_img2img["inpainting_mask_invert"] = True
self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200) self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200)
......
...@@ -66,6 +66,8 @@ class TestTxt2ImgWorking(unittest.TestCase): ...@@ -66,6 +66,8 @@ class TestTxt2ImgWorking(unittest.TestCase):
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
self.simple_txt2img["sampler_index"] = "DDIM" self.simple_txt2img["sampler_index"] = "DDIM"
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
self.simple_txt2img["sampler_index"] = "UniPC"
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
def test_txt2img_multiple_batches_performed(self): def test_txt2img_multiple_batches_performed(self):
self.simple_txt2img["n_iter"] = 2 self.simple_txt2img["n_iter"] = 2
......
import unittest import unittest
import requests import requests
import time import time
import os
from modules.paths import script_path
def run_tests(proc, test_dir): def run_tests(proc, test_dir):
...@@ -15,8 +17,8 @@ def run_tests(proc, test_dir): ...@@ -15,8 +17,8 @@ def run_tests(proc, test_dir):
break break
if proc.poll() is None: if proc.poll() is None:
if test_dir is None: if test_dir is None:
test_dir = "test" test_dir = os.path.join(script_path, "test")
suite = unittest.TestLoader().discover(test_dir, pattern="*_test.py", top_level_dir="test") suite = unittest.TestLoader().discover(test_dir, pattern="*_test.py", top_level_dir=test_dir)
result = unittest.TextTestRunner(verbosity=2).run(suite) result = unittest.TextTestRunner(verbosity=2).run(suite)
return len(result.failures) + len(result.errors) return len(result.failures) + len(result.errors)
else: else:
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment