Commit ee3d63b6 authored by InvincibleDude's avatar InvincibleDude Committed by GitHub

Merge branch 'master' into master

parents 44c0e6b9 00dab8f1
...@@ -37,20 +37,20 @@ body: ...@@ -37,20 +37,20 @@ body:
id: what-should id: what-should
attributes: attributes:
label: What should have happened? label: What should have happened?
description: tell what you think the normal behavior should be description: Tell what you think the normal behavior should be
validations: validations:
required: true required: true
- type: input - type: input
id: commit id: commit
attributes: attributes:
label: Commit where the problem happens label: Commit where the problem happens
description: Which commit are you running ? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit hash** shown in the cmd/terminal when you launch the UI) description: Which commit are you running ? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit** link at the bottom of the UI, or from the cmd/terminal if you can't launch it.)
validations: validations:
required: true required: true
- type: dropdown - type: dropdown
id: platforms id: platforms
attributes: attributes:
label: What platforms do you use to access UI ? label: What platforms do you use to access the UI ?
multiple: true multiple: true
options: options:
- Windows - Windows
...@@ -74,10 +74,27 @@ body: ...@@ -74,10 +74,27 @@ body:
id: cmdargs id: cmdargs
attributes: attributes:
label: Command Line Arguments label: Command Line Arguments
description: Are you using any launching parameters/command line arguments (modified webui-user.py) ? If yes, please write them below description: Are you using any launching parameters/command line arguments (modified webui-user .bat/.sh) ? If yes, please write them below. Write "No" otherwise.
render: Shell render: Shell
validations:
required: true
- type: textarea
id: extensions
attributes:
label: List of extensions
description: Are you using any extensions other than built-ins? If yes, provide a list, you can copy it at "Extensions" tab. Write "No" otherwise.
validations:
required: true
- type: textarea
id: logs
attributes:
label: Console logs
description: Please provide **full** cmd/terminal logs from the moment you started UI to the end of it, after your bug happened. If it's very long, provide a link to pastebin or similar service.
render: Shell
validations:
required: true
- type: textarea - type: textarea
id: misc id: misc
attributes: attributes:
label: Additional information, context and logs label: Additional information
description: Please provide us with any relevant additional info, context or log output. description: Please provide us with any relevant additional info or context.
...@@ -17,7 +17,7 @@ A browser interface based on Gradio library for Stable Diffusion. ...@@ -17,7 +17,7 @@ A browser interface based on Gradio library for Stable Diffusion.
- a man in a (tuxedo:1.21) - alternative syntax - a man in a (tuxedo:1.21) - alternative syntax
- select text and press ctrl+up or ctrl+down to automatically adjust attention to selected text (code contributed by anonymous user) - select text and press ctrl+up or ctrl+down to automatically adjust attention to selected text (code contributed by anonymous user)
- Loopback, run img2img processing multiple times - Loopback, run img2img processing multiple times
- X/Y plot, a way to draw a 2 dimensional plot of images with different parameters - X/Y/Z plot, a way to draw a 3 dimensional plot of images with different parameters
- Textual Inversion - Textual Inversion
- have as many embeddings as you want and use any names you like for them - have as many embeddings as you want and use any names you like for them
- use multiple embeddings with different numbers of vectors per token - use multiple embeddings with different numbers of vectors per token
...@@ -155,6 +155,8 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al ...@@ -155,6 +155,8 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
- Idea for Composable Diffusion - https://github.com/energy-based-model/Compositional-Visual-Generation-with-Composable-Diffusion-Models-PyTorch - Idea for Composable Diffusion - https://github.com/energy-based-model/Compositional-Visual-Generation-with-Composable-Diffusion-Models-PyTorch
- xformers - https://github.com/facebookresearch/xformers - xformers - https://github.com/facebookresearch/xformers
- DeepDanbooru - interrogator for anime diffusers https://github.com/KichangKim/DeepDanbooru - DeepDanbooru - interrogator for anime diffusers https://github.com/KichangKim/DeepDanbooru
- Sampling in float32 precision from a float16 UNet - marunine for the idea, Birch-san for the example Diffusers implementation (https://github.com/Birch-san/diffusers-play/tree/92feee6)
- Instruct pix2pix - Tim Brooks (star), Aleksander Holynski (star), Alexei A. Efros (no star) - https://github.com/timothybrooks/instruct-pix2pix
- Security advice - RyotaK - Security advice - RyotaK
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user. - Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
- (You) - (You)
# File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
# See more details in LICENSE.
model:
base_learning_rate: 1.0e-04
target: modules.models.diffusion.ddpm_edit.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: edited
cond_stage_key: edit
# image_size: 64
# image_size: 32
image_size: 16
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: hybrid
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: true
load_ema: true
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 0 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 8
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
data:
target: main.DataModuleFromConfig
params:
batch_size: 128
num_workers: 1
wrap: false
validation:
target: edit_dataset.EditDataset
params:
path: data/clip-filtered-dataset
cache_dir: data/
cache_name: data_10k
split: val
min_text_sim: 0.2
min_image_sim: 0.75
min_direction_sim: 0.2
max_samples_per_prompt: 1
min_resize_res: 512
max_resize_res: 512
crop_res: 512
output_as_edit: False
real_input: True
model: model:
base_learning_rate: 1.0e-4 base_learning_rate: 7.5e-05
target: ldm.models.diffusion.ddpm.LatentDiffusion target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
params: params:
parameterization: "v"
linear_start: 0.00085 linear_start: 0.00085
linear_end: 0.0120 linear_end: 0.0120
num_timesteps_cond: 1 num_timesteps_cond: 1
...@@ -12,29 +11,36 @@ model: ...@@ -12,29 +11,36 @@ model:
cond_stage_key: "txt" cond_stage_key: "txt"
image_size: 64 image_size: 64
channels: 4 channels: 4
cond_stage_trainable: false cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: crossattn conditioning_key: hybrid # important
monitor: val/loss_simple_ema monitor: val/loss_simple_ema
scale_factor: 0.18215 scale_factor: 0.18215
use_ema: False # we set this to false because this is an inference only config finetune_keys: null
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
unet_config: unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params: params:
use_checkpoint: True
use_fp16: True
image_size: 32 # unused image_size: 32 # unused
in_channels: 4 in_channels: 9 # 4 data + 4 downscaled image + 1 mask
out_channels: 4 out_channels: 4
model_channels: 320 model_channels: 320
attention_resolutions: [ 4, 2, 1 ] attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2 num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ] channel_mult: [ 1, 2, 4, 4 ]
num_head_channels: 64 # need to fix for flash-attn num_heads: 8
use_spatial_transformer: True use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1 transformer_depth: 1
context_dim: 1024 context_dim: 768
use_checkpoint: True
legacy: False legacy: False
first_stage_config: first_stage_config:
...@@ -43,7 +49,6 @@ model: ...@@ -43,7 +49,6 @@ model:
embed_dim: 4 embed_dim: 4
monitor: val/rec_loss monitor: val/rec_loss
ddconfig: ddconfig:
#attn_type: "vanilla-xformers"
double_z: true double_z: true
z_channels: 4 z_channels: 4
resolution: 256 resolution: 256
...@@ -62,7 +67,4 @@ model: ...@@ -62,7 +67,4 @@ model:
target: torch.nn.Identity target: torch.nn.Identity
cond_stage_config: cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
params:
freeze: True
layer: "penultimate"
\ No newline at end of file
from modules import extra_networks from modules import extra_networks, shared
import lora import lora
class ExtraNetworkLora(extra_networks.ExtraNetwork): class ExtraNetworkLora(extra_networks.ExtraNetwork):
...@@ -6,6 +6,12 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork): ...@@ -6,6 +6,12 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
super().__init__('lora') super().__init__('lora')
def activate(self, p, params_list): def activate(self, p, params_list):
additional = shared.opts.sd_lora
if additional != "" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0:
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
names = [] names = []
multipliers = [] multipliers = []
for params in params_list: for params in params_list:
......
...@@ -166,7 +166,10 @@ def lora_forward(module, input, res): ...@@ -166,7 +166,10 @@ def lora_forward(module, input, res):
for lora in loaded_loras: for lora in loaded_loras:
module = lora.modules.get(lora_layer_name, None) module = lora.modules.get(lora_layer_name, None)
if module is not None: if module is not None:
res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) if shared.opts.lora_apply_to_outputs and res.shape == input.shape:
res = res + module.up(module.down(res)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
else:
res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
return res return res
......
import torch import torch
import gradio as gr
import lora import lora
import extra_networks_lora import extra_networks_lora
import ui_extra_networks_lora import ui_extra_networks_lora
from modules import script_callbacks, ui_extra_networks, extra_networks from modules import script_callbacks, ui_extra_networks, extra_networks, shared
def unload(): def unload():
...@@ -28,3 +29,10 @@ torch.nn.Conv2d.forward = lora.lora_Conv2d_forward ...@@ -28,3 +29,10 @@ torch.nn.Conv2d.forward = lora.lora_Conv2d_forward
script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules) script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
script_callbacks.on_script_unloaded(unload) script_callbacks.on_script_unloaded(unload)
script_callbacks.on_before_ui(before_ui) script_callbacks.on_before_ui(before_ui)
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras),
"lora_apply_to_outputs": shared.OptionInfo(False, "Apply Lora to outputs rather than inputs when possible (experimental)"),
}))
...@@ -20,13 +20,14 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): ...@@ -20,13 +20,14 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
preview = None preview = None
for file in previews: for file in previews:
if os.path.isfile(file): if os.path.isfile(file):
preview = "./file=" + file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(file)) preview = self.link_preview(file)
break break
yield { yield {
"name": name, "name": name,
"filename": path, "filename": path,
"preview": preview, "preview": preview,
"search_term": self.search_terms_from_path(lora_on_disk.filename),
"prompt": json.dumps(f"<lora:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"), "prompt": json.dumps(f"<lora:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
"local_preview": path + ".png", "local_preview": path + ".png",
} }
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
<ul> <ul>
<a href="#" title="replace preview image with currently selected in gallery" onclick={save_card_preview}>replace preview</a> <a href="#" title="replace preview image with currently selected in gallery" onclick={save_card_preview}>replace preview</a>
</ul> </ul>
<span style="display:none" class='search_term'>{search_term}</span>
</div> </div>
<span class='name'>{name}</span> <span class='name'>{name}</span>
</div> </div>
......
function extensions_apply(_, _){ function extensions_apply(_, _){
disable = [] var disable = []
update = [] var update = []
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){ gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
if(x.name.startsWith("enable_") && ! x.checked) if(x.name.startsWith("enable_") && ! x.checked)
disable.push(x.name.substr(7)) disable.push(x.name.substr(7))
...@@ -16,11 +17,24 @@ function extensions_apply(_, _){ ...@@ -16,11 +17,24 @@ function extensions_apply(_, _){
} }
function extensions_check(){ function extensions_check(){
var disable = []
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
if(x.name.startsWith("enable_") && ! x.checked)
disable.push(x.name.substr(7))
})
gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){ gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){
x.innerHTML = "Loading..." x.innerHTML = "Loading..."
}) })
return []
var id = randomId()
requestProgress(id, gradioApp().getElementById('extensions_installed_top'), null, function(){
})
return [id, JSON.stringify(disable)]
} }
function install_extension_from_index(button, url){ function install_extension_from_index(button, url){
......
...@@ -16,7 +16,7 @@ function setupExtraNetworksForTab(tabname){ ...@@ -16,7 +16,7 @@ function setupExtraNetworksForTab(tabname){
searchTerm = search.value.toLowerCase() searchTerm = search.value.toLowerCase()
gradioApp().querySelectorAll('#'+tabname+'_extra_tabs div.card').forEach(function(elem){ gradioApp().querySelectorAll('#'+tabname+'_extra_tabs div.card').forEach(function(elem){
text = elem.querySelector('.name').textContent.toLowerCase() text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase()
elem.style.display = text.indexOf(searchTerm) == -1 ? "none" : "" elem.style.display = text.indexOf(searchTerm) == -1 ? "none" : ""
}) })
}); });
...@@ -48,10 +48,39 @@ function setupExtraNetworks(){ ...@@ -48,10 +48,39 @@ function setupExtraNetworks(){
onUiLoaded(setupExtraNetworks) onUiLoaded(setupExtraNetworks)
var re_extranet = /<([^:]+:[^:]+):[\d\.]+>/;
var re_extranet_g = /\s+<([^:]+:[^:]+):[\d\.]+>/g;
function tryToRemoveExtraNetworkFromPrompt(textarea, text){
var m = text.match(re_extranet)
if(! m) return false
var partToSearch = m[1]
var replaced = false
var newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found, index){
m = found.match(re_extranet);
if(m[1] == partToSearch){
replaced = true;
return ""
}
return found;
})
if(replaced){
textarea.value = newTextareaText
return true;
}
return false
}
function cardClicked(tabname, textToAdd, allowNegativePrompt){ function cardClicked(tabname, textToAdd, allowNegativePrompt){
var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea") var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea")
textarea.value = textarea.value + " " + textToAdd if(! tryToRemoveExtraNetworkFromPrompt(textarea, textToAdd)){
textarea.value = textarea.value + " " + textToAdd
}
updateInput(textarea) updateInput(textarea)
} }
...@@ -67,3 +96,12 @@ function saveCardPreview(event, tabname, filename){ ...@@ -67,3 +96,12 @@ function saveCardPreview(event, tabname, filename){
event.stopPropagation() event.stopPropagation()
event.preventDefault() event.preventDefault()
} }
function extraNetworksSearchButton(tabs_id, event){
searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > div > textarea')
button = event.target
text = button.classList.contains("search-all") ? "" : button.textContent.trim()
searchTextarea.value = text
updateInput(searchTextarea)
}
\ No newline at end of file
...@@ -50,7 +50,7 @@ titles = { ...@@ -50,7 +50,7 @@ titles = {
"None": "Do not do anything special", "None": "Do not do anything special",
"Prompt matrix": "Separate prompts into parts using vertical pipe character (|) and the script will create a picture for every combination of them (except for the first part, which will be present in all combinations)", "Prompt matrix": "Separate prompts into parts using vertical pipe character (|) and the script will create a picture for every combination of them (except for the first part, which will be present in all combinations)",
"X/Y plot": "Create a grid where images will have different parameters. Use inputs below to specify which parameters will be shared by columns and rows", "X/Y/Z plot": "Create grid(s) where images will have different parameters. Use inputs below to specify which parameters will be shared by columns and rows",
"Custom code": "Run Python code. Advanced user only. Must run program with --allow-code for this to work", "Custom code": "Run Python code. Advanced user only. Must run program with --allow-code for this to work",
"Prompt S/R": "Separate a list of words with commas, and the first word will be used as a keyword: script will search for this word in the prompt, and replace it with others", "Prompt S/R": "Separate a list of words with commas, and the first word will be used as a keyword: script will search for this word in the prompt, and replace it with others",
......
...@@ -309,3 +309,10 @@ function updateInput(target){ ...@@ -309,3 +309,10 @@ function updateInput(target){
Object.defineProperty(e, "target", {value: target}) Object.defineProperty(e, "target", {value: target})
target.dispatchEvent(e); target.dispatchEvent(e);
} }
var desiredCheckpointName = null;
function selectCheckpoint(name){
desiredCheckpointName = name;
gradioApp().getElementById('change_checkpoint').click()
}
...@@ -17,6 +17,37 @@ stored_commit_hash = None ...@@ -17,6 +17,37 @@ stored_commit_hash = None
skip_install = False skip_install = False
def check_python_version():
is_windows = platform.system() == "Windows"
major = sys.version_info.major
minor = sys.version_info.minor
micro = sys.version_info.micro
if is_windows:
supported_minors = [10]
else:
supported_minors = [7, 8, 9, 10, 11]
if not (major == 3 and minor in supported_minors):
import modules.errors
modules.errors.print_error_explanation(f"""
INCOMPATIBLE PYTHON VERSION
This program is tested with 3.10.6 Python, but you have {major}.{minor}.{micro}.
If you encounter an error with "RuntimeError: Couldn't install torch." message,
or any other error regarding unsuccessful package (library) installation,
please downgrade (or upgrade) to the latest version of 3.10 Python
and delete current Python and "venv" folder in WebUI's directory.
You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3109/
{"Alternatively, use a binary release of WebUI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases" if is_windows else ""}
Use --skip-python-version-check to suppress this warning.
""")
def commit_hash(): def commit_hash():
global stored_commit_hash global stored_commit_hash
...@@ -48,10 +79,19 @@ def extract_opt(args, name): ...@@ -48,10 +79,19 @@ def extract_opt(args, name):
return args, is_present, opt return args, is_present, opt
def run(command, desc=None, errdesc=None, custom_env=None): def run(command, desc=None, errdesc=None, custom_env=None, live=False):
if desc is not None: if desc is not None:
print(desc) print(desc)
if live:
result = subprocess.run(command, shell=True, env=os.environ if custom_env is None else custom_env)
if result.returncode != 0:
raise RuntimeError(f"""{errdesc or 'Error running command'}.
Command: {command}
Error code: {result.returncode}""")
return ""
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env) result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env)
if result.returncode != 0: if result.returncode != 0:
...@@ -108,18 +148,18 @@ def git_clone(url, dir, name, commithash=None): ...@@ -108,18 +148,18 @@ def git_clone(url, dir, name, commithash=None):
if commithash is None: if commithash is None:
return return
current_hash = run(f'"{git}" -C {dir} rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip() current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip()
if current_hash == commithash: if current_hash == commithash:
return return
run(f'"{git}" -C {dir} fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}") run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
run(f'"{git}" -C {dir} checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}") run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}")
return return
run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}") run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}")
if commithash is not None: if commithash is not None:
run(f'"{git}" -C {dir} checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}") run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
def version_check(commit): def version_check(commit):
...@@ -207,6 +247,7 @@ def prepare_environment(): ...@@ -207,6 +247,7 @@ def prepare_environment():
sys.argv, _ = extract_arg(sys.argv, '-f') sys.argv, _ = extract_arg(sys.argv, '-f')
sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test') sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test')
sys.argv, skip_python_version_check = extract_arg(sys.argv, '--skip-python-version-check')
sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers') sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers')
sys.argv, reinstall_torch = extract_arg(sys.argv, '--reinstall-torch') sys.argv, reinstall_torch = extract_arg(sys.argv, '--reinstall-torch')
sys.argv, update_check = extract_arg(sys.argv, '--update-check') sys.argv, update_check = extract_arg(sys.argv, '--update-check')
...@@ -215,13 +256,16 @@ def prepare_environment(): ...@@ -215,13 +256,16 @@ def prepare_environment():
xformers = '--xformers' in sys.argv xformers = '--xformers' in sys.argv
ngrok = '--ngrok' in sys.argv ngrok = '--ngrok' in sys.argv
if not skip_python_version_check:
check_python_version()
commit = commit_hash() commit = commit_hash()
print(f"Python {sys.version}") print(f"Python {sys.version}")
print(f"Commit hash: {commit}") print(f"Commit hash: {commit}")
if reinstall_torch or not is_installed("torch") or not is_installed("torchvision"): if reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch") run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
if not skip_torch_cuda_test: if not skip_torch_cuda_test:
run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'") run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'")
......
...@@ -18,7 +18,8 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_ ...@@ -18,7 +18,8 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_
from modules.textual_inversion.preprocess import preprocess from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
from PIL import PngImagePlugin,Image from PIL import PngImagePlugin,Image
from modules.sd_models import checkpoints_list, find_checkpoint_config from modules.sd_models import checkpoints_list
from modules.sd_models_config import find_checkpoint_config_near_filename
from modules.realesrgan_model import get_realesrgan_models from modules.realesrgan_model import get_realesrgan_models
from modules import devices from modules import devices
from typing import List from typing import List
...@@ -387,7 +388,7 @@ class Api: ...@@ -387,7 +388,7 @@ class Api:
] ]
def get_sd_models(self): def get_sd_models(self):
return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()] return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()]
def get_hypernetworks(self): def get_hypernetworks(self):
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks] return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
......
...@@ -228,7 +228,7 @@ class SDModelItem(BaseModel): ...@@ -228,7 +228,7 @@ class SDModelItem(BaseModel):
hash: Optional[str] = Field(title="Short hash") hash: Optional[str] = Field(title="Short hash")
sha256: Optional[str] = Field(title="sha256 hash") sha256: Optional[str] = Field(title="sha256 hash")
filename: str = Field(title="Filename") filename: str = Field(title="Filename")
config: str = Field(title="Config file") config: Optional[str] = Field(title="Config file")
class HypernetworkItem(BaseModel): class HypernetworkItem(BaseModel):
name: str = Field(title="Name") name: str = Field(title="Name")
......
...@@ -8,7 +8,7 @@ import torch ...@@ -8,7 +8,7 @@ import torch
import modules.face_restoration import modules.face_restoration
import modules.shared import modules.shared
from modules import shared, devices, modelloader from modules import shared, devices, modelloader
from modules.paths import script_path, models_path from modules.paths import models_path
# codeformer people made a choice to include modified basicsr library to their project which makes # codeformer people made a choice to include modified basicsr library to their project which makes
# it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN. # it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN.
......
...@@ -2,6 +2,8 @@ import torch ...@@ -2,6 +2,8 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from modules import devices
# see https://github.com/AUTOMATIC1111/TorchDeepDanbooru for more # see https://github.com/AUTOMATIC1111/TorchDeepDanbooru for more
...@@ -196,7 +198,7 @@ class DeepDanbooruModel(nn.Module): ...@@ -196,7 +198,7 @@ class DeepDanbooruModel(nn.Module):
t_358, = inputs t_358, = inputs
t_359 = t_358.permute(*[0, 3, 1, 2]) t_359 = t_358.permute(*[0, 3, 1, 2])
t_359_padded = F.pad(t_359, [2, 3, 2, 3], value=0) t_359_padded = F.pad(t_359, [2, 3, 2, 3], value=0)
t_360 = self.n_Conv_0(t_359_padded) t_360 = self.n_Conv_0(t_359_padded.to(self.n_Conv_0.bias.dtype) if devices.unet_needs_upcast else t_359_padded)
t_361 = F.relu(t_360) t_361 = F.relu(t_360)
t_361 = F.pad(t_361, [0, 1, 0, 1], value=float('-inf')) t_361 = F.pad(t_361, [0, 1, 0, 1], value=float('-inf'))
t_362 = self.n_MaxPool_0(t_361) t_362 = self.n_MaxPool_0(t_361)
......
...@@ -34,14 +34,18 @@ def get_cuda_device_string(): ...@@ -34,14 +34,18 @@ def get_cuda_device_string():
return "cuda" return "cuda"
def get_optimal_device(): def get_optimal_device_name():
if torch.cuda.is_available(): if torch.cuda.is_available():
return torch.device(get_cuda_device_string()) return get_cuda_device_string()
if has_mps(): if has_mps():
return torch.device("mps") return "mps"
return "cpu"
return cpu def get_optimal_device():
return torch.device(get_optimal_device_name())
def get_device_for(task): def get_device_for(task):
...@@ -79,6 +83,16 @@ cpu = torch.device("cpu") ...@@ -79,6 +83,16 @@ cpu = torch.device("cpu")
device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None
dtype = torch.float16 dtype = torch.float16
dtype_vae = torch.float16 dtype_vae = torch.float16
dtype_unet = torch.float16
unet_needs_upcast = False
def cond_cast_unet(input):
return input.to(dtype_unet) if unet_needs_upcast else input
def cond_cast_float(input):
return input.float() if unet_needs_upcast else input
def randn(seed, shape): def randn(seed, shape):
...@@ -106,6 +120,10 @@ def autocast(disable=False): ...@@ -106,6 +120,10 @@ def autocast(disable=False):
return torch.autocast("cuda") return torch.autocast("cuda")
def without_autocast(disable=False):
return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext()
class NansException(Exception): class NansException(Exception):
pass pass
...@@ -123,7 +141,7 @@ def test_for_nans(x, where): ...@@ -123,7 +141,7 @@ def test_for_nans(x, where):
message = "A tensor with all NaNs was produced in Unet." message = "A tensor with all NaNs was produced in Unet."
if not shared.cmd_opts.no_half: if not shared.cmd_opts.no_half:
message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try using --no-half commandline argument to fix this." message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try setting the \"Upcast cross attention layer to float32\" option in Settings > Stable Diffusion or using the --no-half commandline argument to fix this."
elif where == "vae": elif where == "vae":
message = "A tensor with all NaNs was produced in VAE." message = "A tensor with all NaNs was produced in VAE."
...@@ -133,6 +151,8 @@ def test_for_nans(x, where): ...@@ -133,6 +151,8 @@ def test_for_nans(x, where):
else: else:
message = "A tensor with all NaNs was produced." message = "A tensor with all NaNs was produced."
message += " Use --disable-nan-check commandline argument to disable this check."
raise NansException(message) raise NansException(message)
...@@ -187,6 +207,3 @@ if has_mps(): ...@@ -187,6 +207,3 @@ if has_mps():
cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0)) cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0))
torch.cumsum = lambda input, *args, **kwargs: ( cumsum_fix(input, orig_cumsum, *args, **kwargs) ) torch.cumsum = lambda input, *args, **kwargs: ( cumsum_fix(input, orig_cumsum, *args, **kwargs) )
torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) ) torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) )
orig_narrow = torch.narrow
torch.narrow = lambda *args, **kwargs: ( orig_narrow(*args, **kwargs).clone() )
...@@ -7,9 +7,11 @@ import git ...@@ -7,9 +7,11 @@ import git
from modules import paths, shared from modules import paths, shared
extensions = [] extensions = []
extensions_dir = os.path.join(paths.script_path, "extensions") extensions_dir = os.path.join(paths.data_path, "extensions")
extensions_builtin_dir = os.path.join(paths.script_path, "extensions-builtin") extensions_builtin_dir = os.path.join(paths.script_path, "extensions-builtin")
if not os.path.exists(extensions_dir):
os.makedirs(extensions_dir)
def active(): def active():
return [x for x in extensions if x.enabled] return [x for x in extensions if x.enabled]
......
from modules import extra_networks from modules import extra_networks, shared, extra_networks
from modules.hypernetworks import hypernetwork from modules.hypernetworks import hypernetwork
...@@ -7,6 +7,12 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork): ...@@ -7,6 +7,12 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
super().__init__('hypernet') super().__init__('hypernet')
def activate(self, p, params_list): def activate(self, p, params_list):
additional = shared.opts.sd_hypernetwork
if additional != "" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0:
p.all_prompts = [x + f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
names = [] names = []
multipliers = [] multipliers = []
for params in params_list: for params in params_list:
......
...@@ -6,7 +6,7 @@ import shutil ...@@ -6,7 +6,7 @@ import shutil
import torch import torch
import tqdm import tqdm
from modules import shared, images, sd_models, sd_vae from modules import shared, images, sd_models, sd_vae, sd_models_config
from modules.ui_common import plaintext_to_html from modules.ui_common import plaintext_to_html
import gradio as gr import gradio as gr
import safetensors.torch import safetensors.torch
...@@ -37,7 +37,7 @@ def run_pnginfo(image): ...@@ -37,7 +37,7 @@ def run_pnginfo(image):
def create_config(ckpt_result, config_source, a, b, c): def create_config(ckpt_result, config_source, a, b, c):
def config(x): def config(x):
res = sd_models.find_checkpoint_config(x) if x else None res = sd_models_config.find_checkpoint_config_near_filename(x) if x else None
return res if res != shared.sd_default_config else None return res if res != shared.sd_default_config else None
if config_source == 0: if config_source == 0:
...@@ -132,6 +132,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ ...@@ -132,6 +132,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None
result_is_inpainting_model = False result_is_inpainting_model = False
result_is_instruct_pix2pix_model = False
if theta_func2: if theta_func2:
shared.state.textinfo = f"Loading B" shared.state.textinfo = f"Loading B"
...@@ -185,14 +186,19 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ ...@@ -185,14 +186,19 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
if a.shape != b.shape and a.shape[0:1] + a.shape[2:] == b.shape[0:1] + b.shape[2:]: if a.shape != b.shape and a.shape[0:1] + a.shape[2:] == b.shape[0:1] + b.shape[2:]:
if a.shape[1] == 4 and b.shape[1] == 9: if a.shape[1] == 4 and b.shape[1] == 9:
raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.") raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.")
if a.shape[1] == 4 and b.shape[1] == 8:
raise RuntimeError("When merging instruct-pix2pix model with a normal one, A must be the instruct-pix2pix model.")
assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}" if a.shape[1] == 8 and b.shape[1] == 4:#If we have an Instruct-Pix2Pix model...
theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)#Merge only the vectors the models have in common. Otherwise we get an error due to dimension mismatch.
theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier) result_is_instruct_pix2pix_model = True
result_is_inpainting_model = True else:
assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}"
theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)
result_is_inpainting_model = True
else: else:
theta_0[key] = theta_func2(a, b, multiplier) theta_0[key] = theta_func2(a, b, multiplier)
theta_0[key] = to_half(theta_0[key], save_as_half) theta_0[key] = to_half(theta_0[key], save_as_half)
shared.state.sampling_step += 1 shared.state.sampling_step += 1
...@@ -226,6 +232,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ ...@@ -226,6 +232,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
filename = filename_generator() if custom_name == '' else custom_name filename = filename_generator() if custom_name == '' else custom_name
filename += ".inpainting" if result_is_inpainting_model else "" filename += ".inpainting" if result_is_inpainting_model else ""
filename += ".instruct-pix2pix" if result_is_instruct_pix2pix_model else ""
filename += "." + checkpoint_format filename += "." + checkpoint_format
output_modelname = os.path.join(ckpt_dir, filename) output_modelname = os.path.join(ckpt_dir, filename)
......
...@@ -6,14 +6,13 @@ import re ...@@ -6,14 +6,13 @@ import re
from pathlib import Path from pathlib import Path
import gradio as gr import gradio as gr
from modules.shared import script_path from modules.paths import data_path
from modules import shared, ui_tempdir, script_callbacks from modules import shared, ui_tempdir, script_callbacks
import tempfile import tempfile
from PIL import Image from PIL import Image
re_param_code = r'\s*([\w ]+):\s*("(?:\\|\"|[^\"])+"|[^,]*)(?:,|$)' re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)'
re_param = re.compile(re_param_code) re_param = re.compile(re_param_code)
re_params = re.compile(r"^(?:" + re_param_code + "){3,}$")
re_imagesize = re.compile(r"^(\d+)x(\d+)$") re_imagesize = re.compile(r"^(\d+)x(\d+)$")
re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$") re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$")
type_of_gr_update = type(gr.update()) type_of_gr_update = type(gr.update())
...@@ -243,7 +242,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model ...@@ -243,7 +242,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
done_with_prompt = False done_with_prompt = False
*lines, lastline = x.strip().split("\n") *lines, lastline = x.strip().split("\n")
if not re_params.match(lastline): if len(re_param.findall(lastline)) < 3:
lines.append(lastline) lines.append(lastline)
lastline = '' lastline = ''
...@@ -261,6 +260,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model ...@@ -261,6 +260,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
res["Negative prompt"] = negative_prompt res["Negative prompt"] = negative_prompt
for k, v in re_param.findall(lastline): for k, v in re_param.findall(lastline):
v = v[1:-1] if v[0] == '"' and v[-1] == '"' else v
m = re_imagesize.match(v) m = re_imagesize.match(v)
if m is not None: if m is not None:
res[k+"-1"] = m.group(1) res[k+"-1"] = m.group(1)
...@@ -293,7 +293,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model ...@@ -293,7 +293,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
def connect_paste(button, paste_fields, input_comp, jsfunc=None): def connect_paste(button, paste_fields, input_comp, jsfunc=None):
def paste_func(prompt): def paste_func(prompt):
if not prompt and not shared.cmd_opts.hide_ui_dir_config: if not prompt and not shared.cmd_opts.hide_ui_dir_config:
filename = os.path.join(script_path, "params.txt") filename = os.path.join(data_path, "params.txt")
if os.path.exists(filename): if os.path.exists(filename):
with open(filename, "r", encoding="utf8") as file: with open(filename, "r", encoding="utf8") as file:
prompt = file.read() prompt = file.read()
......
...@@ -6,12 +6,11 @@ import facexlib ...@@ -6,12 +6,11 @@ import facexlib
import gfpgan import gfpgan
import modules.face_restoration import modules.face_restoration
from modules import shared, devices, modelloader from modules import paths, shared, devices, modelloader
from modules.paths import models_path
model_dir = "GFPGAN" model_dir = "GFPGAN"
user_path = None user_path = None
model_path = os.path.join(models_path, model_dir) model_path = os.path.join(paths.models_path, model_dir)
model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth" model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
have_gfpgan = False have_gfpgan = False
loaded_gfpgan_model = None loaded_gfpgan_model = None
......
...@@ -4,8 +4,10 @@ import os.path ...@@ -4,8 +4,10 @@ import os.path
import filelock import filelock
from modules.paths import data_path
cache_filename = "cache.json"
cache_filename = os.path.join(data_path, "cache.json")
cache_data = None cache_data = None
......
...@@ -36,6 +36,8 @@ def image_grid(imgs, batch_size=1, rows=None): ...@@ -36,6 +36,8 @@ def image_grid(imgs, batch_size=1, rows=None):
else: else:
rows = math.sqrt(len(imgs)) rows = math.sqrt(len(imgs))
rows = round(rows) rows = round(rows)
if rows > len(imgs):
rows = len(imgs)
cols = math.ceil(len(imgs) / rows) cols = math.ceil(len(imgs) / rows)
...@@ -195,7 +197,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts): ...@@ -195,7 +197,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in
ver_texts] ver_texts]
pad_top = max(hor_text_heights) + line_spacing * 2 pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2
result = Image.new("RGB", (im.width + pad_left, im.height + pad_top), "white") result = Image.new("RGB", (im.width + pad_left, im.height + pad_top), "white")
result.paste(im, (pad_left, pad_top)) result.paste(im, (pad_left, pad_top))
......
...@@ -16,11 +16,18 @@ import modules.images as images ...@@ -16,11 +16,18 @@ import modules.images as images
import modules.scripts import modules.scripts
def process_batch(p, input_dir, output_dir, args): def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
processing.fix_seed(p) processing.fix_seed(p)
images = shared.listfiles(input_dir) images = shared.listfiles(input_dir)
is_inpaint_batch = False
if inpaint_mask_dir:
inpaint_masks = shared.listfiles(inpaint_mask_dir)
is_inpaint_batch = len(inpaint_masks) > 0
if is_inpaint_batch:
print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.")
print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.") print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
save_normally = output_dir == '' save_normally = output_dir == ''
...@@ -43,6 +50,15 @@ def process_batch(p, input_dir, output_dir, args): ...@@ -43,6 +50,15 @@ def process_batch(p, input_dir, output_dir, args):
img = ImageOps.exif_transpose(img) img = ImageOps.exif_transpose(img)
p.init_images = [img] * p.batch_size p.init_images = [img] * p.batch_size
if is_inpaint_batch:
# try to find corresponding mask for an image using simple filename matching
mask_image_path = os.path.join(inpaint_mask_dir, os.path.basename(image))
# if not found use first one ("same mask for all images" use-case)
if not mask_image_path in inpaint_masks:
mask_image_path = inpaint_masks[0]
mask_image = Image.open(mask_image_path)
p.image_mask = mask_image
proc = modules.scripts.scripts_img2img.run(p, *args) proc = modules.scripts.scripts_img2img.run(p, *args)
if proc is None: if proc is None:
proc = process_images(p) proc = process_images(p)
...@@ -59,7 +75,7 @@ def process_batch(p, input_dir, output_dir, args): ...@@ -59,7 +75,7 @@ def process_batch(p, input_dir, output_dir, args):
processed_image.save(os.path.join(output_dir, filename)) processed_image.save(os.path.join(output_dir, filename))
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args): def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, *args):
is_batch = mode == 5 is_batch = mode == 5
if mode == 0: # img2img if mode == 0: # img2img
...@@ -139,7 +155,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s ...@@ -139,7 +155,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
if is_batch: if is_batch:
assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled" assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, args) process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args)
processed = Processed(p, [], p.seed, "") processed = Processed(p, [], p.seed, "")
else: else:
......
...@@ -12,7 +12,7 @@ from torchvision import transforms ...@@ -12,7 +12,7 @@ from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
import modules.shared as shared import modules.shared as shared
from modules import devices, paths, lowvram, modelloader, errors from modules import devices, paths, shared, lowvram, modelloader, errors
blip_image_eval_size = 384 blip_image_eval_size = 384
clip_model_name = 'ViT-L/14' clip_model_name = 'ViT-L/14'
...@@ -82,9 +82,16 @@ class InterrogateModels: ...@@ -82,9 +82,16 @@ class InterrogateModels:
return self.loaded_categories return self.loaded_categories
def create_fake_fairscale(self):
class FakeFairscale:
def checkpoint_wrapper(self):
pass
sys.modules["fairscale.nn.checkpoint.checkpoint_activations"] = FakeFairscale
def load_blip_model(self): def load_blip_model(self):
with paths.Prioritize("BLIP"): self.create_fake_fairscale()
import models.blip import models.blip
files = modelloader.load_models( files = modelloader.load_models(
model_path=os.path.join(paths.models_path, "BLIP"), model_path=os.path.join(paths.models_path, "BLIP"),
......
This diff is collapsed.
...@@ -4,7 +4,15 @@ import sys ...@@ -4,7 +4,15 @@ import sys
import modules.safe import modules.safe
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
models_path = os.path.join(script_path, "models")
# Parse the --data-dir flag first so we can use it as a base for our other argument default values
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",)
cmd_opts_pre = parser.parse_known_args()[0]
data_path = cmd_opts_pre.data_dir
models_path = os.path.join(data_path, "models")
# data_path = cmd_opts_pre.data
sys.path.insert(0, script_path) sys.path.insert(0, script_path)
# search for directory of stable diffusion in following places # search for directory of stable diffusion in following places
......
...@@ -13,10 +13,11 @@ from skimage import exposure ...@@ -13,10 +13,11 @@ from skimage import exposure
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import modules.sd_hijack import modules.sd_hijack
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx, scripts
from modules.sd_hijack import model_hijack from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
import modules.shared as shared import modules.shared as shared
import modules.paths as paths
import modules.face_restoration import modules.face_restoration
import modules.images as images import modules.images as images
import modules.styles import modules.styles
...@@ -184,7 +185,12 @@ class StableDiffusionProcessing: ...@@ -184,7 +185,12 @@ class StableDiffusionProcessing:
conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1. conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1.
return conditioning return conditioning
def inpainting_image_conditioning(self, source_image, latent_image, image_mask = None): def edit_image_conditioning(self, source_image):
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image))
return conditioning_image
def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None):
self.is_using_inpainting_conditioning = True self.is_using_inpainting_conditioning = True
# Handle the different mask inputs # Handle the different mask inputs
...@@ -203,7 +209,7 @@ class StableDiffusionProcessing: ...@@ -203,7 +209,7 @@ class StableDiffusionProcessing:
# Create another latent image, this time with a masked version of the original input. # Create another latent image, this time with a masked version of the original input.
# Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter. # Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
conditioning_mask = conditioning_mask.to(source_image.device).to(source_image.dtype) conditioning_mask = conditioning_mask.to(device=source_image.device, dtype=source_image.dtype)
conditioning_image = torch.lerp( conditioning_image = torch.lerp(
source_image, source_image,
source_image * (1.0 - conditioning_mask), source_image * (1.0 - conditioning_mask),
...@@ -222,11 +228,16 @@ class StableDiffusionProcessing: ...@@ -222,11 +228,16 @@ class StableDiffusionProcessing:
return image_conditioning return image_conditioning
def img2img_image_conditioning(self, source_image, latent_image, image_mask=None): def img2img_image_conditioning(self, source_image, latent_image, image_mask=None):
source_image = devices.cond_cast_float(source_image)
# HACK: Using introspection as the Depth2Image model doesn't appear to uniquely # HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
# identify itself with a field common to all models. The conditioning_key is also hybrid. # identify itself with a field common to all models. The conditioning_key is also hybrid.
if isinstance(self.sd_model, LatentDepth2ImageDiffusion): if isinstance(self.sd_model, LatentDepth2ImageDiffusion):
return self.depth2img_image_conditioning(source_image) return self.depth2img_image_conditioning(source_image)
if self.sd_model.cond_stage_key == "edit":
return self.edit_image_conditioning(source_image)
if self.sampler.conditioning_key in {'hybrid', 'concat'}: if self.sampler.conditioning_key in {'hybrid', 'concat'}:
return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask) return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
...@@ -439,8 +450,6 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter ...@@ -439,8 +450,6 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Size": f"{p.width}x{p.height}", "Size": f"{p.width}x{p.height}",
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')), "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
"Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]), "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
...@@ -580,10 +589,14 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: ...@@ -580,10 +589,14 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
with devices.autocast(): with devices.autocast():
p.init(p.all_prompts, p.all_seeds, p.all_subseeds) p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
# for OSX, loading the model during sampling changes the generated picture, so it is loaded here
if shared.opts.live_previews_enable and opts.show_progress_type == "Approx NN":
sd_vae_approx.model()
if not p.disable_extra_networks: if not p.disable_extra_networks:
extra_networks.activate(p, extra_network_data) extra_networks.activate(p, extra_network_data)
with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file: with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file:
processed = Processed(p, [], p.seed, "") processed = Processed(p, [], p.seed, "")
file.write(processed.infotext(p, 0)) file.write(processed.infotext(p, 0))
...@@ -634,7 +647,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: ...@@ -634,7 +647,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.n_iter > 1: if p.n_iter > 1:
shared.state.job = f"Batch {n+1} out of {p.n_iter}" shared.state.job = f"Batch {n+1} out of {p.n_iter}"
with devices.autocast():
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
if type(p) == StableDiffusionProcessingTxt2Img: if type(p) == StableDiffusionProcessingTxt2Img:
if p.enable_hr: if p.enable_hr:
if p.hr_prompt != '': if p.hr_prompt != '':
...@@ -684,6 +698,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: ...@@ -684,6 +698,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
image = Image.fromarray(x_sample) image = Image.fromarray(x_sample)
if p.scripts is not None:
pp = scripts.PostprocessImageArgs(image)
p.scripts.postprocess_image(p, pp)
image = pp.image
if p.color_corrections is not None and i < len(p.color_corrections): if p.color_corrections is not None and i < len(p.color_corrections):
if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction: if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images) image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
......
...@@ -46,7 +46,7 @@ class UpscalerRealESRGAN(Upscaler): ...@@ -46,7 +46,7 @@ class UpscalerRealESRGAN(Upscaler):
scale=info.scale, scale=info.scale,
model_path=info.local_data_path, model_path=info.local_data_path,
model=info.model(), model=info.model(),
half=not cmd_opts.no_half, half=not cmd_opts.no_half and not cmd_opts.upcast_sampling,
tile=opts.ESRGAN_tile, tile=opts.ESRGAN_tile,
tile_pad=opts.ESRGAN_tile_overlap, tile_pad=opts.ESRGAN_tile_overlap,
) )
......
import os import os
import sys import sys
import traceback import traceback
import importlib.util
from types import ModuleType from types import ModuleType
def load_module(path): def load_module(path):
with open(path, "r", encoding="utf8") as file: module_spec = importlib.util.spec_from_file_location(os.path.basename(path), path)
text = file.read() module = importlib.util.module_from_spec(module_spec)
module_spec.loader.exec_module(module)
compiled = compile(text, path, 'exec')
module = ModuleType(os.path.basename(path))
exec(compiled, module.__dict__)
return module return module
......
...@@ -6,12 +6,16 @@ from collections import namedtuple ...@@ -6,12 +6,16 @@ from collections import namedtuple
import gradio as gr import gradio as gr
from modules.processing import StableDiffusionProcessing
from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing
AlwaysVisible = object() AlwaysVisible = object()
class PostprocessImageArgs:
def __init__(self, image):
self.image = image
class Script: class Script:
filename = None filename = None
args_from = None args_from = None
...@@ -65,7 +69,7 @@ class Script: ...@@ -65,7 +69,7 @@ class Script:
args contains all values returned by components from ui() args contains all values returned by components from ui()
""" """
raise NotImplementedError() pass
def process(self, p, *args): def process(self, p, *args):
""" """
...@@ -100,6 +104,13 @@ class Script: ...@@ -100,6 +104,13 @@ class Script:
pass pass
def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
"""
Called for every image after it has been generated.
"""
pass
def postprocess(self, p, processed, *args): def postprocess(self, p, processed, *args):
""" """
This function is called after processing ends for AlwaysVisible scripts. This function is called after processing ends for AlwaysVisible scripts.
...@@ -247,11 +258,15 @@ class ScriptRunner: ...@@ -247,11 +258,15 @@ class ScriptRunner:
self.infotext_fields = [] self.infotext_fields = []
def initialize_scripts(self, is_img2img): def initialize_scripts(self, is_img2img):
from modules import scripts_auto_postprocessing
self.scripts.clear() self.scripts.clear()
self.alwayson_scripts.clear() self.alwayson_scripts.clear()
self.selectable_scripts.clear() self.selectable_scripts.clear()
for script_class, path, basedir, script_module in scripts_data: auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()
for script_class, path, basedir, script_module in auto_processing_scripts + scripts_data:
script = script_class() script = script_class()
script.filename = path script.filename = path
script.is_txt2img = not is_img2img script.is_txt2img = not is_img2img
...@@ -330,9 +345,23 @@ class ScriptRunner: ...@@ -330,9 +345,23 @@ class ScriptRunner:
outputs=[script.group for script in self.selectable_scripts] outputs=[script.group for script in self.selectable_scripts]
) )
self.script_load_ctr = 0
def onload_script_visibility(params):
title = params.get('Script', None)
if title:
title_index = self.titles.index(title)
visibility = title_index == self.script_load_ctr
self.script_load_ctr = (self.script_load_ctr + 1) % len(self.titles)
return gr.update(visible=visibility)
else:
return gr.update(visible=False)
self.infotext_fields.append( (dropdown, lambda x: gr.update(value=x.get('Script', 'None'))) )
self.infotext_fields.extend( [(script.group, onload_script_visibility) for script in self.selectable_scripts] )
return inputs return inputs
def run(self, p: StableDiffusionProcessing, *args): def run(self, p, *args):
script_index = args[0] script_index = args[0]
if script_index == 0: if script_index == 0:
...@@ -386,6 +415,15 @@ class ScriptRunner: ...@@ -386,6 +415,15 @@ class ScriptRunner:
print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr) print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
def postprocess_image(self, p, pp: PostprocessImageArgs):
for script in self.alwayson_scripts:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.postprocess_image(p, pp, *script_args)
except Exception:
print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
def before_component(self, component, **kwargs): def before_component(self, component, **kwargs):
for script in self.scripts: for script in self.scripts:
try: try:
......
from modules import scripts, scripts_postprocessing, shared
class ScriptPostprocessingForMainUI(scripts.Script):
def __init__(self, script_postproc):
self.script: scripts_postprocessing.ScriptPostprocessing = script_postproc
self.postprocessing_controls = None
def title(self):
return self.script.name
def show(self, is_img2img):
return scripts.AlwaysVisible
def ui(self, is_img2img):
self.postprocessing_controls = self.script.ui()
return self.postprocessing_controls.values()
def postprocess_image(self, p, script_pp, *args):
args_dict = {k: v for k, v in zip(self.postprocessing_controls, args)}
pp = scripts_postprocessing.PostprocessedImage(script_pp.image)
pp.info = {}
self.script.process(pp, **args_dict)
p.extra_generation_params.update(pp.info)
script_pp.image = pp.image
def create_auto_preprocessing_script_data():
from modules import scripts
res = []
for name in shared.opts.postprocessing_enable_in_main_ui:
script = next(iter([x for x in scripts.postprocessing_scripts_data if x.script_class.name == name]), None)
if script is None:
continue
constructor = lambda s=script: ScriptPostprocessingForMainUI(s.script_class())
res.append(scripts.ScriptClassData(script_class=constructor, path=script.path, basedir=script.basedir, module=script.module))
return res
...@@ -46,6 +46,8 @@ class ScriptPostprocessing: ...@@ -46,6 +46,8 @@ class ScriptPostprocessing:
pass pass
def wrap_call(func, filename, funcname, *args, default=None, **kwargs): def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
try: try:
res = func(*args, **kwargs) res = func(*args, **kwargs)
...@@ -68,6 +70,9 @@ class ScriptPostprocessingRunner: ...@@ -68,6 +70,9 @@ class ScriptPostprocessingRunner:
script: ScriptPostprocessing = script_class() script: ScriptPostprocessing = script_class()
script.filename = path script.filename = path
if script.name == "Simple Upscale":
continue
self.scripts.append(script) self.scripts.append(script)
def create_script_ui(self, script, inputs): def create_script_ui(self, script, inputs):
...@@ -87,12 +92,11 @@ class ScriptPostprocessingRunner: ...@@ -87,12 +92,11 @@ class ScriptPostprocessingRunner:
import modules.scripts import modules.scripts
self.initialize_scripts(modules.scripts.postprocessing_scripts_data) self.initialize_scripts(modules.scripts.postprocessing_scripts_data)
scripts_order = [x.lower().strip() for x in shared.opts.postprocessing_scipts_order.split(",")] scripts_order = shared.opts.postprocessing_operation_order
def script_score(name): def script_score(name):
name = name.lower()
for i, possible_match in enumerate(scripts_order): for i, possible_match in enumerate(scripts_order):
if possible_match in name: if possible_match == name:
return i return i
return len(self.scripts) return len(self.scripts)
...@@ -145,3 +149,4 @@ class ScriptPostprocessingRunner: ...@@ -145,3 +149,4 @@ class ScriptPostprocessingRunner:
def image_changed(self): def image_changed(self):
for script in self.scripts_in_preferred_order(): for script in self.scripts_in_preferred_order():
script.image_changed() script.image_changed()
...@@ -131,6 +131,8 @@ class StableDiffusionModelHijack: ...@@ -131,6 +131,8 @@ class StableDiffusionModelHijack:
m.cond_stage_model.wrapped.model.token_embedding = m.cond_stage_model.wrapped.model.token_embedding.wrapped m.cond_stage_model.wrapped.model.token_embedding = m.cond_stage_model.wrapped.model.token_embedding.wrapped
m.cond_stage_model = m.cond_stage_model.wrapped m.cond_stage_model = m.cond_stage_model.wrapped
undo_optimizations()
self.apply_circular(False) self.apply_circular(False)
self.layers = None self.layers = None
self.clip = None self.clip = None
...@@ -171,7 +173,7 @@ class EmbeddingsWithFixes(torch.nn.Module): ...@@ -171,7 +173,7 @@ class EmbeddingsWithFixes(torch.nn.Module):
vecs = [] vecs = []
for fixes, tensor in zip(batch_fixes, inputs_embeds): for fixes, tensor in zip(batch_fixes, inputs_embeds):
for offset, embedding in fixes: for offset, embedding in fixes:
emb = embedding.vec emb = devices.cond_cast_unet(embedding.vec)
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0]) emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]) tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
......
...@@ -96,15 +96,6 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F ...@@ -96,15 +96,6 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
return x_prev, pred_x0, e_t return x_prev, pred_x0, e_t
def should_hijack_inpainting(checkpoint_info):
from modules import sd_models
ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
cfg_basename = os.path.basename(sd_models.find_checkpoint_config(checkpoint_info)).lower()
return "inpainting" in ckpt_basename and not "inpainting" in cfg_basename
def do_inpainting_hijack(): def do_inpainting_hijack():
# p_sample_plms is needed because PLMS can't work with dicts as conditionings # p_sample_plms is needed because PLMS can't work with dicts as conditionings
......
import collections
import os.path
import sys
import gc
import time
def should_hijack_ip2p(checkpoint_info):
from modules import sd_models_config
ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
cfg_basename = os.path.basename(sd_models_config.find_checkpoint_config_near_filename(checkpoint_info)).lower()
return "pix2pix" in ckpt_basename and not "pix2pix" in cfg_basename
...@@ -9,7 +9,7 @@ from torch import einsum ...@@ -9,7 +9,7 @@ from torch import einsum
from ldm.util import default from ldm.util import default
from einops import rearrange from einops import rearrange
from modules import shared, errors from modules import shared, errors, devices
from modules.hypernetworks import hypernetwork from modules.hypernetworks import hypernetwork
from .sub_quadratic_attention import efficient_dot_product_attention from .sub_quadratic_attention import efficient_dot_product_attention
...@@ -52,18 +52,25 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None): ...@@ -52,18 +52,25 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in del q_in, k_in, v_in
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) dtype = q.dtype
for i in range(0, q.shape[0], 2): if shared.opts.upcast_attn:
end = i + 2 q, k, v = q.float(), k.float(), v.float()
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
s1 *= self.scale
s2 = s1.softmax(dim=-1) with devices.without_autocast(disable=not shared.opts.upcast_attn):
del s1 r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[0], 2):
end = i + 2
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
s1 *= self.scale
s2 = s1.softmax(dim=-1)
del s1
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
del s2
del q, k, v
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) r1 = r1.to(dtype)
del s2
del q, k, v
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1 del r1
...@@ -82,45 +89,52 @@ def split_cross_attention_forward(self, x, context=None, mask=None): ...@@ -82,45 +89,52 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
k_in = self.to_k(context_k) k_in = self.to_k(context_k)
v_in = self.to_v(context_v) v_in = self.to_v(context_v)
k_in *= self.scale dtype = q_in.dtype
if shared.opts.upcast_attn:
del context, x q_in, k_in, v_in = q_in.float(), k_in.float(), v_in if v_in.device.type == 'mps' else v_in.float()
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
mem_free_total = get_available_vram()
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
modifier = 3 if q.element_size() == 2 else 2.5
mem_required = tensor_size * modifier
steps = 1
if mem_required > mem_free_total:
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
if steps > 64: with devices.without_autocast(disable=not shared.opts.upcast_attn):
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 k_in = k_in * self.scale
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free') del context, x
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
for i in range(0, q.shape[1], slice_size): del q_in, k_in, v_in
end = i + slice_size
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
s2 = s1.softmax(dim=-1, dtype=q.dtype) mem_free_total = get_available_vram()
del s1
gb = 1024 ** 3
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
del s2 modifier = 3 if q.element_size() == 2 else 2.5
mem_required = tensor_size * modifier
steps = 1
if mem_required > mem_free_total:
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
if steps > 64:
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
s2 = s1.softmax(dim=-1, dtype=q.dtype)
del s1
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
del q, k, v
del q, k, v r1 = r1.to(dtype)
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1 del r1
...@@ -204,12 +218,20 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None): ...@@ -204,12 +218,20 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
context = default(context, x) context = default(context, x)
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
k = self.to_k(context_k) * self.scale k = self.to_k(context_k)
v = self.to_v(context_v) v = self.to_v(context_v)
del context, context_k, context_v, x del context, context_k, context_v, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) dtype = q.dtype
r = einsum_op(q, k, v) if shared.opts.upcast_attn:
q, k, v = q.float(), k.float(), v if v.device.type == 'mps' else v.float()
with devices.without_autocast(disable=not shared.opts.upcast_attn):
k = k * self.scale
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
r = einsum_op(q, k, v)
r = r.to(dtype)
return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h)) return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
# -- End of code from https://github.com/invoke-ai/InvokeAI -- # -- End of code from https://github.com/invoke-ai/InvokeAI --
...@@ -234,8 +256,14 @@ def sub_quad_attention_forward(self, x, context=None, mask=None): ...@@ -234,8 +256,14 @@ def sub_quad_attention_forward(self, x, context=None, mask=None):
k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
dtype = q.dtype
if shared.opts.upcast_attn:
q, k = q.float(), k.float()
x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
x = x.to(dtype)
x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2) x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2)
out_proj, dropout = self.to_out out_proj, dropout = self.to_out
...@@ -268,15 +296,16 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_ ...@@ -268,15 +296,16 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_
query_chunk_size = q_tokens query_chunk_size = q_tokens
kv_chunk_size = k_tokens kv_chunk_size = k_tokens
return efficient_dot_product_attention( with devices.without_autocast(disable=q.dtype == v.dtype):
q, return efficient_dot_product_attention(
k, q,
v, k,
query_chunk_size=q_chunk_size, v,
kv_chunk_size=kv_chunk_size, query_chunk_size=q_chunk_size,
kv_chunk_size_min = kv_chunk_size_min, kv_chunk_size=kv_chunk_size,
use_checkpoint=use_checkpoint, kv_chunk_size_min = kv_chunk_size_min,
) use_checkpoint=use_checkpoint,
)
def get_xformers_flash_attention_op(q, k, v): def get_xformers_flash_attention_op(q, k, v):
...@@ -306,8 +335,14 @@ def xformers_attention_forward(self, x, context=None, mask=None): ...@@ -306,8 +335,14 @@ def xformers_attention_forward(self, x, context=None, mask=None):
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in del q_in, k_in, v_in
dtype = q.dtype
if shared.opts.upcast_attn:
q, k = q.float(), k.float()
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v)) out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v))
out = out.to(dtype)
out = rearrange(out, 'b n h d -> b n (h d)', h=h) out = rearrange(out, 'b n h d -> b n (h d)', h=h)
return self.to_out(out) return self.to_out(out)
...@@ -378,10 +413,14 @@ def xformers_attnblock_forward(self, x): ...@@ -378,10 +413,14 @@ def xformers_attnblock_forward(self, x):
v = self.v(h_) v = self.v(h_)
b, c, h, w = q.shape b, c, h, w = q.shape
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v)) q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
dtype = q.dtype
if shared.opts.upcast_attn:
q, k = q.float(), k.float()
q = q.contiguous() q = q.contiguous()
k = k.contiguous() k = k.contiguous()
v = v.contiguous() v = v.contiguous()
out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v)) out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v))
out = out.to(dtype)
out = rearrange(out, 'b (h w) c -> b c h w', h=h) out = rearrange(out, 'b (h w) c -> b c h w', h=h)
out = self.proj_out(out) out = self.proj_out(out)
return x + out return x + out
......
import torch import torch
from packaging import version
from modules import devices
from modules.sd_hijack_utils import CondFunc
class TorchHijackForUnet: class TorchHijackForUnet:
...@@ -28,3 +32,37 @@ class TorchHijackForUnet: ...@@ -28,3 +32,37 @@ class TorchHijackForUnet:
th = TorchHijackForUnet() th = TorchHijackForUnet()
# Below are monkey patches to enable upcasting a float16 UNet for float32 sampling
def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
if isinstance(cond, dict):
for y in cond.keys():
cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
with devices.autocast():
return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()
class GELUHijack(torch.nn.GELU, torch.nn.Module):
def __init__(self, *args, **kwargs):
torch.nn.GELU.__init__(self, *args, **kwargs)
def forward(self, x):
if devices.unet_needs_upcast:
return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_unet)
else:
return torch.nn.GELU.forward(self, x)
unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
if version.parse(torch.__version__) <= version.parse("1.13.1"):
CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)
first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16
first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devices.dtype_vae), **kwargs)
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond)
import importlib
class CondFunc:
def __new__(cls, orig_func, sub_func, cond_func):
self = super(CondFunc, cls).__new__(cls)
if isinstance(orig_func, str):
func_path = orig_func.split('.')
for i in range(len(func_path)-1, -1, -1):
try:
resolved_obj = importlib.import_module('.'.join(func_path[:i]))
break
except ImportError:
pass
for attr_name in func_path[i:-1]:
resolved_obj = getattr(resolved_obj, attr_name)
orig_func = getattr(resolved_obj, func_path[-1])
setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
self.__init__(orig_func, sub_func, cond_func)
return lambda *args, **kwargs: self(*args, **kwargs)
def __init__(self, orig_func, sub_func, cond_func):
self.__orig_func = orig_func
self.__sub_func = sub_func
self.__cond_func = cond_func
def __call__(self, *args, **kwargs):
if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs):
return self.__sub_func(self.__orig_func, *args, **kwargs)
else:
return self.__orig_func(*args, **kwargs)
This diff is collapsed.
import re
import os
import torch
from modules import shared, paths, sd_disable_initialization
sd_configs_path = shared.sd_configs_path
sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion")
config_default = shared.sd_default_config
config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")
config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml")
config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
def is_using_v_parameterization_for_sd2(state_dict):
"""
Detects whether unet in state_dict is using v-parameterization. Returns True if it is. You're welcome.
"""
import ldm.modules.diffusionmodules.openaimodel
from modules import devices
device = devices.cpu
with sd_disable_initialization.DisableInitialization():
unet = ldm.modules.diffusionmodules.openaimodel.UNetModel(
use_checkpoint=True,
use_fp16=False,
image_size=32,
in_channels=4,
out_channels=4,
model_channels=320,
attention_resolutions=[4, 2, 1],
num_res_blocks=2,
channel_mult=[1, 2, 4, 4],
num_head_channels=64,
use_spatial_transformer=True,
use_linear_in_transformer=True,
transformer_depth=1,
context_dim=1024,
legacy=False
)
unet.eval()
with torch.no_grad():
unet_sd = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if "model.diffusion_model." in k}
unet.load_state_dict(unet_sd, strict=True)
unet.to(device=device, dtype=torch.float)
test_cond = torch.ones((1, 2, 1024), device=device) * 0.5
x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5
out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().item()
return out < -1
def guess_model_config_from_state_dict(sd, filename):
sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None)
diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
if sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
return config_depth_model
if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024:
if diffusion_model_input.shape[1] == 9:
return config_sd2_inpainting
elif is_using_v_parameterization_for_sd2(sd):
return config_sd2v
else:
return config_sd2
if diffusion_model_input is not None:
if diffusion_model_input.shape[1] == 9:
return config_inpainting
if diffusion_model_input.shape[1] == 8:
return config_instruct_pix2pix
if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None:
return config_alt_diffusion
return config_default
def find_checkpoint_config(state_dict, info):
if info is None:
return guess_model_config_from_state_dict(state_dict, "")
config = find_checkpoint_config_near_filename(info)
if config is not None:
return config
return guess_model_config_from_state_dict(state_dict, info.filename)
def find_checkpoint_config_near_filename(info):
if info is None:
return None
config = os.path.splitext(info.filename)[0] + ".yaml"
if os.path.exists(config):
return config
return None
...@@ -454,7 +454,7 @@ class KDiffusionSampler: ...@@ -454,7 +454,7 @@ class KDiffusionSampler:
def initialize(self, p): def initialize(self, p):
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
self.model_wrap.step = 0 self.model_wrap_cfg.step = 0
self.eta = p.eta or opts.eta_ancestral self.eta = p.eta or opts.eta_ancestral
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else []) k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
......
...@@ -3,13 +3,12 @@ import safetensors.torch ...@@ -3,13 +3,12 @@ import safetensors.torch
import os import os
import collections import collections
from collections import namedtuple from collections import namedtuple
from modules import shared, devices, script_callbacks, sd_models from modules import paths, shared, devices, script_callbacks, sd_models
from modules.paths import models_path
import glob import glob
from copy import deepcopy from copy import deepcopy
vae_path = os.path.abspath(os.path.join(models_path, "VAE")) vae_path = os.path.abspath(os.path.join(paths.models_path, "VAE"))
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
vae_dict = {} vae_dict = {}
......
This diff is collapsed.
def realesrgan_models_names():
import modules.realesrgan_model
return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)]
def postprocessing_scripts():
import modules.scripts
return modules.scripts.scripts_postproc.scripts
def sd_vae_items():
import modules.sd_vae
return ["Automatic", "None"] + list(modules.sd_vae.vae_dict)
def refresh_vae_list():
import modules.sd_vae
return modules.sd_vae.refresh_vae_list
...@@ -67,7 +67,7 @@ def _summarize_chunk( ...@@ -67,7 +67,7 @@ def _summarize_chunk(
max_score, _ = torch.max(attn_weights, -1, keepdim=True) max_score, _ = torch.max(attn_weights, -1, keepdim=True)
max_score = max_score.detach() max_score = max_score.detach()
exp_weights = torch.exp(attn_weights - max_score) exp_weights = torch.exp(attn_weights - max_score)
exp_values = torch.bmm(exp_weights, value) exp_values = torch.bmm(exp_weights, value) if query.device.type == 'mps' else torch.bmm(exp_weights, value.to(exp_weights.dtype)).to(value.dtype)
max_score = max_score.squeeze(-1) max_score = max_score.squeeze(-1)
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score) return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
...@@ -129,7 +129,7 @@ def _get_attention_scores_no_kv_chunking( ...@@ -129,7 +129,7 @@ def _get_attention_scores_no_kv_chunking(
) )
attn_probs = attn_scores.softmax(dim=-1) attn_probs = attn_scores.softmax(dim=-1)
del attn_scores del attn_scores
hidden_states_slice = torch.bmm(attn_probs, value) hidden_states_slice = torch.bmm(attn_probs, value) if query.device.type == 'mps' else torch.bmm(attn_probs, value.to(attn_probs.dtype)).to(value.dtype)
return hidden_states_slice return hidden_states_slice
......
...@@ -6,8 +6,7 @@ import sys ...@@ -6,8 +6,7 @@ import sys
import tqdm import tqdm
import time import time
from modules import shared, images, deepbooru from modules import paths, shared, images, deepbooru
from modules.paths import models_path
from modules.shared import opts, cmd_opts from modules.shared import opts, cmd_opts
from modules.textual_inversion import autocrop from modules.textual_inversion import autocrop
...@@ -199,7 +198,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre ...@@ -199,7 +198,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
dnn_model_path = None dnn_model_path = None
try: try:
dnn_model_path = autocrop.download_and_cache_models(os.path.join(models_path, "opencv")) dnn_model_path = autocrop.download_and_cache_models(os.path.join(paths.models_path, "opencv"))
except Exception as e: except Exception as e:
print("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", e) print("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", e)
......
...@@ -112,6 +112,7 @@ class EmbeddingDatabase: ...@@ -112,6 +112,7 @@ class EmbeddingDatabase:
self.skipped_embeddings = {} self.skipped_embeddings = {}
self.expected_shape = -1 self.expected_shape = -1
self.embedding_dirs = {} self.embedding_dirs = {}
self.previously_displayed_embeddings = ()
def add_embedding_dir(self, path): def add_embedding_dir(self, path):
self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path) self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path)
...@@ -194,7 +195,7 @@ class EmbeddingDatabase: ...@@ -194,7 +195,7 @@ class EmbeddingDatabase:
if not os.path.isdir(embdir.path): if not os.path.isdir(embdir.path):
return return
for root, dirs, fns in os.walk(embdir.path): for root, dirs, fns in os.walk(embdir.path, followlinks=True):
for fn in fns: for fn in fns:
try: try:
fullfn = os.path.join(root, fn) fullfn = os.path.join(root, fn)
...@@ -228,9 +229,12 @@ class EmbeddingDatabase: ...@@ -228,9 +229,12 @@ class EmbeddingDatabase:
self.load_from_dir(embdir) self.load_from_dir(embdir)
embdir.update() embdir.update()
print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}") displayed_embeddings = (tuple(self.word_embeddings.keys()), tuple(self.skipped_embeddings.keys()))
if len(self.skipped_embeddings) > 0: if self.previously_displayed_embeddings != displayed_embeddings:
print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}") self.previously_displayed_embeddings = displayed_embeddings
print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
if len(self.skipped_embeddings) > 0:
print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
def find_embedding_at_position(self, tokens, offset): def find_embedding_at_position(self, tokens, offset):
token = tokens[offset] token = tokens[offset]
......
import time
class Timer:
def __init__(self):
self.start = time.time()
self.records = {}
self.total = 0
def elapsed(self):
end = time.time()
res = end - self.start
self.start = end
return res
def record(self, category, extra_time=0):
e = self.elapsed()
if category not in self.records:
self.records[category] = 0
self.records[category] += e + extra_time
self.total += e + extra_time
def summary(self):
res = f"{self.total:.1f}s"
additions = [x for x in self.records.items() if x[1] >= 0.1]
if not additions:
return res
res += " ("
res += ", ".join([f"{category}: {time_taken:.1f}s" for category, time_taken in additions])
res += ")"
return res
...@@ -21,7 +21,7 @@ from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_grad ...@@ -21,7 +21,7 @@ from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_grad
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, postprocessing, ui_components, ui_common, ui_postprocessing from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, postprocessing, ui_components, ui_common, ui_postprocessing
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
from modules.paths import script_path from modules.paths import script_path, data_path
from modules.shared import opts, cmd_opts, restricted_opts from modules.shared import opts, cmd_opts, restricted_opts
...@@ -91,6 +91,7 @@ save_style_symbol = '\U0001f4be' # 💾 ...@@ -91,6 +91,7 @@ save_style_symbol = '\U0001f4be' # 💾
apply_style_symbol = '\U0001f4cb' # 📋 apply_style_symbol = '\U0001f4cb' # 📋
clear_prompt_symbol = '\U0001F5D1' # 🗑️ clear_prompt_symbol = '\U0001F5D1' # 🗑️
extra_networks_symbol = '\U0001F3B4' # 🎴 extra_networks_symbol = '\U0001F3B4' # 🎴
switch_values_symbol = '\U000021C5' # ⇅
def plaintext_to_html(text): def plaintext_to_html(text):
...@@ -466,6 +467,7 @@ def create_ui(): ...@@ -466,6 +467,7 @@ def create_ui():
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height") height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
if opts.dimensions_and_batch_together: if opts.dimensions_and_batch_together:
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn")
with gr.Column(elem_id="txt2img_column_batch"): with gr.Column(elem_id="txt2img_column_batch"):
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
...@@ -581,6 +583,8 @@ def create_ui(): ...@@ -581,6 +583,8 @@ def create_ui():
txt2img_prompt.submit(**txt2img_args) txt2img_prompt.submit(**txt2img_args)
submit.click(**txt2img_args) submit.click(**txt2img_args)
res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height])
txt_prompt_img.change( txt_prompt_img.change(
fn=modules.images.image_data, fn=modules.images.image_data,
inputs=[ inputs=[
...@@ -708,9 +712,15 @@ def create_ui(): ...@@ -708,9 +712,15 @@ def create_ui():
with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch: with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch:
hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
gr.HTML(f"<p style='padding-bottom: 1em;' class=\"text-gray-500\">Process images in a directory on the same machine where the server is running.<br>Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}</p>") gr.HTML(
f"<p style='padding-bottom: 1em;' class=\"text-gray-500\">Process images in a directory on the same machine where the server is running." +
f"<br>Use an empty output directory to save pictures normally instead of writing to the output directory." +
f"<br>Add inpaint batch mask directory to enable inpaint batch processing."
f"{hidden}</p>"
)
img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir") img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir")
img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir") img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir")
def copy_image(img): def copy_image(img):
if isinstance(img, dict) and 'image' in img: if isinstance(img, dict) and 'image' in img:
...@@ -745,6 +755,7 @@ def create_ui(): ...@@ -745,6 +755,7 @@ def create_ui():
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height") height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
if opts.dimensions_and_batch_together: if opts.dimensions_and_batch_together:
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
with gr.Column(elem_id="img2img_column_batch"): with gr.Column(elem_id="img2img_column_batch"):
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
...@@ -855,6 +866,7 @@ def create_ui(): ...@@ -855,6 +866,7 @@ def create_ui():
inpainting_mask_invert, inpainting_mask_invert,
img2img_batch_input_dir, img2img_batch_input_dir,
img2img_batch_output_dir, img2img_batch_output_dir,
img2img_batch_inpaint_mask_dir
] + custom_inputs, ] + custom_inputs,
outputs=[ outputs=[
img2img_gallery, img2img_gallery,
...@@ -882,6 +894,7 @@ def create_ui(): ...@@ -882,6 +894,7 @@ def create_ui():
img2img_prompt.submit(**img2img_args) img2img_prompt.submit(**img2img_args)
submit.click(**img2img_args) submit.click(**img2img_args)
res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height])
img2img_interrogate.click( img2img_interrogate.click(
fn=lambda *args: process_interrogate(interrogate, *args), fn=lambda *args: process_interrogate(interrogate, *args),
...@@ -1514,8 +1527,8 @@ def create_ui(): ...@@ -1514,8 +1527,8 @@ def create_ui():
with open(cssfile, "r", encoding="utf8") as file: with open(cssfile, "r", encoding="utf8") as file:
css += file.read() + "\n" css += file.read() + "\n"
if os.path.exists(os.path.join(script_path, "user.css")): if os.path.exists(os.path.join(data_path, "user.css")):
with open(os.path.join(script_path, "user.css"), "r", encoding="utf8") as file: with open(os.path.join(data_path, "user.css"), "r", encoding="utf8") as file:
css += file.read() + "\n" css += file.read() + "\n"
if not cmd_opts.no_progressbar_hiding: if not cmd_opts.no_progressbar_hiding:
...@@ -1564,6 +1577,14 @@ def create_ui(): ...@@ -1564,6 +1577,14 @@ def create_ui():
outputs=[component, text_settings], outputs=[component, text_settings],
) )
button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
button_set_checkpoint.click(
fn=lambda value, _: run_settings_single(value, key='sd_model_checkpoint'),
_js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }",
inputs=[component_dict['sd_model_checkpoint'], dummy_component],
outputs=[component_dict['sd_model_checkpoint'], text_settings],
)
component_keys = [k for k in opts.data_labels.keys() if k in component_dict] component_keys = [k for k in opts.data_labels.keys() if k in component_dict]
def get_settings_values(): def get_settings_values():
...@@ -1696,14 +1717,14 @@ def create_ui(): ...@@ -1696,14 +1717,14 @@ def create_ui():
def reload_javascript(): def reload_javascript():
head = f'<script type="text/javascript" src="file={os.path.abspath("script.js")}"></script>\n' head = f'<script type="text/javascript" src="file={os.path.abspath("script.js")}?{os.path.getmtime("script.js")}"></script>\n'
inline = f"{localization.localization_js(shared.opts.localization)};" inline = f"{localization.localization_js(shared.opts.localization)};"
if cmd_opts.theme is not None: if cmd_opts.theme is not None:
inline += f"set_theme('{cmd_opts.theme}');" inline += f"set_theme('{cmd_opts.theme}');"
for script in modules.scripts.list_scripts("javascript", ".js"): for script in modules.scripts.list_scripts("javascript", ".js"):
head += f'<script type="text/javascript" src="file={script.path}"></script>\n' head += f'<script type="text/javascript" src="file={script.path}?{os.path.getmtime(script.path)}"></script>\n'
head += f'<script type="text/javascript">{inline}</script>\n' head += f'<script type="text/javascript">{inline}</script>\n'
......
...@@ -48,3 +48,11 @@ class FormColorPicker(gr.ColorPicker, gr.components.FormComponent): ...@@ -48,3 +48,11 @@ class FormColorPicker(gr.ColorPicker, gr.components.FormComponent):
def get_block_name(self): def get_block_name(self):
return "colorpicker" return "colorpicker"
class DropdownMulti(gr.Dropdown):
"""Same as gr.Dropdown but always multiselect"""
def __init__(self, **kwargs):
super().__init__(multiselect=True, **kwargs)
def get_block_name(self):
return "dropdown"
...@@ -13,7 +13,7 @@ import shutil ...@@ -13,7 +13,7 @@ import shutil
import errno import errno
from modules import extensions, shared, paths from modules import extensions, shared, paths
from modules.call_queue import wrap_gradio_gpu_call
available_extensions = {"extensions": []} available_extensions = {"extensions": []}
...@@ -50,12 +50,17 @@ def apply_and_restart(disable_list, update_list): ...@@ -50,12 +50,17 @@ def apply_and_restart(disable_list, update_list):
shared.state.need_restart = True shared.state.need_restart = True
def check_updates(): def check_updates(id_task, disable_list):
check_access() check_access()
for ext in extensions.extensions: disabled = json.loads(disable_list)
if ext.remote is None: assert type(disabled) == list, f"wrong disable_list data for apply_and_restart: {disable_list}"
continue
exts = [ext for ext in extensions.extensions if ext.remote is not None and ext.name not in disabled]
shared.state.job_count = len(exts)
for ext in exts:
shared.state.textinfo = ext.name
try: try:
ext.check_updates() ext.check_updates()
...@@ -63,7 +68,9 @@ def check_updates(): ...@@ -63,7 +68,9 @@ def check_updates():
print(f"Error checking updates for {ext.name}:", file=sys.stderr) print(f"Error checking updates for {ext.name}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
return extension_table() shared.state.nextjob()
return extension_table(), ""
def extension_table(): def extension_table():
...@@ -132,7 +139,7 @@ def install_extension_from_url(dirname, url): ...@@ -132,7 +139,7 @@ def install_extension_from_url(dirname, url):
normalized_url = normalize_git_url(url) normalized_url = normalize_git_url(url)
assert len([x for x in extensions.extensions if normalize_git_url(x.remote) == normalized_url]) == 0, 'Extension with this URL is already installed' assert len([x for x in extensions.extensions if normalize_git_url(x.remote) == normalized_url]) == 0, 'Extension with this URL is already installed'
tmpdir = os.path.join(paths.script_path, "tmp", dirname) tmpdir = os.path.join(paths.data_path, "tmp", dirname)
try: try:
shutil.rmtree(tmpdir, True) shutil.rmtree(tmpdir, True)
...@@ -273,12 +280,13 @@ def create_ui(): ...@@ -273,12 +280,13 @@ def create_ui():
with gr.Tabs(elem_id="tabs_extensions") as tabs: with gr.Tabs(elem_id="tabs_extensions") as tabs:
with gr.TabItem("Installed"): with gr.TabItem("Installed"):
with gr.Row(): with gr.Row(elem_id="extensions_installed_top"):
apply = gr.Button(value="Apply and restart UI", variant="primary") apply = gr.Button(value="Apply and restart UI", variant="primary")
check = gr.Button(value="Check for updates") check = gr.Button(value="Check for updates")
extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False).style(container=False) extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False).style(container=False)
extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False).style(container=False) extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False).style(container=False)
info = gr.HTML()
extensions_table = gr.HTML(lambda: extension_table()) extensions_table = gr.HTML(lambda: extension_table())
apply.click( apply.click(
...@@ -289,10 +297,10 @@ def create_ui(): ...@@ -289,10 +297,10 @@ def create_ui():
) )
check.click( check.click(
fn=check_updates, fn=wrap_gradio_gpu_call(check_updates, extra_outputs=[gr.update()]),
_js="extensions_check", _js="extensions_check",
inputs=[], inputs=[info, extensions_disabled_list],
outputs=[extensions_table], outputs=[extensions_table, info],
) )
with gr.TabItem("Available"): with gr.TabItem("Available"):
......
import glob
import os.path import os.path
import urllib.parse
from pathlib import Path
from modules import shared from modules import shared
import gradio as gr import gradio as gr
...@@ -8,12 +11,31 @@ import html ...@@ -8,12 +11,31 @@ import html
from modules.generation_parameters_copypaste import image_from_url_text from modules.generation_parameters_copypaste import image_from_url_text
extra_pages = [] extra_pages = []
allowed_dirs = set()
def register_page(page): def register_page(page):
"""registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions""" """registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions"""
extra_pages.append(page) extra_pages.append(page)
allowed_dirs.clear()
allowed_dirs.update(set(sum([x.allowed_directories_for_previews() for x in extra_pages], [])))
def add_pages_to_demo(app):
def fetch_file(filename: str = ""):
from starlette.responses import FileResponse
if not any([Path(x).resolve() in Path(filename).resolve().parents for x in allowed_dirs]):
raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
if os.path.splitext(filename)[1].lower() != ".png":
raise ValueError(f"File cannot be fetched: {filename}. Only png.")
# would profit from returning 304
return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"])
class ExtraNetworksPage: class ExtraNetworksPage:
...@@ -26,10 +48,44 @@ class ExtraNetworksPage: ...@@ -26,10 +48,44 @@ class ExtraNetworksPage:
def refresh(self): def refresh(self):
pass pass
def link_preview(self, filename):
return "./sd_extra_networks/thumb?filename=" + urllib.parse.quote(filename.replace('\\', '/')) + "&mtime=" + str(os.path.getmtime(filename))
def search_terms_from_path(self, filename, possible_directories=None):
abspath = os.path.abspath(filename)
for parentdir in (possible_directories if possible_directories is not None else self.allowed_directories_for_previews()):
parentdir = os.path.abspath(parentdir)
if abspath.startswith(parentdir):
return abspath[len(parentdir):].replace('\\', '/')
return ""
def create_html(self, tabname): def create_html(self, tabname):
view = shared.opts.extra_networks_default_view view = shared.opts.extra_networks_default_view
items_html = '' items_html = ''
subdirs = {}
for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]:
for x in glob.glob(os.path.join(parentdir, '**/*'), recursive=True):
if not os.path.isdir(x):
continue
subdir = os.path.abspath(x)[len(parentdir):].replace("\\", "/")
while subdir.startswith("/"):
subdir = subdir[1:]
subdirs[subdir] = 1
if subdirs:
subdirs = {"": 1, **subdirs}
subdirs_html = "".join([f"""
<button class='gr-button gr-button-lg gr-button-secondary{" search-all" if subdir=="" else ""}' onclick='extraNetworksSearchButton("{tabname}_extra_tabs", event)'>
{html.escape(subdir if subdir!="" else "all")}
</button>
""" for subdir in subdirs])
for item in self.list_items(): for item in self.list_items():
items_html += self.create_html_for_item(item, tabname) items_html += self.create_html_for_item(item, tabname)
...@@ -38,6 +94,9 @@ class ExtraNetworksPage: ...@@ -38,6 +94,9 @@ class ExtraNetworksPage:
items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs) items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs)
res = f""" res = f"""
<div id='{tabname}_{self.name}_subdirs' class='extra-network-subdirs extra-network-subdirs-{view}'>
{subdirs_html}
</div>
<div id='{tabname}_{self.name}_cards' class='extra-network-{view}'> <div id='{tabname}_{self.name}_cards' class='extra-network-{view}'>
{items_html} {items_html}
</div> </div>
...@@ -54,14 +113,19 @@ class ExtraNetworksPage: ...@@ -54,14 +113,19 @@ class ExtraNetworksPage:
def create_html_for_item(self, item, tabname): def create_html_for_item(self, item, tabname):
preview = item.get("preview", None) preview = item.get("preview", None)
onclick = item.get("onclick", None)
if onclick is None:
onclick = '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
args = { args = {
"preview_html": "style='background-image: url(\"" + html.escape(preview) + "\")'" if preview else '', "preview_html": "style='background-image: url(\"" + html.escape(preview) + "\")'" if preview else '',
"prompt": item["prompt"], "prompt": item.get("prompt", None),
"tabname": json.dumps(tabname), "tabname": json.dumps(tabname),
"local_preview": json.dumps(item["local_preview"]), "local_preview": json.dumps(item["local_preview"]),
"name": item["name"], "name": item["name"],
"card_clicked": '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"', "card_clicked": onclick,
"save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"', "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"',
"search_term": item.get("search_term", ""),
} }
return self.card_page.format(**args) return self.card_page.format(**args)
...@@ -117,8 +181,13 @@ def create_ui(container, button, tabname): ...@@ -117,8 +181,13 @@ def create_ui(container, button, tabname):
ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False)
button.click(fn=lambda: gr.update(visible=True), inputs=[], outputs=[container]) def toggle_visibility(is_visible):
button_close.click(fn=lambda: gr.update(visible=False), inputs=[], outputs=[container]) is_visible = not is_visible
return is_visible, gr.update(visible=is_visible)
state_visible = gr.State(value=False)
button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container])
button_close.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container])
def refresh(): def refresh():
res = [] res = []
...@@ -138,7 +207,7 @@ def path_is_parent(parent_path, child_path): ...@@ -138,7 +207,7 @@ def path_is_parent(parent_path, child_path):
parent_path = os.path.abspath(parent_path) parent_path = os.path.abspath(parent_path)
child_path = os.path.abspath(child_path) child_path = os.path.abspath(child_path)
return os.path.commonpath([parent_path]) == os.path.commonpath([parent_path, child_path]) return child_path.startswith(parent_path)
def setup_ui(ui, gallery): def setup_ui(ui, gallery):
...@@ -168,7 +237,8 @@ def setup_ui(ui, gallery): ...@@ -168,7 +237,8 @@ def setup_ui(ui, gallery):
ui.button_save_preview.click( ui.button_save_preview.click(
fn=save_preview, fn=save_preview,
_js="function(x, y, z){console.log(x, y, z); return [selected_gallery_index(), y, z]}", _js="function(x, y, z){return [selected_gallery_index(), y, z]}",
inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename], inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename],
outputs=[*ui.pages] outputs=[*ui.pages]
) )
import html
import json
import os
import urllib.parse
from modules import shared, ui_extra_networks, sd_models
class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
def __init__(self):
super().__init__('Checkpoints')
def refresh(self):
shared.refresh_checkpoints()
def list_items(self):
for name, checkpoint in sd_models.checkpoints_list.items():
path, ext = os.path.splitext(checkpoint.filename)
previews = [path + ".png", path + ".preview.png"]
preview = None
for file in previews:
if os.path.isfile(file):
preview = self.link_preview(file)
break
yield {
"name": checkpoint.name_for_extra,
"filename": path,
"preview": preview,
"search_term": self.search_terms_from_path(checkpoint.filename),
"onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"',
"local_preview": path + ".png",
}
def allowed_directories_for_previews(self):
return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None]
...@@ -19,13 +19,14 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): ...@@ -19,13 +19,14 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
preview = None preview = None
for file in previews: for file in previews:
if os.path.isfile(file): if os.path.isfile(file):
preview = "./file=" + file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(file)) preview = self.link_preview(file)
break break
yield { yield {
"name": name, "name": name,
"filename": path, "filename": path,
"preview": preview, "preview": preview,
"search_term": self.search_terms_from_path(path),
"prompt": json.dumps(f"<hypernet:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"), "prompt": json.dumps(f"<hypernet:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
"local_preview": path + ".png", "local_preview": path + ".png",
} }
......
...@@ -19,12 +19,13 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): ...@@ -19,12 +19,13 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
preview = None preview = None
if os.path.isfile(preview_file): if os.path.isfile(preview_file):
preview = "./file=" + preview_file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(preview_file)) preview = self.link_preview(preview_file)
yield { yield {
"name": embedding.name, "name": embedding.name,
"filename": embedding.filename, "filename": embedding.filename,
"preview": preview, "preview": preview,
"search_term": self.search_terms_from_path(embedding.filename),
"prompt": json.dumps(embedding.name), "prompt": json.dumps(embedding.name),
"local_preview": path + ".preview.png", "local_preview": path + ".preview.png",
} }
......
...@@ -11,7 +11,6 @@ from modules import modelloader, shared ...@@ -11,7 +11,6 @@ from modules import modelloader, shared
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
NEAREST = (Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST) NEAREST = (Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST)
from modules.paths import models_path
class Upscaler: class Upscaler:
...@@ -39,7 +38,7 @@ class Upscaler: ...@@ -39,7 +38,7 @@ class Upscaler:
self.mod_scale = None self.mod_scale = None
if self.model_path is None and self.name: if self.model_path is None and self.name:
self.model_path = os.path.join(models_path, self.name) self.model_path = os.path.join(shared.models_path, self.name)
if self.model_path and create_dirs: if self.model_path and create_dirs:
os.makedirs(self.model_path, exist_ok=True) os.makedirs(self.model_path, exist_ok=True)
...@@ -143,4 +142,4 @@ class UpscalerNearest(Upscaler): ...@@ -143,4 +142,4 @@ class UpscalerNearest(Upscaler):
def __init__(self, dirname=None): def __init__(self, dirname=None):
super().__init__(False) super().__init__(False)
self.name = "Nearest" self.name = "Nearest"
self.scalers = [UpscalerData("Nearest", None, self)] self.scalers = [UpscalerData("Nearest", None, self)]
\ No newline at end of file
blendmodes blendmodes
accelerate accelerate
basicsr basicsr
fairscale==0.4.4
fonts fonts
font-roboto font-roboto
gfpgan gfpgan
...@@ -17,7 +16,7 @@ pytorch_lightning==1.7.7 ...@@ -17,7 +16,7 @@ pytorch_lightning==1.7.7
realesrgan realesrgan
scikit-image>=0.19 scikit-image>=0.19
timm==0.4.12 timm==0.4.12
transformers==4.19.2 transformers==4.25.1
torch torch
einops einops
jsonmerge jsonmerge
......
blendmodes==2022 blendmodes==2022
transformers==4.19.2 transformers==4.25.1
accelerate==0.12.0 accelerate==0.12.0
basicsr==1.4.2 basicsr==1.4.2
gfpgan==1.3.8 gfpgan==1.3.8
...@@ -14,7 +14,6 @@ scikit-image==0.19.2 ...@@ -14,7 +14,6 @@ scikit-image==0.19.2
fonts fonts
font-roboto font-roboto
timm==0.6.7 timm==0.6.7
fairscale==0.4.9
piexif==1.1.3 piexif==1.1.3
einops==0.4.1 einops==0.4.1
jsonmerge==1.8.0 jsonmerge==1.8.0
......
...@@ -104,3 +104,28 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing): ...@@ -104,3 +104,28 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
def image_changed(self): def image_changed(self):
upscale_cache.clear() upscale_cache.clear()
class ScriptPostprocessingUpscaleSimple(ScriptPostprocessingUpscale):
name = "Simple Upscale"
order = 900
def ui(self):
with FormRow():
upscaler_name = gr.Dropdown(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)
upscale_by = gr.Slider(minimum=0.05, maximum=8.0, step=0.05, label="Upscale by", value=2)
return {
"upscale_by": upscale_by,
"upscaler_name": upscaler_name,
}
def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_by=2.0, upscaler_name=None):
if upscaler_name is None or upscaler_name == "None":
return
upscaler1 = next(iter([x for x in shared.sd_upscalers if x.name == upscaler_name]), None)
assert upscaler1, f'could not find upscaler named {upscaler_name}'
pp.image = self.upscale(pp.image, pp.info, upscaler1, 0, upscale_by, 0, 0, False)
pp.info[f"Postprocess upscaler"] = upscaler1.name
This diff is collapsed.
...@@ -74,7 +74,12 @@ ...@@ -74,7 +74,12 @@
#txt2img_gallery img, #img2img_gallery img{ #txt2img_gallery img, #img2img_gallery img{
object-fit: scale-down; object-fit: scale-down;
} }
#txt2img_actions_column, #img2img_actions_column {
margin: 0.35rem 0.75rem 0.35rem 0;
}
#script_list {
padding: .625rem .75rem 0 .625rem;
}
.justify-center.overflow-x-scroll { .justify-center.overflow-x-scroll {
justify-content: left; justify-content: left;
} }
...@@ -126,6 +131,7 @@ ...@@ -126,6 +131,7 @@
#txt2img_actions_column, #img2img_actions_column{ #txt2img_actions_column, #img2img_actions_column{
gap: 0; gap: 0;
margin-right: .75rem;
} }
#txt2img_tools, #img2img_tools{ #txt2img_tools, #img2img_tools{
...@@ -150,6 +156,7 @@ ...@@ -150,6 +156,7 @@
#txt2img_styles_row, #img2img_styles_row{ #txt2img_styles_row, #img2img_styles_row{
gap: 0.25em; gap: 0.25em;
margin-top: 0.3em;
} }
#txt2img_styles_row > button, #img2img_styles_row > button{ #txt2img_styles_row > button, #img2img_styles_row > button{
...@@ -164,7 +171,7 @@ ...@@ -164,7 +171,7 @@
min-height: 3.2em; min-height: 3.2em;
} }
#txt2img_styles ul, #img2img_styles ul{ ul.list-none{
max-height: 35em; max-height: 35em;
z-index: 2000; z-index: 2000;
} }
...@@ -311,11 +318,11 @@ input[type="range"]{ ...@@ -311,11 +318,11 @@ input[type="range"]{
.min-h-\[6rem\] { min-height: unset !important; } .min-h-\[6rem\] { min-height: unset !important; }
.progressDiv{ .progressDiv{
position: absolute; position: relative;
height: 20px; height: 20px;
top: -20px;
background: #b4c0cc; background: #b4c0cc;
border-radius: 3px !important; border-radius: 3px !important;
margin-bottom: -3px;
} }
.dark .progressDiv{ .dark .progressDiv{
...@@ -535,7 +542,7 @@ input[type="range"]{ ...@@ -535,7 +542,7 @@ input[type="range"]{
} }
#quicksettings { #quicksettings {
gap: 0.4em; width: fit-content;
} }
#quicksettings > div, #quicksettings > fieldset{ #quicksettings > div, #quicksettings > fieldset{
...@@ -545,6 +552,7 @@ input[type="range"]{ ...@@ -545,6 +552,7 @@ input[type="range"]{
border: none; border: none;
box-shadow: none; box-shadow: none;
background: none; background: none;
margin-right: 10px;
} }
#quicksettings > div > div > div > label > span { #quicksettings > div > div > div > label > span {
...@@ -567,7 +575,7 @@ canvas[key="mask"] { ...@@ -567,7 +575,7 @@ canvas[key="mask"] {
right: 0.5em; right: 0.5em;
top: -0.6em; top: -0.6em;
z-index: 400; z-index: 400;
width: 8em; width: 6em;
} }
#quicksettings .gr-box > div > div > input.gr-text-input { #quicksettings .gr-box > div > div > input.gr-text-input {
top: -1.12em; top: -1.12em;
...@@ -665,11 +673,27 @@ canvas[key="mask"] { ...@@ -665,11 +673,27 @@ canvas[key="mask"] {
#quicksettings .gr-button-tool{ #quicksettings .gr-button-tool{
margin: 0; margin: 0;
border-color: unset;
background-color: unset;
} }
#modelmerger_interp_description>p {
margin: 0!important;
text-align: center;
}
#modelmerger_interp_description {
margin: 0.35rem 0.75rem 1.23rem;
}
#img2img_settings > div.gr-form, #txt2img_settings > div.gr-form { #img2img_settings > div.gr-form, #txt2img_settings > div.gr-form {
padding-top: 0.9em; padding-top: 0.9em;
padding-bottom: 0.9em;
}
#txt2img_settings {
padding-top: 1.16em;
padding-bottom: 0.9em;
}
#img2img_settings {
padding-bottom: 0.9em;
} }
#img2img_settings div.gr-form .gr-form, #txt2img_settings div.gr-form .gr-form, #train_tabs div.gr-form .gr-form{ #img2img_settings div.gr-form .gr-form, #txt2img_settings div.gr-form .gr-form, #train_tabs div.gr-form .gr-form{
...@@ -714,9 +738,6 @@ footer { ...@@ -714,9 +738,6 @@ footer {
white-space: nowrap; white-space: nowrap;
min-width: auto; min-width: auto;
} }
#txt2img_hires_fix{
margin-left: -0.8em;
}
#img2img_copy_to_img2img, #img2img_copy_to_sketch, #img2img_copy_to_inpaint, #img2img_copy_to_inpaint_sketch{ #img2img_copy_to_img2img, #img2img_copy_to_sketch, #img2img_copy_to_inpaint, #img2img_copy_to_inpaint_sketch{
margin-left: 0em; margin-left: 0em;
...@@ -744,7 +765,7 @@ footer { ...@@ -744,7 +765,7 @@ footer {
.dark .gr-compact{ .dark .gr-compact{
background-color: rgb(31 41 55 / var(--tw-bg-opacity)); background-color: rgb(31 41 55 / var(--tw-bg-opacity));
margin-left: 0.8em; margin-left: 0;
} }
.gr-compact{ .gr-compact{
...@@ -786,7 +807,13 @@ footer { ...@@ -786,7 +807,13 @@ footer {
margin: 0.3em; margin: 0.3em;
} }
.extra-network-subdirs{
padding: 0.2em 0.35em;
}
.extra-network-subdirs button{
margin: 0 0.15em;
}
#txt2img_extra_networks .search, #img2img_extra_networks .search{ #txt2img_extra_networks .search, #img2img_extra_networks .search{
display: inline-block; display: inline-block;
...@@ -857,6 +884,7 @@ footer { ...@@ -857,6 +884,7 @@ footer {
white-space: nowrap; white-space: nowrap;
text-overflow: ellipsis; text-overflow: ellipsis;
background: rgba(0,0,0,.5); background: rgba(0,0,0,.5);
color: white;
} }
.extra-network-thumbs .card:hover .actions .name { .extra-network-thumbs .card:hover .actions .name {
...@@ -928,3 +956,6 @@ footer { ...@@ -928,3 +956,6 @@ footer {
color: red; color: red;
} }
[id*='_prompt_container'] > div {
margin: 0!important;
}
...@@ -10,7 +10,7 @@ then ...@@ -10,7 +10,7 @@ then
fi fi
export install_dir="$HOME" export install_dir="$HOME"
export COMMANDLINE_ARGS="--skip-torch-cuda-test --no-half --use-cpu interrogate" export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --no-half-vae --use-cpu interrogate"
export TORCH_COMMAND="pip install torch==1.12.1 torchvision==0.13.1" export TORCH_COMMAND="pip install torch==1.12.1 torchvision==0.13.1"
export K_DIFFUSION_REPO="https://github.com/brkirch/k-diffusion.git" export K_DIFFUSION_REPO="https://github.com/brkirch/k-diffusion.git"
export K_DIFFUSION_COMMIT_HASH="51c9778f269cedb55a4d88c79c0246d35bdadb71" export K_DIFFUSION_COMMIT_HASH="51c9778f269cedb55a4d88c79c0246d35bdadb71"
......
...@@ -3,17 +3,28 @@ ...@@ -3,17 +3,28 @@
if not defined PYTHON (set PYTHON=python) if not defined PYTHON (set PYTHON=python)
if not defined VENV_DIR (set "VENV_DIR=%~dp0%venv") if not defined VENV_DIR (set "VENV_DIR=%~dp0%venv")
set ERROR_REPORTING=FALSE set ERROR_REPORTING=FALSE
mkdir tmp 2>NUL mkdir tmp 2>NUL
%PYTHON% -c "" >tmp/stdout.txt 2>tmp/stderr.txt %PYTHON% -c "" >tmp/stdout.txt 2>tmp/stderr.txt
if %ERRORLEVEL% == 0 goto :start_venv if %ERRORLEVEL% == 0 goto :check_pip
echo Couldn't launch python echo Couldn't launch python
goto :show_stdout_stderr goto :show_stdout_stderr
:check_pip
%PYTHON% -mpip --help >tmp/stdout.txt 2>tmp/stderr.txt
if %ERRORLEVEL% == 0 goto :start_venv
if "%PIP_INSTALLER_LOCATION%" == "" goto :show_stdout_stderr
%PYTHON% "%PIP_INSTALLER_LOCATION%" >tmp/stdout.txt 2>tmp/stderr.txt
if %ERRORLEVEL% == 0 goto :start_venv
echo Couldn't install pip
goto :show_stdout_stderr
:start_venv :start_venv
if ["%VENV_DIR%"] == ["-"] goto :skip_venv if ["%VENV_DIR%"] == ["-"] goto :skip_venv
if ["%SKIP_VENV%"] == ["1"] goto :skip_venv
dir "%VENV_DIR%\Scripts\Python.exe" >tmp/stdout.txt 2>tmp/stderr.txt dir "%VENV_DIR%\Scripts\Python.exe" >tmp/stdout.txt 2>tmp/stderr.txt
if %ERRORLEVEL% == 0 goto :activate_venv if %ERRORLEVEL% == 0 goto :activate_venv
...@@ -28,13 +39,13 @@ goto :show_stdout_stderr ...@@ -28,13 +39,13 @@ goto :show_stdout_stderr
:activate_venv :activate_venv
set PYTHON="%VENV_DIR%\Scripts\Python.exe" set PYTHON="%VENV_DIR%\Scripts\Python.exe"
echo venv %PYTHON% echo venv %PYTHON%
if [%ACCELERATE%] == ["True"] goto :accelerate
goto :launch
:skip_venv :skip_venv
if [%ACCELERATE%] == ["True"] goto :accelerate
goto :launch
:accelerate :accelerate
echo "Checking for accelerate" echo Checking for accelerate
set ACCELERATE="%VENV_DIR%\Scripts\accelerate.exe" set ACCELERATE="%VENV_DIR%\Scripts\accelerate.exe"
if EXIST %ACCELERATE% goto :accelerate_launch if EXIST %ACCELERATE% goto :accelerate_launch
...@@ -44,7 +55,7 @@ pause ...@@ -44,7 +55,7 @@ pause
exit /b exit /b
:accelerate_launch :accelerate_launch
echo "Accelerating" echo Accelerating
%ACCELERATE% launch --num_cpu_threads_per_process=6 launch.py %ACCELERATE% launch --num_cpu_threads_per_process=6 launch.py
pause pause
exit /b exit /b
......
...@@ -12,10 +12,9 @@ from packaging import version ...@@ -12,10 +12,9 @@ from packaging import version
import logging import logging
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
from modules import import_hook, errors, extra_networks from modules import import_hook, errors, extra_networks, ui_extra_networks_checkpoints
from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
from modules.paths import script_path
import torch import torch
...@@ -120,6 +119,7 @@ def initialize(): ...@@ -120,6 +119,7 @@ def initialize():
ui_extra_networks.intialize() ui_extra_networks.intialize()
ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion()) ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks()) ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints())
extra_networks.initialize() extra_networks.initialize()
extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet()) extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
...@@ -228,6 +228,8 @@ def webui(): ...@@ -228,6 +228,8 @@ def webui():
if launch_api: if launch_api:
create_api(app) create_api(app)
ui_extra_networks.add_pages_to_demo(app)
modules.script_callbacks.app_started_callback(shared.demo, app) modules.script_callbacks.app_started_callback(shared.demo, app)
wait_on_server(shared.demo) wait_on_server(shared.demo)
...@@ -255,6 +257,7 @@ def webui(): ...@@ -255,6 +257,7 @@ def webui():
ui_extra_networks.intialize() ui_extra_networks.intialize()
ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion()) ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks()) ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints())
extra_networks.initialize() extra_networks.initialize()
extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet()) extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment