Commit 0bc7867c authored by hako-mikan's avatar hako-mikan Committed by GitHub

Merge branch 'AUTOMATIC1111:master' into master

parents 816096e6 cf2772fa
...@@ -20,7 +20,7 @@ jobs: ...@@ -20,7 +20,7 @@ jobs:
# not to have GHA download an (at the time of writing) 4 GB cache # not to have GHA download an (at the time of writing) 4 GB cache
# of PyTorch and other dependencies. # of PyTorch and other dependencies.
- name: Install Ruff - name: Install Ruff
run: pip install ruff==0.0.272 run: pip install ruff==0.1.6
- name: Run Ruff - name: Run Ruff
run: ruff . run: ruff .
lint-js: lint-js:
......
This diff is collapsed.
...@@ -121,7 +121,9 @@ Alternatively, use online services (like Google Colab): ...@@ -121,7 +121,9 @@ Alternatively, use online services (like Google Colab):
# Debian-based: # Debian-based:
sudo apt install wget git python3 python3-venv libgl1 libglib2.0-0 sudo apt install wget git python3 python3-venv libgl1 libglib2.0-0
# Red Hat-based: # Red Hat-based:
sudo dnf install wget git python3 sudo dnf install wget git python3 gperftools-libs libglvnd-glx
# openSUSE-based:
sudo zypper install wget git python3 libtcmalloc4 libglvnd
# Arch-based: # Arch-based:
sudo pacman -S wget git python3 sudo pacman -S wget git python3
``` ```
...@@ -147,7 +149,7 @@ For the purposes of getting Google and other search engines to crawl the wiki, h ...@@ -147,7 +149,7 @@ For the purposes of getting Google and other search engines to crawl the wiki, h
## Credits ## Credits
Licenses for borrowed code can be found in `Settings -> Licenses` screen, and also in `html/licenses.html` file. Licenses for borrowed code can be found in `Settings -> Licenses` screen, and also in `html/licenses.html` file.
- Stable Diffusion - https://github.com/CompVis/stable-diffusion, https://github.com/CompVis/taming-transformers - Stable Diffusion - https://github.com/Stability-AI/stablediffusion, https://github.com/CompVis/taming-transformers
- k-diffusion - https://github.com/crowsonkb/k-diffusion.git - k-diffusion - https://github.com/crowsonkb/k-diffusion.git
- GFPGAN - https://github.com/TencentARC/GFPGAN.git - GFPGAN - https://github.com/TencentARC/GFPGAN.git
- CodeFormer - https://github.com/sczhou/CodeFormer - CodeFormer - https://github.com/sczhou/CodeFormer
...@@ -174,5 +176,6 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al ...@@ -174,5 +176,6 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
- TAESD - Ollin Boer Bohan - https://github.com/madebyollin/taesd - TAESD - Ollin Boer Bohan - https://github.com/madebyollin/taesd
- LyCORIS - KohakuBlueleaf - LyCORIS - KohakuBlueleaf
- Restart sampling - lambertae - https://github.com/Newbeeer/diffusion_restart_sampling - Restart sampling - lambertae - https://github.com/Newbeeer/diffusion_restart_sampling
- Hypertile - tfernd - https://github.com/tfernd/HyperTile
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user. - Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
- (You) - (You)
...@@ -19,3 +19,50 @@ def rebuild_cp_decomposition(up, down, mid): ...@@ -19,3 +19,50 @@ def rebuild_cp_decomposition(up, down, mid):
up = up.reshape(up.size(0), -1) up = up.reshape(up.size(0), -1)
down = down.reshape(down.size(0), -1) down = down.reshape(down.size(0), -1)
return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down) return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down)
# copied from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/lokr.py
def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
'''
return a tuple of two value of input dimension decomposed by the number closest to factor
second value is higher or equal than first value.
In LoRA with Kroneckor Product, first value is a value for weight scale.
secon value is a value for weight.
Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different.
examples)
factor
-1 2 4 8 16 ...
127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127
128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16
250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25
360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30
512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32
1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64
'''
if factor > 0 and (dimension % factor) == 0:
m = factor
n = dimension // factor
if m > n:
n, m = m, n
return m, n
if factor < 0:
factor = dimension
m, n = 1, dimension
length = m + n
while m<n:
new_m = m + 1
while dimension%new_m != 0:
new_m += 1
new_n = dimension // new_m
if new_m + new_n > length or new_m>factor:
break
else:
m, n = new_m, new_n
if m > n:
n, m = m, n
return m, n
import torch
import network
from lyco_helpers import factorization
from einops import rearrange
class ModuleTypeOFT(network.ModuleType):
def create_module(self, net: network.Network, weights: network.NetworkWeights):
if all(x in weights.w for x in ["oft_blocks"]) or all(x in weights.w for x in ["oft_diag"]):
return NetworkModuleOFT(net, weights)
return None
# Supports both kohya-ss' implementation of COFT https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py
# and KohakuBlueleaf's implementation of OFT/COFT https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/diag_oft.py
class NetworkModuleOFT(network.NetworkModule):
def __init__(self, net: network.Network, weights: network.NetworkWeights):
super().__init__(net, weights)
self.lin_module = None
self.org_module: list[torch.Module] = [self.sd_module]
self.scale = 1.0
# kohya-ss
if "oft_blocks" in weights.w.keys():
self.is_kohya = True
self.oft_blocks = weights.w["oft_blocks"] # (num_blocks, block_size, block_size)
self.alpha = weights.w["alpha"] # alpha is constraint
self.dim = self.oft_blocks.shape[0] # lora dim
# LyCORIS
elif "oft_diag" in weights.w.keys():
self.is_kohya = False
self.oft_blocks = weights.w["oft_diag"]
# self.alpha is unused
self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size)
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]
is_conv = type(self.sd_module) in [torch.nn.Conv2d]
is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] # unsupported
if is_linear:
self.out_dim = self.sd_module.out_features
elif is_conv:
self.out_dim = self.sd_module.out_channels
elif is_other_linear:
self.out_dim = self.sd_module.embed_dim
if self.is_kohya:
self.constraint = self.alpha * self.out_dim
self.num_blocks = self.dim
self.block_size = self.out_dim // self.dim
else:
self.constraint = None
self.block_size, self.num_blocks = factorization(self.out_dim, self.dim)
def calc_updown(self, orig_weight):
oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
eye = torch.eye(self.block_size, device=self.oft_blocks.device)
if self.is_kohya:
block_Q = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix
norm_Q = torch.norm(block_Q.flatten())
new_norm_Q = torch.clamp(norm_Q, max=self.constraint)
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse())
R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
# This errors out for MultiheadAttention, might need to be handled up-stream
merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
merged_weight = torch.einsum(
'k n m, k n ... -> k m ...',
R,
merged_weight
)
merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...')
updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
output_shape = orig_weight.shape
return self.finalize_updown(updown, orig_weight, output_shape)
...@@ -11,6 +11,7 @@ import network_ia3 ...@@ -11,6 +11,7 @@ import network_ia3
import network_lokr import network_lokr
import network_full import network_full
import network_norm import network_norm
import network_oft
import torch import torch
from typing import Union from typing import Union
...@@ -28,6 +29,7 @@ module_types = [ ...@@ -28,6 +29,7 @@ module_types = [
network_full.ModuleTypeFull(), network_full.ModuleTypeFull(),
network_norm.ModuleTypeNorm(), network_norm.ModuleTypeNorm(),
network_glora.ModuleTypeGLora(), network_glora.ModuleTypeGLora(),
network_oft.ModuleTypeOFT(),
] ]
...@@ -157,7 +159,8 @@ def load_network(name, network_on_disk): ...@@ -157,7 +159,8 @@ def load_network(name, network_on_disk):
bundle_embeddings = {} bundle_embeddings = {}
for key_network, weight in sd.items(): for key_network, weight in sd.items():
key_network_without_network_parts, network_part = key_network.split(".", 1) key_network_without_network_parts, _, network_part = key_network.partition(".")
if key_network_without_network_parts == "bundle_emb": if key_network_without_network_parts == "bundle_emb":
emb_name, vec_name = network_part.split(".", 1) emb_name, vec_name = network_part.split(".", 1)
emb_dict = bundle_embeddings.get(emb_name, {}) emb_dict = bundle_embeddings.get(emb_name, {})
...@@ -189,6 +192,17 @@ def load_network(name, network_on_disk): ...@@ -189,6 +192,17 @@ def load_network(name, network_on_disk):
key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model") key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model")
sd_module = shared.sd_model.network_layer_mapping.get(key, None) sd_module = shared.sd_model.network_layer_mapping.get(key, None)
# kohya_ss OFT module
elif sd_module is None and "oft_unet" in key_network_without_network_parts:
key = key_network_without_network_parts.replace("oft_unet", "diffusion_model")
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
# KohakuBlueLeaf OFT module
if sd_module is None and "oft_diag" in key:
key = key_network_without_network_parts.replace("lora_unet", "diffusion_model")
key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model")
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
if sd_module is None: if sd_module is None:
keys_failed_to_match[key_network] = key keys_failed_to_match[key_network] = key
continue continue
......
...@@ -17,6 +17,8 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): ...@@ -17,6 +17,8 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
def create_item(self, name, index=None, enable_filter=True): def create_item(self, name, index=None, enable_filter=True):
lora_on_disk = networks.available_networks.get(name) lora_on_disk = networks.available_networks.get(name)
if lora_on_disk is None:
return
path, ext = os.path.splitext(lora_on_disk.filename) path, ext = os.path.splitext(lora_on_disk.filename)
...@@ -66,9 +68,10 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): ...@@ -66,9 +68,10 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
return item return item
def list_items(self): def list_items(self):
for index, name in enumerate(networks.available_networks): # instantiate a list to protect against concurrent modification
names = list(networks.available_networks)
for index, name in enumerate(names):
item = self.create_item(name, index) item = self.create_item(name, index)
if item is not None: if item is not None:
yield item yield item
......
...@@ -23,11 +23,12 @@ class ExtraOptionsSection(scripts.Script): ...@@ -23,11 +23,12 @@ class ExtraOptionsSection(scripts.Script):
self.setting_names = [] self.setting_names = []
self.infotext_fields = [] self.infotext_fields = []
extra_options = shared.opts.extra_options_img2img if is_img2img else shared.opts.extra_options_txt2img extra_options = shared.opts.extra_options_img2img if is_img2img else shared.opts.extra_options_txt2img
elem_id_tabname = "extra_options_" + ("img2img" if is_img2img else "txt2img")
mapping = {k: v for v, k in generation_parameters_copypaste.infotext_to_setting_name_mapping} mapping = {k: v for v, k in generation_parameters_copypaste.infotext_to_setting_name_mapping}
with gr.Blocks() as interface: with gr.Blocks() as interface:
with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and extra_options else gr.Group(): with gr.Accordion("Options", open=False, elem_id=elem_id_tabname) if shared.opts.extra_options_accordion and extra_options else gr.Group(elem_id=elem_id_tabname):
row_count = math.ceil(len(extra_options) / shared.opts.extra_options_cols) row_count = math.ceil(len(extra_options) / shared.opts.extra_options_cols)
...@@ -64,11 +65,14 @@ class ExtraOptionsSection(scripts.Script): ...@@ -64,11 +65,14 @@ class ExtraOptionsSection(scripts.Script):
p.override_settings[name] = value p.override_settings[name] = value
shared.options_templates.update(shared.options_section(('ui', "User interface"), { shared.options_templates.update(shared.options_section(('settings_in_ui', "Settings in UI", "ui"), {
"extra_options_txt2img": shared.OptionInfo([], "Options in main UI - txt2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img interfaces").needs_reload_ui(), "settings_in_ui": shared.OptionHTML("""
"extra_options_img2img": shared.OptionInfo([], "Options in main UI - img2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in img2img interfaces").needs_reload_ui(), This page allows you to add some settings to the main interface of txt2img and img2img tabs.
"extra_options_cols": shared.OptionInfo(1, "Options in main UI - number of columns", gr.Number, {"precision": 0}).needs_reload_ui(), """),
"extra_options_accordion": shared.OptionInfo(False, "Options in main UI - place into an accordion").needs_reload_ui() "extra_options_txt2img": shared.OptionInfo([], "Settings for txt2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img interfaces").needs_reload_ui(),
"extra_options_img2img": shared.OptionInfo([], "Settings for img2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in img2img interfaces").needs_reload_ui(),
"extra_options_cols": shared.OptionInfo(1, "Number of columns for added settings", gr.Slider, {"step": 1, "minimum": 1, "maximum": 20}).info("displayed amount will depend on the actual browser window width").needs_reload_ui(),
"extra_options_accordion": shared.OptionInfo(False, "Place added settings into an accordion").needs_reload_ui()
})) }))
This diff is collapsed.
import hypertile
from modules import scripts, script_callbacks, shared
from scripts.hypertile_xyz import add_axis_options
class ScriptHypertile(scripts.Script):
name = "Hypertile"
def title(self):
return self.name
def show(self, is_img2img):
return scripts.AlwaysVisible
def process(self, p, *args):
hypertile.set_hypertile_seed(p.all_seeds[0])
configure_hypertile(p.width, p.height, enable_unet=shared.opts.hypertile_enable_unet)
self.add_infotext(p)
def before_hr(self, p, *args):
enable = shared.opts.hypertile_enable_unet_secondpass or shared.opts.hypertile_enable_unet
# exclusive hypertile seed for the second pass
if enable:
hypertile.set_hypertile_seed(p.all_seeds[0])
configure_hypertile(p.hr_upscale_to_x, p.hr_upscale_to_y, enable_unet=enable)
if enable and not shared.opts.hypertile_enable_unet:
p.extra_generation_params["Hypertile U-Net second pass"] = True
self.add_infotext(p, add_unet_params=True)
def add_infotext(self, p, add_unet_params=False):
def option(name):
value = getattr(shared.opts, name)
default_value = shared.opts.get_default(name)
return None if value == default_value else value
if shared.opts.hypertile_enable_unet:
p.extra_generation_params["Hypertile U-Net"] = True
if shared.opts.hypertile_enable_unet or add_unet_params:
p.extra_generation_params["Hypertile U-Net max depth"] = option('hypertile_max_depth_unet')
p.extra_generation_params["Hypertile U-Net max tile size"] = option('hypertile_max_tile_unet')
p.extra_generation_params["Hypertile U-Net swap size"] = option('hypertile_swap_size_unet')
if shared.opts.hypertile_enable_vae:
p.extra_generation_params["Hypertile VAE"] = True
p.extra_generation_params["Hypertile VAE max depth"] = option('hypertile_max_depth_vae')
p.extra_generation_params["Hypertile VAE max tile size"] = option('hypertile_max_tile_vae')
p.extra_generation_params["Hypertile VAE swap size"] = option('hypertile_swap_size_vae')
def configure_hypertile(width, height, enable_unet=True):
hypertile.hypertile_hook_model(
shared.sd_model.first_stage_model,
width,
height,
swap_size=shared.opts.hypertile_swap_size_vae,
max_depth=shared.opts.hypertile_max_depth_vae,
tile_size_max=shared.opts.hypertile_max_tile_vae,
enable=shared.opts.hypertile_enable_vae,
)
hypertile.hypertile_hook_model(
shared.sd_model.model,
width,
height,
swap_size=shared.opts.hypertile_swap_size_unet,
max_depth=shared.opts.hypertile_max_depth_unet,
tile_size_max=shared.opts.hypertile_max_tile_unet,
enable=enable_unet,
is_sdxl=shared.sd_model.is_sdxl
)
def on_ui_settings():
import gradio as gr
options = {
"hypertile_explanation": shared.OptionHTML("""
<a href='https://github.com/tfernd/HyperTile'>Hypertile</a> optimizes the self-attention layer within U-Net and VAE models,
resulting in a reduction in computation time ranging from 1 to 4 times. The larger the generated image is, the greater the
benefit.
"""),
"hypertile_enable_unet": shared.OptionInfo(False, "Enable Hypertile U-Net", infotext="Hypertile U-Net").info("enables hypertile for all modes, including hires fix second pass; noticeable change in details of the generated picture"),
"hypertile_enable_unet_secondpass": shared.OptionInfo(False, "Enable Hypertile U-Net for hires fix second pass", infotext="Hypertile U-Net second pass").info("enables hypertile just for hires fix second pass - regardless of whether the above setting is enabled"),
"hypertile_max_depth_unet": shared.OptionInfo(3, "Hypertile U-Net max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}, infotext="Hypertile U-Net max depth").info("larger = more neural network layers affected; minor effect on performance"),
"hypertile_max_tile_unet": shared.OptionInfo(256, "Hypertile U-Net max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, infotext="Hypertile U-Net max tile size").info("larger = worse performance"),
"hypertile_swap_size_unet": shared.OptionInfo(3, "Hypertile U-Net swap size", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, infotext="Hypertile U-Net swap size"),
"hypertile_enable_vae": shared.OptionInfo(False, "Enable Hypertile VAE", infotext="Hypertile VAE").info("minimal change in the generated picture"),
"hypertile_max_depth_vae": shared.OptionInfo(3, "Hypertile VAE max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}, infotext="Hypertile VAE max depth"),
"hypertile_max_tile_vae": shared.OptionInfo(128, "Hypertile VAE max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, infotext="Hypertile VAE max tile size"),
"hypertile_swap_size_vae": shared.OptionInfo(3, "Hypertile VAE swap size ", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, infotext="Hypertile VAE swap size"),
}
for name, opt in options.items():
opt.section = ('hypertile', "Hypertile")
shared.opts.add_option(name, opt)
script_callbacks.on_ui_settings(on_ui_settings)
script_callbacks.on_before_ui(add_axis_options)
from modules import scripts
from modules.shared import opts
xyz_grid = [x for x in scripts.scripts_data if x.script_class.__module__ == "xyz_grid.py"][0].module
def int_applier(value_name:str, min_range:int = -1, max_range:int = -1):
"""
Returns a function that applies the given value to the given value_name in opts.data.
"""
def validate(value_name:str, value:str):
value = int(value)
# validate value
if not min_range == -1:
assert value >= min_range, f"Value {value} for {value_name} must be greater than or equal to {min_range}"
if not max_range == -1:
assert value <= max_range, f"Value {value} for {value_name} must be less than or equal to {max_range}"
def apply_int(p, x, xs):
validate(value_name, x)
opts.data[value_name] = int(x)
return apply_int
def bool_applier(value_name:str):
"""
Returns a function that applies the given value to the given value_name in opts.data.
"""
def validate(value_name:str, value:str):
assert value.lower() in ["true", "false"], f"Value {value} for {value_name} must be either true or false"
def apply_bool(p, x, xs):
validate(value_name, x)
value_boolean = x.lower() == "true"
opts.data[value_name] = value_boolean
return apply_bool
def add_axis_options():
extra_axis_options = [
xyz_grid.AxisOption("[Hypertile] Unet First pass Enabled", str, bool_applier("hypertile_enable_unet"), choices=xyz_grid.boolean_choice(reverse=True)),
xyz_grid.AxisOption("[Hypertile] Unet Second pass Enabled", str, bool_applier("hypertile_enable_unet_secondpass"), choices=xyz_grid.boolean_choice(reverse=True)),
xyz_grid.AxisOption("[Hypertile] Unet Max Depth", int, int_applier("hypertile_max_depth_unet", 0, 3), choices=lambda: [str(x) for x in range(4)]),
xyz_grid.AxisOption("[Hypertile] Unet Max Tile Size", int, int_applier("hypertile_max_tile_unet", 0, 512)),
xyz_grid.AxisOption("[Hypertile] Unet Swap Size", int, int_applier("hypertile_swap_size_unet", 0, 64)),
xyz_grid.AxisOption("[Hypertile] VAE Enabled", str, bool_applier("hypertile_enable_vae"), choices=xyz_grid.boolean_choice(reverse=True)),
xyz_grid.AxisOption("[Hypertile] VAE Max Depth", int, int_applier("hypertile_max_depth_vae", 0, 3), choices=lambda: [str(x) for x in range(4)]),
xyz_grid.AxisOption("[Hypertile] VAE Max Tile Size", int, int_applier("hypertile_max_tile_vae", 0, 512)),
xyz_grid.AxisOption("[Hypertile] VAE Swap Size", int, int_applier("hypertile_swap_size_vae", 0, 64)),
]
set_a = {opt.label for opt in xyz_grid.axis_options}
set_b = {opt.label for opt in extra_axis_options}
if set_a.intersection(set_b):
return
xyz_grid.axis_options.extend(extra_axis_options)
...@@ -130,6 +130,10 @@ function extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePromp ...@@ -130,6 +130,10 @@ function extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePromp
} else { } else {
promptContainer.insertBefore(prompt, promptContainer.firstChild); promptContainer.insertBefore(prompt, promptContainer.firstChild);
} }
if (elem) {
elem.classList.toggle('extra-page-prompts-active', showNegativePrompt || showPrompt);
}
} }
...@@ -388,3 +392,9 @@ function extraNetworksRefreshSingleCard(page, tabname, name) { ...@@ -388,3 +392,9 @@ function extraNetworksRefreshSingleCard(page, tabname, name) {
} }
}); });
} }
window.addEventListener("keydown", function(event) {
if (event.key == "Escape") {
closePopup();
}
});
...@@ -34,7 +34,7 @@ function updateOnBackgroundChange() { ...@@ -34,7 +34,7 @@ function updateOnBackgroundChange() {
if (modalImage && modalImage.offsetParent) { if (modalImage && modalImage.offsetParent) {
let currentButton = selected_gallery_button(); let currentButton = selected_gallery_button();
let preview = gradioApp().querySelectorAll('.livePreview > img'); let preview = gradioApp().querySelectorAll('.livePreview > img');
if (preview.length > 0) { if (opts.js_live_preview_in_modal_lightbox && preview.length > 0) {
// show preview image if available // show preview image if available
modalImage.src = preview[preview.length - 1].src; modalImage.src = preview[preview.length - 1].src;
} else if (currentButton?.children?.length > 0 && modalImage.src != currentButton.children[0].src) { } else if (currentButton?.children?.length > 0 && modalImage.src != currentButton.children[0].src) {
......
...@@ -44,3 +44,28 @@ onUiLoaded(function() { ...@@ -44,3 +44,28 @@ onUiLoaded(function() {
buttonShowAllPages.addEventListener("click", settingsShowAllTabs); buttonShowAllPages.addEventListener("click", settingsShowAllTabs);
}); });
onOptionsChanged(function() {
if (gradioApp().querySelector('#settings .settings-category')) return;
var sectionMap = {};
gradioApp().querySelectorAll('#settings > div > button').forEach(function(x) {
sectionMap[x.textContent.trim()] = x;
});
opts._categories.forEach(function(x) {
var section = x[0];
var category = x[1];
var span = document.createElement('SPAN');
span.textContent = category;
span.className = 'settings-category';
var sectionElem = sectionMap[section];
if (!sectionElem) return;
sectionElem.parentElement.insertBefore(span, sectionElem);
});
});
...@@ -170,6 +170,23 @@ function submit_img2img() { ...@@ -170,6 +170,23 @@ function submit_img2img() {
return res; return res;
} }
function submit_extras() {
showSubmitButtons('extras', false);
var id = randomId();
requestProgress(id, gradioApp().getElementById('extras_gallery_container'), gradioApp().getElementById('extras_gallery'), function() {
showSubmitButtons('extras', true);
});
var res = create_submit_args(arguments);
res[0] = id;
console.log(res);
return res;
}
function restoreProgressTxt2img() { function restoreProgressTxt2img() {
showRestoreProgressButton("txt2img", false); showRestoreProgressButton("txt2img", false);
var id = localGet("txt2img_task_id"); var id = localGet("txt2img_task_id");
...@@ -198,9 +215,33 @@ function restoreProgressImg2img() { ...@@ -198,9 +215,33 @@ function restoreProgressImg2img() {
} }
/**
* Configure the width and height elements on `tabname` to accept
* pasting of resolutions in the form of "width x height".
*/
function setupResolutionPasting(tabname) {
var width = gradioApp().querySelector(`#${tabname}_width input[type=number]`);
var height = gradioApp().querySelector(`#${tabname}_height input[type=number]`);
for (const el of [width, height]) {
el.addEventListener('paste', function(event) {
var pasteData = event.clipboardData.getData('text/plain');
var parsed = pasteData.match(/^\s*(\d+)\D+(\d+)\s*$/);
if (parsed) {
width.value = parsed[1];
height.value = parsed[2];
updateInput(width);
updateInput(height);
event.preventDefault();
}
});
}
}
onUiLoaded(function() { onUiLoaded(function() {
showRestoreProgressButton('txt2img', localGet("txt2img_task_id")); showRestoreProgressButton('txt2img', localGet("txt2img_task_id"));
showRestoreProgressButton('img2img', localGet("img2img_task_id")); showRestoreProgressButton('img2img', localGet("img2img_task_id"));
setupResolutionPasting('txt2img');
setupResolutionPasting('img2img');
}); });
......
...@@ -22,7 +22,6 @@ from modules.api import models ...@@ -22,7 +22,6 @@ from modules.api import models
from modules.shared import opts from modules.shared import opts
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
from PIL import PngImagePlugin, Image from PIL import PngImagePlugin, Image
from modules.sd_models_config import find_checkpoint_config_near_filename from modules.sd_models_config import find_checkpoint_config_near_filename
...@@ -235,7 +234,6 @@ class Api: ...@@ -235,7 +234,6 @@ class Api:
self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"]) self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse) self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse) self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse)
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse) self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse)
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse) self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse)
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse) self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)
...@@ -675,19 +673,6 @@ class Api: ...@@ -675,19 +673,6 @@ class Api:
finally: finally:
shared.state.end() shared.state.end()
def preprocess(self, args: dict):
try:
shared.state.begin(job="preprocess")
preprocess(**args) # quick operation unless blip/booru interrogation is enabled
shared.state.end()
return models.PreprocessResponse(info='preprocess complete')
except KeyError as e:
return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}")
except Exception as e:
return models.PreprocessResponse(info=f"preprocess error: {e}")
finally:
shared.state.end()
def train_embedding(self, args: dict): def train_embedding(self, args: dict):
try: try:
shared.state.begin(job="train_embedding") shared.state.begin(job="train_embedding")
......
...@@ -202,9 +202,6 @@ class TrainResponse(BaseModel): ...@@ -202,9 +202,6 @@ class TrainResponse(BaseModel):
class CreateResponse(BaseModel): class CreateResponse(BaseModel):
info: str = Field(title="Create info", description="Response string from create embedding or hypernetwork task.") info: str = Field(title="Create info", description="Response string from create embedding or hypernetwork task.")
class PreprocessResponse(BaseModel):
info: str = Field(title="Preprocess info", description="Response string from preprocessing task.")
fields = {} fields = {}
for key, metadata in opts.data_labels.items(): for key, metadata in opts.data_labels.items():
value = opts.data.get(key) value = opts.data.get(key)
......
...@@ -32,7 +32,7 @@ def dump_cache(): ...@@ -32,7 +32,7 @@ def dump_cache():
with cache_lock: with cache_lock:
cache_filename_tmp = cache_filename + "-" cache_filename_tmp = cache_filename + "-"
with open(cache_filename_tmp, "w", encoding="utf8") as file: with open(cache_filename_tmp, "w", encoding="utf8") as file:
json.dump(cache_data, file, indent=4) json.dump(cache_data, file, indent=4, ensure_ascii=False)
os.replace(cache_filename_tmp, cache_filename) os.replace(cache_filename_tmp, cache_filename)
......
...@@ -70,6 +70,7 @@ parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="pre ...@@ -70,6 +70,7 @@ parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="pre
parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization") parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization")
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI") parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower) parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
parser.add_argument("--use-ipex", action="store_true", help="use Intel XPU as torch device")
parser.add_argument("--disable-model-loading-ram-optimization", action='store_true', help="disable an optimization that reduces RAM use when loading a model") parser.add_argument("--disable-model-loading-ram-optimization", action='store_true', help="disable an optimization that reduces RAM use when loading a model")
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests") parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None) parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
......
...@@ -8,6 +8,13 @@ from modules import errors, shared ...@@ -8,6 +8,13 @@ from modules import errors, shared
if sys.platform == "darwin": if sys.platform == "darwin":
from modules import mac_specific from modules import mac_specific
if shared.cmd_opts.use_ipex:
from modules import xpu_specific
def has_xpu() -> bool:
return shared.cmd_opts.use_ipex and xpu_specific.has_xpu
def has_mps() -> bool: def has_mps() -> bool:
if sys.platform != "darwin": if sys.platform != "darwin":
...@@ -30,6 +37,9 @@ def get_optimal_device_name(): ...@@ -30,6 +37,9 @@ def get_optimal_device_name():
if has_mps(): if has_mps():
return "mps" return "mps"
if has_xpu():
return xpu_specific.get_xpu_device_string()
return "cpu" return "cpu"
...@@ -38,7 +48,7 @@ def get_optimal_device(): ...@@ -38,7 +48,7 @@ def get_optimal_device():
def get_device_for(task): def get_device_for(task):
if task in shared.cmd_opts.use_cpu: if task in shared.cmd_opts.use_cpu or "all" in shared.cmd_opts.use_cpu:
return cpu return cpu
return get_optimal_device() return get_optimal_device()
...@@ -54,6 +64,9 @@ def torch_gc(): ...@@ -54,6 +64,9 @@ def torch_gc():
if has_mps(): if has_mps():
mac_specific.torch_mps_gc() mac_specific.torch_mps_gc()
if has_xpu():
xpu_specific.torch_xpu_gc()
def enable_tf32(): def enable_tf32():
if torch.cuda.is_available(): if torch.cuda.is_available():
......
...@@ -6,6 +6,21 @@ import traceback ...@@ -6,6 +6,21 @@ import traceback
exception_records = [] exception_records = []
def format_traceback(tb):
return [[f"{x.filename}, line {x.lineno}, {x.name}", x.line] for x in traceback.extract_tb(tb)]
def format_exception(e, tb):
return {"exception": str(e), "traceback": format_traceback(tb)}
def get_exceptions():
try:
return list(reversed(exception_records))
except Exception as e:
return str(e)
def record_exception(): def record_exception():
_, e, tb = sys.exc_info() _, e, tb = sys.exc_info()
if e is None: if e is None:
...@@ -14,8 +29,7 @@ def record_exception(): ...@@ -14,8 +29,7 @@ def record_exception():
if exception_records and exception_records[-1] == e: if exception_records and exception_records[-1] == e:
return return
from modules import sysinfo exception_records.append(format_exception(e, tb))
exception_records.append(sysinfo.format_exception(e, tb))
if len(exception_records) > 5: if len(exception_records) > 5:
exception_records.pop(0) exception_records.pop(0)
......
from __future__ import annotations
import configparser
import os import os
import threading import threading
import re
from modules import shared, errors, cache, scripts from modules import shared, errors, cache, scripts
from modules.gitpython_hack import Repo from modules.gitpython_hack import Repo
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401 from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
extensions = []
os.makedirs(extensions_dir, exist_ok=True) os.makedirs(extensions_dir, exist_ok=True)
...@@ -19,11 +22,55 @@ def active(): ...@@ -19,11 +22,55 @@ def active():
return [x for x in extensions if x.enabled] return [x for x in extensions if x.enabled]
class ExtensionMetadata:
filename = "metadata.ini"
config: configparser.ConfigParser
canonical_name: str
requires: list
def __init__(self, path, canonical_name):
self.config = configparser.ConfigParser()
filepath = os.path.join(path, self.filename)
if os.path.isfile(filepath):
try:
self.config.read(filepath)
except Exception:
errors.report(f"Error reading {self.filename} for extension {canonical_name}.", exc_info=True)
self.canonical_name = self.config.get("Extension", "Name", fallback=canonical_name)
self.canonical_name = canonical_name.lower().strip()
self.requires = self.get_script_requirements("Requires", "Extension")
def get_script_requirements(self, field, section, extra_section=None):
"""reads a list of requirements from the config; field is the name of the field in the ini file,
like Requires or Before, and section is the name of the [section] in the ini file; additionally,
reads more requirements from [extra_section] if specified."""
x = self.config.get(section, field, fallback='')
if extra_section:
x = x + ', ' + self.config.get(extra_section, field, fallback='')
return self.parse_list(x.lower())
def parse_list(self, text):
"""converts a line from config ("ext1 ext2, ext3 ") into a python list (["ext1", "ext2", "ext3"])"""
if not text:
return []
# both "," and " " are accepted as separator
return [x for x in re.split(r"[,\s]+", text.strip()) if x]
class Extension: class Extension:
lock = threading.Lock() lock = threading.Lock()
cached_fields = ['remote', 'commit_date', 'branch', 'commit_hash', 'version'] cached_fields = ['remote', 'commit_date', 'branch', 'commit_hash', 'version']
metadata: ExtensionMetadata
def __init__(self, name, path, enabled=True, is_builtin=False): def __init__(self, name, path, enabled=True, is_builtin=False, metadata=None):
self.name = name self.name = name
self.path = path self.path = path
self.enabled = enabled self.enabled = enabled
...@@ -36,6 +83,8 @@ class Extension: ...@@ -36,6 +83,8 @@ class Extension:
self.branch = None self.branch = None
self.remote = None self.remote = None
self.have_info_from_repo = False self.have_info_from_repo = False
self.metadata = metadata if metadata else ExtensionMetadata(self.path, name.lower())
self.canonical_name = metadata.canonical_name
def to_dict(self): def to_dict(self):
return {x: getattr(self, x) for x in self.cached_fields} return {x: getattr(self, x) for x in self.cached_fields}
...@@ -56,6 +105,7 @@ class Extension: ...@@ -56,6 +105,7 @@ class Extension:
self.do_read_info_from_repo() self.do_read_info_from_repo()
return self.to_dict() return self.to_dict()
try: try:
d = cache.cached_data_for_file('extensions-git', self.name, os.path.join(self.path, ".git"), read_from_repo) d = cache.cached_data_for_file('extensions-git', self.name, os.path.join(self.path, ".git"), read_from_repo)
self.from_dict(d) self.from_dict(d)
...@@ -136,9 +186,6 @@ class Extension: ...@@ -136,9 +186,6 @@ class Extension:
def list_extensions(): def list_extensions():
extensions.clear() extensions.clear()
if not os.path.isdir(extensions_dir):
return
if shared.cmd_opts.disable_all_extensions: if shared.cmd_opts.disable_all_extensions:
print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***") print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***")
elif shared.opts.disable_all_extensions == "all": elif shared.opts.disable_all_extensions == "all":
...@@ -148,18 +195,43 @@ def list_extensions(): ...@@ -148,18 +195,43 @@ def list_extensions():
elif shared.opts.disable_all_extensions == "extra": elif shared.opts.disable_all_extensions == "extra":
print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***") print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***")
extension_paths = [] loaded_extensions = {}
for dirname in [extensions_dir, extensions_builtin_dir]:
# scan through extensions directory and load metadata
for dirname in [extensions_builtin_dir, extensions_dir]:
if not os.path.isdir(dirname): if not os.path.isdir(dirname):
return continue
for extension_dirname in sorted(os.listdir(dirname)): for extension_dirname in sorted(os.listdir(dirname)):
path = os.path.join(dirname, extension_dirname) path = os.path.join(dirname, extension_dirname)
if not os.path.isdir(path): if not os.path.isdir(path):
continue continue
extension_paths.append((extension_dirname, path, dirname == extensions_builtin_dir)) canonical_name = extension_dirname
metadata = ExtensionMetadata(path, canonical_name)
# check for duplicated canonical names
already_loaded_extension = loaded_extensions.get(metadata.canonical_name)
if already_loaded_extension is not None:
errors.report(f'Duplicate canonical name "{canonical_name}" found in extensions "{extension_dirname}" and "{already_loaded_extension.name}". Former will be discarded.', exc_info=False)
continue
is_builtin = dirname == extensions_builtin_dir
extension = Extension(name=extension_dirname, path=path, enabled=extension_dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin, metadata=metadata)
extensions.append(extension)
loaded_extensions[canonical_name] = extension
# check for requirements
for extension in extensions:
for req in extension.metadata.requires:
required_extension = loaded_extensions.get(req)
if required_extension is None:
errors.report(f'Extension "{extension.name}" requires "{req}" which is not installed.', exc_info=False)
continue
if not extension.enabled:
errors.report(f'Extension "{extension.name}" requires "{required_extension.name}" which is disabled.', exc_info=False)
continue
for dirname, path, is_builtin in extension_paths: extensions: list[Extension] = []
extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin)
extensions.append(extension)
from __future__ import annotations
import base64 import base64
import io import io
import json import json
...@@ -15,9 +16,6 @@ re_imagesize = re.compile(r"^(\d+)x(\d+)$") ...@@ -15,9 +16,6 @@ re_imagesize = re.compile(r"^(\d+)x(\d+)$")
re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$") re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$")
type_of_gr_update = type(gr.update()) type_of_gr_update = type(gr.update())
paste_fields = {}
registered_param_bindings = []
class ParamBinding: class ParamBinding:
def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=None): def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=None):
...@@ -30,6 +28,10 @@ class ParamBinding: ...@@ -30,6 +28,10 @@ class ParamBinding:
self.paste_field_names = paste_field_names or [] self.paste_field_names = paste_field_names or []
paste_fields: dict[str, dict] = {}
registered_param_bindings: list[ParamBinding] = []
def reset(): def reset():
paste_fields.clear() paste_fields.clear()
registered_param_bindings.clear() registered_param_bindings.clear()
...@@ -113,7 +115,6 @@ def register_paste_params_button(binding: ParamBinding): ...@@ -113,7 +115,6 @@ def register_paste_params_button(binding: ParamBinding):
def connect_paste_params_buttons(): def connect_paste_params_buttons():
binding: ParamBinding
for binding in registered_param_bindings: for binding in registered_param_bindings:
destination_image_component = paste_fields[binding.tabname]["init_img"] destination_image_component = paste_fields[binding.tabname]["init_img"]
fields = paste_fields[binding.tabname]["fields"] fields = paste_fields[binding.tabname]["fields"]
...@@ -313,6 +314,9 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model ...@@ -313,6 +314,9 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
if "VAE Decoder" not in res: if "VAE Decoder" not in res:
res["VAE Decoder"] = "Full" res["VAE Decoder"] = "Full"
skip = set(shared.opts.infotext_skip_pasting)
res = {k: v for k, v in res.items() if k not in skip}
return res return res
...@@ -443,3 +447,4 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component, ...@@ -443,3 +447,4 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
outputs=[], outputs=[],
show_progress=False, show_progress=False,
) )
...@@ -47,10 +47,20 @@ def Block_get_config(self): ...@@ -47,10 +47,20 @@ def Block_get_config(self):
def BlockContext_init(self, *args, **kwargs): def BlockContext_init(self, *args, **kwargs):
if scripts.scripts_current is not None:
scripts.scripts_current.before_component(self, **kwargs)
scripts.script_callbacks.before_component_callback(self, **kwargs)
res = original_BlockContext_init(self, *args, **kwargs) res = original_BlockContext_init(self, *args, **kwargs)
add_classes_to_gradio_component(self) add_classes_to_gradio_component(self)
scripts.script_callbacks.after_component_callback(self, **kwargs)
if scripts.scripts_current is not None:
scripts.scripts_current.after_component(self, **kwargs)
return res return res
......
...@@ -44,6 +44,8 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal ...@@ -44,6 +44,8 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
steps = p.steps steps = p.steps
override_settings = p.override_settings override_settings = p.override_settings
sd_model_checkpoint_override = get_closet_checkpoint_match(override_settings.get("sd_model_checkpoint", None)) sd_model_checkpoint_override = get_closet_checkpoint_match(override_settings.get("sd_model_checkpoint", None))
batch_results = None
discard_further_results = False
for i, image in enumerate(images): for i, image in enumerate(images):
state.job = f"{i+1} out of {len(images)}" state.job = f"{i+1} out of {len(images)}"
if state.skipped: if state.skipped:
...@@ -127,7 +129,21 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal ...@@ -127,7 +129,21 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
if proc is None: if proc is None:
p.override_settings.pop('save_images_replace_action', None) p.override_settings.pop('save_images_replace_action', None)
process_images(p) proc = process_images(p)
if not discard_further_results and proc:
if batch_results:
batch_results.images.extend(proc.images)
batch_results.infotexts.extend(proc.infotexts)
else:
batch_results = proc
if 0 <= shared.opts.img2img_batch_show_results_limit < len(batch_results.images):
discard_further_results = True
batch_results.images = batch_results.images[:int(shared.opts.img2img_batch_show_results_limit)]
batch_results.infotexts = batch_results.infotexts[:int(shared.opts.img2img_batch_show_results_limit)]
return batch_results
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args): def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
...@@ -212,10 +228,10 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s ...@@ -212,10 +228,10 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
with closing(p): with closing(p):
if is_batch: if is_batch:
assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled" assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
processed = process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=img2img_batch_png_info_dir)
process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=img2img_batch_png_info_dir) if processed is None:
processed = Processed(p, [], p.seed, "")
processed = Processed(p, [], p.seed, "")
else: else:
processed = modules.scripts.scripts_img2img.run(p, *args) processed = modules.scripts.scripts_img2img.run(p, *args)
if processed is None: if processed is None:
......
...@@ -3,3 +3,14 @@ import sys ...@@ -3,3 +3,14 @@ import sys
# this will break any attempt to import xformers which will prevent stability diffusion repo from trying to use it # this will break any attempt to import xformers which will prevent stability diffusion repo from trying to use it
if "--xformers" not in "".join(sys.argv): if "--xformers" not in "".join(sys.argv):
sys.modules["xformers"] = None sys.modules["xformers"] = None
# Hack to fix a changed import in torchvision 0.17+, which otherwise breaks
# basicsr; see https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/13985
try:
import torchvision.transforms.functional_tensor # noqa: F401
except ImportError:
try:
import torchvision.transforms.functional as functional
sys.modules["torchvision.transforms.functional_tensor"] = functional
except ImportError:
pass # shrug...
...@@ -6,6 +6,7 @@ import os ...@@ -6,6 +6,7 @@ import os
import shutil import shutil
import sys import sys
import importlib.util import importlib.util
import importlib.metadata
import platform import platform
import json import json
from functools import lru_cache from functools import lru_cache
...@@ -119,11 +120,16 @@ def run(command, desc=None, errdesc=None, custom_env=None, live: bool = default_ ...@@ -119,11 +120,16 @@ def run(command, desc=None, errdesc=None, custom_env=None, live: bool = default_
def is_installed(package): def is_installed(package):
try: try:
spec = importlib.util.find_spec(package) dist = importlib.metadata.distribution(package)
except ModuleNotFoundError: except importlib.metadata.PackageNotFoundError:
return False try:
spec = importlib.util.find_spec(package)
except ModuleNotFoundError:
return False
return spec is not None
return spec is not None return dist is not None
def repo_dir(name): def repo_dir(name):
...@@ -310,6 +316,26 @@ def requirements_met(requirements_file): ...@@ -310,6 +316,26 @@ def requirements_met(requirements_file):
def prepare_environment(): def prepare_environment():
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118") torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118")
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}") torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}")
if args.use_ipex:
if platform.system() == "Windows":
# The "Nuullll/intel-extension-for-pytorch" wheels were built from IPEX source for Intel Arc GPU: https://github.com/intel/intel-extension-for-pytorch/tree/xpu-main
# This is NOT an Intel official release so please use it at your own risk!!
# See https://github.com/Nuullll/intel-extension-for-pytorch/releases/tag/v2.0.110%2Bxpu-master%2Bdll-bundle for details.
#
# Strengths (over official IPEX 2.0.110 windows release):
# - AOT build (for Arc GPU only) to eliminate JIT compilation overhead: https://github.com/intel/intel-extension-for-pytorch/issues/399
# - Bundles minimal oneAPI 2023.2 dependencies into the python wheels, so users don't need to install oneAPI for the whole system.
# - Provides a compatible torchvision wheel: https://github.com/intel/intel-extension-for-pytorch/issues/465
# Limitation:
# - Only works for python 3.10
url_prefix = "https://github.com/Nuullll/intel-extension-for-pytorch/releases/download/v2.0.110%2Bxpu-master%2Bdll-bundle"
torch_command = os.environ.get('TORCH_COMMAND', f"pip install {url_prefix}/torch-2.0.0a0+gite9ebda2-cp310-cp310-win_amd64.whl {url_prefix}/torchvision-0.15.2a0+fa99a53-cp310-cp310-win_amd64.whl {url_prefix}/intel_extension_for_pytorch-2.0.110+gitc6ea20b-cp310-cp310-win_amd64.whl")
else:
# Using official IPEX release for linux since it's already an AOT build.
# However, users still have to install oneAPI toolkit and activate oneAPI environment manually.
# See https://intel.github.io/intel-extension-for-pytorch/index.html#installation for details.
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://pytorch-extension.intel.com/release-whl/stable/xpu/us/")
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.0a0 intel-extension-for-pytorch==2.0.110+gitba7f6c1 --extra-index-url {torch_index_url}")
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.20') xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.20')
...@@ -352,6 +378,8 @@ def prepare_environment(): ...@@ -352,6 +378,8 @@ def prepare_environment():
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True) run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
startup_timer.record("install torch") startup_timer.record("install torch")
if args.use_ipex:
args.skip_torch_cuda_test = True
if not args.skip_torch_cuda_test and not check_run_python("import torch; assert torch.cuda.is_available()"): if not args.skip_torch_cuda_test and not check_run_python("import torch; assert torch.cuda.is_available()"):
raise RuntimeError( raise RuntimeError(
'Torch is not able to use GPU; ' 'Torch is not able to use GPU; '
...@@ -441,7 +469,7 @@ def dump_sysinfo(): ...@@ -441,7 +469,7 @@ def dump_sysinfo():
import datetime import datetime
text = sysinfo.get() text = sysinfo.get()
filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.txt" filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.json"
with open(filename, "w", encoding="utf8") as file: with open(filename, "w", encoding="utf8") as file:
file.write(text) file.write(text)
......
import os import os
import logging import logging
try:
from tqdm.auto import tqdm
class TqdmLoggingHandler(logging.Handler):
def __init__(self, level=logging.INFO):
super().__init__(level)
def emit(self, record):
try:
msg = self.format(record)
tqdm.write(msg)
self.flush()
except Exception:
self.handleError(record)
TQDM_IMPORTED = True
except ImportError:
# tqdm does not exist before first launch
# I will import once the UI finishes seting up the enviroment and reloads.
TQDM_IMPORTED = False
def setup_logging(loglevel): def setup_logging(loglevel):
if loglevel is None: if loglevel is None:
loglevel = os.environ.get("SD_WEBUI_LOG_LEVEL") loglevel = os.environ.get("SD_WEBUI_LOG_LEVEL")
loghandlers = []
if TQDM_IMPORTED:
loghandlers.append(TqdmLoggingHandler())
if loglevel: if loglevel:
log_level = getattr(logging, loglevel.upper(), None) or logging.INFO log_level = getattr(logging, loglevel.upper(), None) or logging.INFO
logging.basicConfig( logging.basicConfig(
level=log_level, level=log_level,
format='%(asctime)s %(levelname)s [%(name)s] %(message)s', format='%(asctime)s %(levelname)s [%(name)s] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S', datefmt='%Y-%m-%d %H:%M:%S',
handlers=loghandlers
) )
import logging import logging
import torch import torch
from torch import Tensor
import platform import platform
from modules.sd_hijack_utils import CondFunc from modules.sd_hijack_utils import CondFunc
from packaging import version from packaging import version
...@@ -51,6 +52,17 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs): ...@@ -51,6 +52,17 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs):
return cumsum_func(input, *args, **kwargs) return cumsum_func(input, *args, **kwargs)
# MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046
def interpolate_with_fp32_fallback(orig_func, *args, **kwargs) -> Tensor:
try:
return orig_func(*args, **kwargs)
except RuntimeError as e:
if "not implemented for" in str(e) and "Half" in str(e):
input_tensor = args[0]
return orig_func(input_tensor.to(torch.float32), *args[1:], **kwargs).to(input_tensor.dtype)
else:
print(f"An unexpected RuntimeError occurred: {str(e)}")
if has_mps: if has_mps:
if platform.mac_ver()[0].startswith("13.2."): if platform.mac_ver()[0].startswith("13.2."):
# MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124) # MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124)
...@@ -77,6 +89,9 @@ if has_mps: ...@@ -77,6 +89,9 @@ if has_mps:
# MPS workaround for https://github.com/pytorch/pytorch/issues/96113 # MPS workaround for https://github.com/pytorch/pytorch/issues/96113
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda _, input, *args, **kwargs: len(args) == 4 and input.device.type == 'mps') CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda _, input, *args, **kwargs: len(args) == 4 and input.device.type == 'mps')
# MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046
CondFunc('torch.nn.functional.interpolate', interpolate_with_fp32_fallback, None)
# MPS workaround for https://github.com/pytorch/pytorch/issues/92311 # MPS workaround for https://github.com/pytorch/pytorch/issues/92311
if platform.processor() == 'i386': if platform.processor() == 'i386':
for funcName in ['torch.argmax', 'torch.Tensor.argmax']: for funcName in ['torch.argmax', 'torch.Tensor.argmax']:
......
...@@ -24,10 +24,15 @@ from pytorch_lightning.utilities.distributed import rank_zero_only ...@@ -24,10 +24,15 @@ from pytorch_lightning.utilities.distributed import rank_zero_only
from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
from ldm.modules.ema import LitEma from ldm.modules.ema import LitEma
from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL from ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL
from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ddim import DDIMSampler
try:
from ldm.models.autoencoder import VQModelInterface
except Exception:
class VQModelInterface:
pass
__conditioning_keys__ = {'concat': 'c_concat', __conditioning_keys__ = {'concat': 'c_concat',
'crossattn': 'c_crossattn', 'crossattn': 'c_crossattn',
......
import json import json
import sys import sys
from dataclasses import dataclass
import gradio as gr import gradio as gr
...@@ -8,13 +9,14 @@ from modules.shared_cmd_options import cmd_opts ...@@ -8,13 +9,14 @@ from modules.shared_cmd_options import cmd_opts
class OptionInfo: class OptionInfo:
def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None, comment_before='', comment_after='', infotext=None, restrict_api=False): def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None, comment_before='', comment_after='', infotext=None, restrict_api=False, category_id=None):
self.default = default self.default = default
self.label = label self.label = label
self.component = component self.component = component
self.component_args = component_args self.component_args = component_args
self.onchange = onchange self.onchange = onchange
self.section = section self.section = section
self.category_id = category_id
self.refresh = refresh self.refresh = refresh
self.do_not_save = False self.do_not_save = False
...@@ -63,7 +65,11 @@ class OptionHTML(OptionInfo): ...@@ -63,7 +65,11 @@ class OptionHTML(OptionInfo):
def options_section(section_identifier, options_dict): def options_section(section_identifier, options_dict):
for v in options_dict.values(): for v in options_dict.values():
v.section = section_identifier if len(section_identifier) == 2:
v.section = section_identifier
elif len(section_identifier) == 3:
v.section = section_identifier[0:2]
v.category_id = section_identifier[2]
return options_dict return options_dict
...@@ -76,7 +82,7 @@ class Options: ...@@ -76,7 +82,7 @@ class Options:
def __init__(self, data_labels: dict[str, OptionInfo], restricted_opts): def __init__(self, data_labels: dict[str, OptionInfo], restricted_opts):
self.data_labels = data_labels self.data_labels = data_labels
self.data = {k: v.default for k, v in self.data_labels.items()} self.data = {k: v.default for k, v in self.data_labels.items() if not v.do_not_save}
self.restricted_opts = restricted_opts self.restricted_opts = restricted_opts
def __setattr__(self, key, value): def __setattr__(self, key, value):
...@@ -158,7 +164,7 @@ class Options: ...@@ -158,7 +164,7 @@ class Options:
assert not cmd_opts.freeze_settings, "saving settings is disabled" assert not cmd_opts.freeze_settings, "saving settings is disabled"
with open(filename, "w", encoding="utf8") as file: with open(filename, "w", encoding="utf8") as file:
json.dump(self.data, file, indent=4) json.dump(self.data, file, indent=4, ensure_ascii=False)
def same_type(self, x, y): def same_type(self, x, y):
if x is None or y is None: if x is None or y is None:
...@@ -206,23 +212,59 @@ class Options: ...@@ -206,23 +212,59 @@ class Options:
d = {k: self.data.get(k, v.default) for k, v in self.data_labels.items()} d = {k: self.data.get(k, v.default) for k, v in self.data_labels.items()}
d["_comments_before"] = {k: v.comment_before for k, v in self.data_labels.items() if v.comment_before is not None} d["_comments_before"] = {k: v.comment_before for k, v in self.data_labels.items() if v.comment_before is not None}
d["_comments_after"] = {k: v.comment_after for k, v in self.data_labels.items() if v.comment_after is not None} d["_comments_after"] = {k: v.comment_after for k, v in self.data_labels.items() if v.comment_after is not None}
item_categories = {}
for item in self.data_labels.values():
category = categories.mapping.get(item.category_id)
category = "Uncategorized" if category is None else category.label
if category not in item_categories:
item_categories[category] = item.section[1]
# _categories is a list of pairs: [section, category]. Each section (a setting page) will get a special heading above it with the category as text.
d["_categories"] = [[v, k] for k, v in item_categories.items()] + [["Defaults", "Other"]]
return json.dumps(d) return json.dumps(d)
def add_option(self, key, info): def add_option(self, key, info):
self.data_labels[key] = info self.data_labels[key] = info
if key not in self.data: if key not in self.data and not info.do_not_save:
self.data[key] = info.default self.data[key] = info.default
def reorder(self): def reorder(self):
"""reorder settings so that all items related to section always go together""" """Reorder settings so that:
- all items related to section always go together
- all sections belonging to a category go together
- sections inside a category are ordered alphabetically
- categories are ordered by creation order
Category is a superset of sections: for category "postprocessing" there could be multiple sections: "face restoration", "upscaling".
This function also changes items' category_id so that all items belonging to a section have the same category_id.
"""
category_ids = {}
section_categories = {}
section_ids = {}
settings_items = self.data_labels.items() settings_items = self.data_labels.items()
for _, item in settings_items: for _, item in settings_items:
if item.section not in section_ids: if item.section not in section_categories:
section_ids[item.section] = len(section_ids) section_categories[item.section] = item.category_id
for _, item in settings_items:
item.category_id = section_categories.get(item.section)
for category_id in categories.mapping:
if category_id not in category_ids:
category_ids[category_id] = len(category_ids)
self.data_labels = dict(sorted(settings_items, key=lambda x: section_ids[x[1].section])) def sort_key(x):
item: OptionInfo = x[1]
category_order = category_ids.get(item.category_id, len(category_ids))
section_order = item.section[1]
return category_order, section_order
self.data_labels = dict(sorted(settings_items, key=sort_key))
def cast_value(self, key, value): def cast_value(self, key, value):
"""casts an arbitrary to the same type as this setting's value with key """casts an arbitrary to the same type as this setting's value with key
...@@ -245,3 +287,22 @@ class Options: ...@@ -245,3 +287,22 @@ class Options:
value = expected_type(value) value = expected_type(value)
return value return value
@dataclass
class OptionsCategory:
id: str
label: str
class OptionsCategories:
def __init__(self):
self.mapping = {}
def register_category(self, category_id, label):
if category_id in self.mapping:
return category_id
self.mapping[category_id] = OptionsCategory(category_id, label)
categories = OptionsCategories()
...@@ -29,11 +29,7 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, ...@@ -29,11 +29,7 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
image_list = shared.listfiles(input_dir) image_list = shared.listfiles(input_dir)
for filename in image_list: for filename in image_list:
try: yield filename, filename
image = Image.open(filename)
except Exception:
continue
yield image, filename
else: else:
assert image, 'image not selected' assert image, 'image not selected'
yield image, None yield image, None
...@@ -45,43 +41,97 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, ...@@ -45,43 +41,97 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
infotext = '' infotext = ''
for image_data, name in get_images(extras_mode, image, image_folder, input_dir): data_to_process = list(get_images(extras_mode, image, image_folder, input_dir))
shared.state.job_count = len(data_to_process)
for image_placeholder, name in data_to_process:
image_data: Image.Image image_data: Image.Image
shared.state.nextjob()
shared.state.textinfo = name shared.state.textinfo = name
shared.state.skipped = False
if shared.state.interrupted:
break
if isinstance(image_placeholder, str):
try:
image_data = Image.open(image_placeholder)
except Exception:
continue
else:
image_data = image_placeholder
shared.state.assign_current_image(image_data)
parameters, existing_pnginfo = images.read_info_from_image(image_data) parameters, existing_pnginfo = images.read_info_from_image(image_data)
if parameters: if parameters:
existing_pnginfo["parameters"] = parameters existing_pnginfo["parameters"] = parameters
pp = scripts_postprocessing.PostprocessedImage(image_data.convert("RGB")) initial_pp = scripts_postprocessing.PostprocessedImage(image_data.convert("RGB"))
scripts.scripts_postproc.run(pp, args) scripts.scripts_postproc.run(initial_pp, args)
if opts.use_original_name_batch and name is not None: if shared.state.skipped:
basename = os.path.splitext(os.path.basename(name))[0] continue
else:
basename = '' used_suffixes = {}
for pp in [initial_pp, *initial_pp.extra_images]:
suffix = pp.get_suffix(used_suffixes)
infotext = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in pp.info.items() if v is not None]) if opts.use_original_name_batch and name is not None:
basename = os.path.splitext(os.path.basename(name))[0]
forced_filename = basename + suffix
else:
basename = ''
forced_filename = None
if opts.enable_pnginfo: infotext = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in pp.info.items() if v is not None])
pp.image.info = existing_pnginfo
pp.image.info["postprocessing"] = infotext
if save_output: if opts.enable_pnginfo:
images.save_image(pp.image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None) pp.image.info = existing_pnginfo
pp.image.info["postprocessing"] = infotext
if extras_mode != 2 or show_extras_results: if save_output:
outputs.append(pp.image) fullfn, _ = images.save_image(pp.image, path=outpath, basename=basename, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=forced_filename, suffix=suffix)
if pp.caption:
caption_filename = os.path.splitext(fullfn)[0] + ".txt"
if os.path.isfile(caption_filename):
with open(caption_filename, encoding="utf8") as file:
existing_caption = file.read().strip()
else:
existing_caption = ""
action = shared.opts.postprocessing_existing_caption_action
if action == 'Prepend' and existing_caption:
caption = f"{existing_caption} {pp.caption}"
elif action == 'Append' and existing_caption:
caption = f"{pp.caption} {existing_caption}"
elif action == 'Keep' and existing_caption:
caption = existing_caption
else:
caption = pp.caption
caption = caption.strip()
if caption:
with open(caption_filename, "w", encoding="utf8") as file:
file.write(caption)
if extras_mode != 2 or show_extras_results:
outputs.append(pp.image)
image_data.close() image_data.close()
devices.torch_gc() devices.torch_gc()
shared.state.end()
return outputs, ui_common.plaintext_to_html(infotext), '' return outputs, ui_common.plaintext_to_html(infotext), ''
def run_postprocessing_webui(id_task, *args, **kwargs):
return run_postprocessing(*args, **kwargs)
def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True): def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True):
"""old handler for API""" """old handler for API"""
...@@ -97,9 +147,11 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ ...@@ -97,9 +147,11 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
"upscaler_2_visibility": extras_upscaler_2_visibility, "upscaler_2_visibility": extras_upscaler_2_visibility,
}, },
"GFPGAN": { "GFPGAN": {
"enable": True,
"gfpgan_visibility": gfpgan_visibility, "gfpgan_visibility": gfpgan_visibility,
}, },
"CodeFormer": { "CodeFormer": {
"enable": True,
"codeformer_visibility": codeformer_visibility, "codeformer_visibility": codeformer_visibility,
"codeformer_weight": codeformer_weight, "codeformer_weight": codeformer_weight,
}, },
......
...@@ -679,8 +679,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter ...@@ -679,8 +679,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Size": f"{p.width}x{p.height}", "Size": f"{p.width}x{p.height}",
"Model hash": p.sd_model_hash if opts.add_model_hash_to_info else None, "Model hash": p.sd_model_hash if opts.add_model_hash_to_info else None,
"Model": p.sd_model_name if opts.add_model_name_to_info else None, "Model": p.sd_model_name if opts.add_model_name_to_info else None,
"VAE hash": p.sd_vae_hash if opts.add_model_hash_to_info else None, "VAE hash": p.sd_vae_hash if opts.add_vae_hash_to_info else None,
"VAE": p.sd_vae_name if opts.add_model_name_to_info else None, "VAE": p.sd_vae_name if opts.add_vae_name_to_info else None,
"Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])), "Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
"Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), "Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
...@@ -799,7 +799,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: ...@@ -799,7 +799,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
infotexts = [] infotexts = []
output_images = [] output_images = []
with torch.no_grad(), p.sd_model.ema_scope(): with torch.no_grad(), p.sd_model.ema_scope():
with devices.autocast(): with devices.autocast():
p.init(p.all_prompts, p.all_seeds, p.all_subseeds) p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
...@@ -873,7 +872,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: ...@@ -873,7 +872,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
else: else:
if opts.sd_vae_decode_method != 'Full': if opts.sd_vae_decode_method != 'Full':
p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
x_samples_ddim = torch.stack(x_samples_ddim).float() x_samples_ddim = torch.stack(x_samples_ddim).float()
...@@ -940,21 +938,20 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: ...@@ -940,21 +938,20 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if opts.enable_pnginfo: if opts.enable_pnginfo:
image.info["parameters"] = text image.info["parameters"] = text
output_images.append(image) output_images.append(image)
if save_samples and hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]): if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay:
image_mask = p.mask_for_overlay.convert('RGB') if opts.return_mask or opts.save_mask:
image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA') image_mask = p.mask_for_overlay.convert('RGB')
if save_samples and opts.save_mask:
if opts.save_mask: images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask")
images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask") if opts.return_mask:
output_images.append(image_mask)
if opts.save_mask_composite:
images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite") if opts.return_mask_composite or opts.save_mask_composite:
image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
if opts.return_mask: if save_samples and opts.save_mask_composite:
output_images.append(image_mask) images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite")
if opts.return_mask_composite:
if opts.return_mask_composite: output_images.append(image_mask_composite)
output_images.append(image_mask_composite)
del x_samples_ddim del x_samples_ddim
...@@ -1147,6 +1144,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): ...@@ -1147,6 +1144,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
if not self.enable_hr: if not self.enable_hr:
return samples return samples
devices.torch_gc()
if self.latent_scale_mode is None: if self.latent_scale_mode is None:
decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32) decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
...@@ -1156,8 +1154,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): ...@@ -1156,8 +1154,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
with sd_models.SkipWritingToConfig(): with sd_models.SkipWritingToConfig():
sd_models.reload_model_weights(info=self.hr_checkpoint_info) sd_models.reload_model_weights(info=self.hr_checkpoint_info)
devices.torch_gc()
return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts) return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)
def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts): def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
...@@ -1165,7 +1161,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): ...@@ -1165,7 +1161,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
return samples return samples
self.is_hr_pass = True self.is_hr_pass = True
target_width = self.hr_upscale_to_x target_width = self.hr_upscale_to_x
target_height = self.hr_upscale_to_y target_height = self.hr_upscale_to_y
...@@ -1254,7 +1249,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): ...@@ -1254,7 +1249,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True) decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
self.is_hr_pass = False self.is_hr_pass = False
return decoded_samples return decoded_samples
def close(self): def close(self):
......
...@@ -110,7 +110,7 @@ class ImageRNG: ...@@ -110,7 +110,7 @@ class ImageRNG:
self.is_first = True self.is_first = True
def first(self): def first(self):
noise_shape = self.shape if self.seed_resize_from_h <= 0 or self.seed_resize_from_w <= 0 else (self.shape[0], self.seed_resize_from_h // 8, self.seed_resize_from_w // 8) noise_shape = self.shape if self.seed_resize_from_h <= 0 or self.seed_resize_from_w <= 0 else (self.shape[0], int(self.seed_resize_from_h) // 8, int(self.seed_resize_from_w // 8))
xs = [] xs = []
......
...@@ -311,20 +311,113 @@ scripts_data = [] ...@@ -311,20 +311,113 @@ scripts_data = []
postprocessing_scripts_data = [] postprocessing_scripts_data = []
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"]) ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
def topological_sort(dependencies):
"""Accepts a dictionary mapping name to its dependencies, returns a list of names ordered according to dependencies.
Ignores errors relating to missing dependeencies or circular dependencies
"""
visited = {}
result = []
def inner(name):
visited[name] = True
for dep in dependencies.get(name, []):
if dep in dependencies and dep not in visited:
inner(dep)
result.append(name)
for depname in dependencies:
if depname not in visited:
inner(depname)
return result
@dataclass
class ScriptWithDependencies:
script_canonical_name: str
file: ScriptFile
requires: list
load_before: list
load_after: list
def list_scripts(scriptdirname, extension, *, include_extensions=True): def list_scripts(scriptdirname, extension, *, include_extensions=True):
scripts_list = [] scripts = {}
loaded_extensions = {ext.canonical_name: ext for ext in extensions.active()}
loaded_extensions_scripts = {ext.canonical_name: [] for ext in extensions.active()}
basedir = os.path.join(paths.script_path, scriptdirname) # build script dependency map
if os.path.exists(basedir): root_script_basedir = os.path.join(paths.script_path, scriptdirname)
for filename in sorted(os.listdir(basedir)): if os.path.exists(root_script_basedir):
scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename))) for filename in sorted(os.listdir(root_script_basedir)):
if not os.path.isfile(os.path.join(root_script_basedir, filename)):
continue
if os.path.splitext(filename)[1].lower() != extension:
continue
script_file = ScriptFile(paths.script_path, filename, os.path.join(root_script_basedir, filename))
scripts[filename] = ScriptWithDependencies(filename, script_file, [], [], [])
if include_extensions: if include_extensions:
for ext in extensions.active(): for ext in extensions.active():
scripts_list += ext.list_files(scriptdirname, extension) extension_scripts_list = ext.list_files(scriptdirname, extension)
for extension_script in extension_scripts_list:
scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)] if not os.path.isfile(extension_script.path):
continue
script_canonical_name = ("builtin/" if ext.is_builtin else "") + ext.canonical_name + "/" + extension_script.filename
relative_path = scriptdirname + "/" + extension_script.filename
script = ScriptWithDependencies(
script_canonical_name=script_canonical_name,
file=extension_script,
requires=ext.metadata.get_script_requirements("Requires", relative_path, scriptdirname),
load_before=ext.metadata.get_script_requirements("Before", relative_path, scriptdirname),
load_after=ext.metadata.get_script_requirements("After", relative_path, scriptdirname),
)
scripts[script_canonical_name] = script
loaded_extensions_scripts[ext.canonical_name].append(script)
for script_canonical_name, script in scripts.items():
# load before requires inverse dependency
# in this case, append the script name into the load_after list of the specified script
for load_before in script.load_before:
# if this requires an individual script to be loaded before
other_script = scripts.get(load_before)
if other_script:
other_script.load_after.append(script_canonical_name)
# if this requires an extension
other_extension_scripts = loaded_extensions_scripts.get(load_before)
if other_extension_scripts:
for other_script in other_extension_scripts:
other_script.load_after.append(script_canonical_name)
# if After mentions an extension, remove it and instead add all of its scripts
for load_after in list(script.load_after):
if load_after not in scripts and load_after in loaded_extensions_scripts:
script.load_after.remove(load_after)
for other_script in loaded_extensions_scripts.get(load_after, []):
script.load_after.append(other_script.script_canonical_name)
dependencies = {}
for script_canonical_name, script in scripts.items():
for required_script in script.requires:
if required_script not in scripts and required_script not in loaded_extensions:
errors.report(f'Script "{script_canonical_name}" requires "{required_script}" to be loaded, but it is not.', exc_info=False)
dependencies[script_canonical_name] = script.load_after
ordered_scripts = topological_sort(dependencies)
scripts_list = [scripts[script_canonical_name].file for script_canonical_name in ordered_scripts]
return scripts_list return scripts_list
...@@ -365,15 +458,9 @@ def load_scripts(): ...@@ -365,15 +458,9 @@ def load_scripts():
elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing): elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing):
postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module)) postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
def orderby(basedir): # here the scripts_list is already ordered
# 1st webui, 2nd extensions-builtin, 3rd extensions # processing_script is not considered though
priority = {os.path.join(paths.script_path, "extensions-builtin"):1, paths.script_path:0} for scriptfile in scripts_list:
for key in priority:
if basedir.startswith(key):
return priority[key]
return 9999
for scriptfile in sorted(scripts_list, key=lambda x: [orderby(x.basedir), x]):
try: try:
if scriptfile.basedir != paths.script_path: if scriptfile.basedir != paths.script_path:
sys.path = [scriptfile.basedir] + sys.path sys.path = [scriptfile.basedir] + sys.path
...@@ -473,17 +560,25 @@ class ScriptRunner: ...@@ -473,17 +560,25 @@ class ScriptRunner:
on_after.clear() on_after.clear()
def create_script_ui(self, script): def create_script_ui(self, script):
import modules.api.models as api_models
script.args_from = len(self.inputs) script.args_from = len(self.inputs)
script.args_to = len(self.inputs) script.args_to = len(self.inputs)
try:
self.create_script_ui_inner(script)
except Exception:
errors.report(f"Error creating UI for {script.name}: ", exc_info=True)
def create_script_ui_inner(self, script):
import modules.api.models as api_models
controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img) controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
if controls is None: if controls is None:
return return
script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower() script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower()
api_args = [] api_args = []
for control in controls: for control in controls:
......
import dataclasses
import os import os
import gradio as gr import gradio as gr
from modules import errors, shared from modules import errors, shared
@dataclasses.dataclass
class PostprocessedImageSharedInfo:
target_width: int = None
target_height: int = None
class PostprocessedImage: class PostprocessedImage:
def __init__(self, image): def __init__(self, image):
self.image = image self.image = image
self.info = {} self.info = {}
self.shared = PostprocessedImageSharedInfo()
self.extra_images = []
self.nametags = []
self.disable_processing = False
self.caption = None
def get_suffix(self, used_suffixes=None):
used_suffixes = {} if used_suffixes is None else used_suffixes
suffix = "-".join(self.nametags)
if suffix:
suffix = "-" + suffix
if suffix not in used_suffixes:
used_suffixes[suffix] = 1
return suffix
for i in range(1, 100):
proposed_suffix = suffix + "-" + str(i)
if proposed_suffix not in used_suffixes:
used_suffixes[proposed_suffix] = 1
return proposed_suffix
return suffix
def create_copy(self, new_image, *, nametags=None, disable_processing=False):
pp = PostprocessedImage(new_image)
pp.shared = self.shared
pp.nametags = self.nametags.copy()
pp.info = self.info.copy()
pp.disable_processing = disable_processing
if nametags is not None:
pp.nametags += nametags
return pp
class ScriptPostprocessing: class ScriptPostprocessing:
...@@ -42,10 +85,17 @@ class ScriptPostprocessing: ...@@ -42,10 +85,17 @@ class ScriptPostprocessing:
pass pass
def image_changed(self): def process_firstpass(self, pp: PostprocessedImage, **args):
pass """
Called for all scripts before calling process(). Scripts can examine the image here and set fields
of the pp object to communicate things to other scripts.
args contains a dictionary with all values returned by components from ui()
"""
pass
def image_changed(self):
pass
def wrap_call(func, filename, funcname, *args, default=None, **kwargs): def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
...@@ -118,16 +168,42 @@ class ScriptPostprocessingRunner: ...@@ -118,16 +168,42 @@ class ScriptPostprocessingRunner:
return inputs return inputs
def run(self, pp: PostprocessedImage, args): def run(self, pp: PostprocessedImage, args):
for script in self.scripts_in_preferred_order(): scripts = []
shared.state.job = script.name
for script in self.scripts_in_preferred_order():
script_args = args[script.args_from:script.args_to] script_args = args[script.args_from:script.args_to]
process_args = {} process_args = {}
for (name, _component), value in zip(script.controls.items(), script_args): for (name, _component), value in zip(script.controls.items(), script_args):
process_args[name] = value process_args[name] = value
script.process(pp, **process_args) scripts.append((script, process_args))
for script, process_args in scripts:
script.process_firstpass(pp, **process_args)
all_images = [pp]
for script, process_args in scripts:
if shared.state.skipped:
break
shared.state.job = script.name
for single_image in all_images.copy():
if not single_image.disable_processing:
script.process(single_image, **process_args)
for extra_image in single_image.extra_images:
if not isinstance(extra_image, PostprocessedImage):
extra_image = single_image.create_copy(extra_image)
all_images.append(extra_image)
single_image.extra_images.clear()
pp.extra_images = all_images[1:]
def create_args_for_run(self, scripts_args): def create_args_for_run(self, scripts_args):
if not self.ui_created: if not self.ui_created:
......
...@@ -215,7 +215,7 @@ class LoadStateDictOnMeta(ReplaceHelper): ...@@ -215,7 +215,7 @@ class LoadStateDictOnMeta(ReplaceHelper):
would be on the meta device. would be on the meta device.
""" """
if state_dict == sd: if state_dict is sd:
state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()} state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()}
original(module, state_dict, strict=strict) original(module, state_dict, strict=strict)
......
...@@ -38,8 +38,12 @@ ldm.models.diffusion.ddpm.print = shared.ldm_print ...@@ -38,8 +38,12 @@ ldm.models.diffusion.ddpm.print = shared.ldm_print
optimizers = [] optimizers = []
current_optimizer: sd_hijack_optimizations.SdOptimization = None current_optimizer: sd_hijack_optimizations.SdOptimization = None
ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sd_unet.UNetModel_forward) ldm_patched_forward = sd_unet.create_unet_forward(ldm.modules.diffusionmodules.openaimodel.UNetModel.forward)
sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sd_unet.UNetModel_forward) ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", ldm_patched_forward)
sgm_patched_forward = sd_unet.create_unet_forward(sgm.modules.diffusionmodules.openaimodel.UNetModel.forward)
sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sgm_patched_forward)
def list_optimizers(): def list_optimizers():
new_optimizers = script_callbacks.list_optimizers_callback() new_optimizers = script_callbacks.list_optimizers_callback()
...@@ -303,8 +307,6 @@ class StableDiffusionModelHijack: ...@@ -303,8 +307,6 @@ class StableDiffusionModelHijack:
self.layers = None self.layers = None
self.clip = None self.clip = None
sd_unet.original_forward = None
def apply_circular(self, enable): def apply_circular(self, enable):
if self.circular_enabled == enable: if self.circular_enabled == enable:
......
...@@ -230,15 +230,19 @@ def select_checkpoint(): ...@@ -230,15 +230,19 @@ def select_checkpoint():
return checkpoint_info return checkpoint_info
checkpoint_dict_replacements = { checkpoint_dict_replacements_sd1 = {
'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.', 'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.', 'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.', 'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.',
} }
checkpoint_dict_replacements_sd2_turbo = { # Converts SD 2.1 Turbo from SGM to LDM format.
'conditioner.embedders.0.': 'cond_stage_model.',
}
def transform_checkpoint_dict_key(k): def transform_checkpoint_dict_key(k, replacements):
for text, replacement in checkpoint_dict_replacements.items(): for text, replacement in replacements.items():
if k.startswith(text): if k.startswith(text):
k = replacement + k[len(text):] k = replacement + k[len(text):]
...@@ -249,9 +253,14 @@ def get_state_dict_from_checkpoint(pl_sd): ...@@ -249,9 +253,14 @@ def get_state_dict_from_checkpoint(pl_sd):
pl_sd = pl_sd.pop("state_dict", pl_sd) pl_sd = pl_sd.pop("state_dict", pl_sd)
pl_sd.pop("state_dict", None) pl_sd.pop("state_dict", None)
is_sd2_turbo = 'conditioner.embedders.0.model.ln_final.weight' in pl_sd and pl_sd['conditioner.embedders.0.model.ln_final.weight'].size()[0] == 1024
sd = {} sd = {}
for k, v in pl_sd.items(): for k, v in pl_sd.items():
new_key = transform_checkpoint_dict_key(k) if is_sd2_turbo:
new_key = transform_checkpoint_dict_key(k, checkpoint_dict_replacements_sd2_turbo)
else:
new_key = transform_checkpoint_dict_key(k, checkpoint_dict_replacements_sd1)
if new_key is not None: if new_key is not None:
sd[new_key] = v sd[new_key] = v
......
...@@ -60,7 +60,7 @@ def restart_sampler(model, x, sigmas, extra_args=None, callback=None, disable=No ...@@ -60,7 +60,7 @@ def restart_sampler(model, x, sigmas, extra_args=None, callback=None, disable=No
sigma_restart = get_sigmas_karras(restart_steps, sigmas[min_idx].item(), sigmas[max_idx].item(), device=sigmas.device)[:-1] sigma_restart = get_sigmas_karras(restart_steps, sigmas[min_idx].item(), sigmas[max_idx].item(), device=sigmas.device)[:-1]
while restart_times > 0: while restart_times > 0:
restart_times -= 1 restart_times -= 1
step_list.extend([(old_sigma, new_sigma) for (old_sigma, new_sigma) in zip(sigma_restart[:-1], sigma_restart[1:])]) step_list.extend(zip(sigma_restart[:-1], sigma_restart[1:]))
last_sigma = None last_sigma = None
for old_sigma, new_sigma in tqdm.tqdm(step_list, disable=disable): for old_sigma, new_sigma in tqdm.tqdm(step_list, disable=disable):
......
...@@ -11,7 +11,7 @@ from modules.models.diffusion.uni_pc import uni_pc ...@@ -11,7 +11,7 @@ from modules.models.diffusion.uni_pc import uni_pc
def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0): def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0):
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
alphas = alphas_cumprod[timesteps] alphas = alphas_cumprod[timesteps]
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' else torch.float32) alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' and x.device.type != 'xpu' else torch.float32)
sqrt_one_minus_alphas = torch.sqrt(1 - alphas) sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy())) sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))
...@@ -43,7 +43,7 @@ def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta= ...@@ -43,7 +43,7 @@ def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=
def plms(model, x, timesteps, extra_args=None, callback=None, disable=None): def plms(model, x, timesteps, extra_args=None, callback=None, disable=None):
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
alphas = alphas_cumprod[timesteps] alphas = alphas_cumprod[timesteps]
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' else torch.float32) alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' and x.device.type != 'xpu' else torch.float32)
sqrt_one_minus_alphas = torch.sqrt(1 - alphas) sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
extra_args = {} if extra_args is None else extra_args extra_args = {} if extra_args is None else extra_args
......
...@@ -5,8 +5,7 @@ from modules import script_callbacks, shared, devices ...@@ -5,8 +5,7 @@ from modules import script_callbacks, shared, devices
unet_options = [] unet_options = []
current_unet_option = None current_unet_option = None
current_unet = None current_unet = None
original_forward = None original_forward = None # not used, only left temporarily for compatibility
def list_unets(): def list_unets():
new_unets = script_callbacks.list_unets_callback() new_unets = script_callbacks.list_unets_callback()
...@@ -84,9 +83,12 @@ class SdUnet(torch.nn.Module): ...@@ -84,9 +83,12 @@ class SdUnet(torch.nn.Module):
pass pass
def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs): def create_unet_forward(original_forward):
if current_unet is not None: def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs):
return current_unet.forward(x, timesteps, context, *args, **kwargs) if current_unet is not None:
return current_unet.forward(x, timesteps, context, *args, **kwargs)
return original_forward(self, x, timesteps, context, *args, **kwargs)
return original_forward(self, x, timesteps, context, *args, **kwargs) return UNetModel_forward
...@@ -66,6 +66,22 @@ def reload_hypernetworks(): ...@@ -66,6 +66,22 @@ def reload_hypernetworks():
shared.hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) shared.hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
def get_infotext_names():
from modules import generation_parameters_copypaste, shared
res = {}
for info in shared.opts.data_labels.values():
if info.infotext:
res[info.infotext] = 1
for tab_data in generation_parameters_copypaste.paste_fields.values():
for _, name in tab_data.get("fields") or []:
if isinstance(name, str):
res[name] = 1
return list(res)
ui_reorder_categories_builtin_items = [ ui_reorder_categories_builtin_items = [
"prompt", "prompt",
"image", "image",
......
This diff is collapsed.
This diff is collapsed.
import json import json
import os import os
import sys import sys
import traceback
import platform import platform
import hashlib import hashlib
...@@ -84,7 +83,7 @@ def get_dict(): ...@@ -84,7 +83,7 @@ def get_dict():
"Checksum": checksum_token, "Checksum": checksum_token,
"Commandline": get_argv(), "Commandline": get_argv(),
"Torch env info": get_torch_sysinfo(), "Torch env info": get_torch_sysinfo(),
"Exceptions": get_exceptions(), "Exceptions": errors.get_exceptions(),
"CPU": { "CPU": {
"model": platform.processor(), "model": platform.processor(),
"count logical": psutil.cpu_count(logical=True), "count logical": psutil.cpu_count(logical=True),
...@@ -104,21 +103,6 @@ def get_dict(): ...@@ -104,21 +103,6 @@ def get_dict():
return res return res
def format_traceback(tb):
return [[f"{x.filename}, line {x.lineno}, {x.name}", x.line] for x in traceback.extract_tb(tb)]
def format_exception(e, tb):
return {"exception": str(e), "traceback": format_traceback(tb)}
def get_exceptions():
try:
return list(reversed(errors.exception_records))
except Exception as e:
return str(e)
def get_environment(): def get_environment():
return {k: os.environ[k] for k in sorted(os.environ) if k in environment_whitelist} return {k: os.environ[k] for k in sorted(os.environ) if k in environment_whitelist}
......
This diff is collapsed.
This diff is collapsed.
...@@ -3,7 +3,6 @@ import html ...@@ -3,7 +3,6 @@ import html
import gradio as gr import gradio as gr
import modules.textual_inversion.textual_inversion import modules.textual_inversion.textual_inversion
import modules.textual_inversion.preprocess
from modules import sd_hijack, shared from modules import sd_hijack, shared
...@@ -15,12 +14,6 @@ def create_embedding(name, initialization_text, nvpt, overwrite_old): ...@@ -15,12 +14,6 @@ def create_embedding(name, initialization_text, nvpt, overwrite_old):
return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", "" return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", ""
def preprocess(*args):
modules.textual_inversion.preprocess.preprocess(*args)
return f"Preprocessing {'interrupted' if shared.state.interrupted else 'finished'}.", ""
def train_embedding(*args): def train_embedding(*args):
assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible' assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible'
......
This diff is collapsed.
...@@ -65,7 +65,7 @@ def save_config_state(name): ...@@ -65,7 +65,7 @@ def save_config_state(name):
filename = os.path.join(config_states_dir, f"{timestamp}_{name}.json") filename = os.path.join(config_states_dir, f"{timestamp}_{name}.json")
print(f"Saving backup of webui/extension state to {filename}.") print(f"Saving backup of webui/extension state to {filename}.")
with open(filename, "w", encoding="utf-8") as f: with open(filename, "w", encoding="utf-8") as f:
json.dump(current_config_state, f, indent=4) json.dump(current_config_state, f, indent=4, ensure_ascii=False)
config_states.list_config_states() config_states.list_config_states()
new_value = next(iter(config_states.all_config_states.keys()), "Current") new_value = next(iter(config_states.all_config_states.keys()), "Current")
new_choices = ["Current"] + list(config_states.all_config_states.keys()) new_choices = ["Current"] + list(config_states.all_config_states.keys())
...@@ -335,6 +335,11 @@ def normalize_git_url(url): ...@@ -335,6 +335,11 @@ def normalize_git_url(url):
return url return url
def get_extension_dirname_from_url(url):
*parts, last_part = url.split('/')
return normalize_git_url(last_part)
def install_extension_from_url(dirname, url, branch_name=None): def install_extension_from_url(dirname, url, branch_name=None):
check_access() check_access()
...@@ -346,10 +351,7 @@ def install_extension_from_url(dirname, url, branch_name=None): ...@@ -346,10 +351,7 @@ def install_extension_from_url(dirname, url, branch_name=None):
assert url, 'No URL specified' assert url, 'No URL specified'
if dirname is None or dirname == "": if dirname is None or dirname == "":
*parts, last_part = url.split('/') dirname = get_extension_dirname_from_url(url)
last_part = normalize_git_url(last_part)
dirname = last_part
target_dir = os.path.join(extensions.extensions_dir, dirname) target_dir = os.path.join(extensions.extensions_dir, dirname)
assert not os.path.exists(target_dir), f'Extension directory already exists: {target_dir}' assert not os.path.exists(target_dir), f'Extension directory already exists: {target_dir}'
...@@ -449,7 +451,8 @@ def get_date(info: dict, key): ...@@ -449,7 +451,8 @@ def get_date(info: dict, key):
def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=""): def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=""):
extlist = available_extensions["extensions"] extlist = available_extensions["extensions"]
installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions} installed_extensions = {extension.name for extension in extensions.extensions}
installed_extension_urls = {normalize_git_url(extension.remote) for extension in extensions.extensions if extension.remote is not None}
tags = available_extensions.get("tags", {}) tags = available_extensions.get("tags", {})
tags_to_hide = set(hide_tags) tags_to_hide = set(hide_tags)
...@@ -482,7 +485,7 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=" ...@@ -482,7 +485,7 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text="
if url is None: if url is None:
continue continue
existing = installed_extension_urls.get(normalize_git_url(url), None) existing = get_extension_dirname_from_url(url) in installed_extensions or normalize_git_url(url) in installed_extension_urls
extension_tags = extension_tags + ["installed"] if existing else extension_tags extension_tags = extension_tags + ["installed"] if existing else extension_tags
if any(x for x in extension_tags if x in tags_to_hide): if any(x for x in extension_tags if x in tags_to_hide):
......
...@@ -151,8 +151,13 @@ class ExtraNetworksPage: ...@@ -151,8 +151,13 @@ class ExtraNetworksPage:
continue continue
subdir = os.path.abspath(x)[len(parentdir):].replace("\\", "/") subdir = os.path.abspath(x)[len(parentdir):].replace("\\", "/")
while subdir.startswith("/"):
subdir = subdir[1:] if shared.opts.extra_networks_dir_button_function:
if not subdir.startswith("/"):
subdir = "/" + subdir
else:
while subdir.startswith("/"):
subdir = subdir[1:]
is_empty = len(os.listdir(x)) == 0 is_empty = len(os.listdir(x)) == 0
if not is_empty and not subdir.endswith("/"): if not is_empty and not subdir.endswith("/"):
...@@ -279,6 +284,7 @@ class ExtraNetworksPage: ...@@ -279,6 +284,7 @@ class ExtraNetworksPage:
"date_created": int(stat.st_ctime or 0), "date_created": int(stat.st_ctime or 0),
"date_modified": int(stat.st_mtime or 0), "date_modified": int(stat.st_mtime or 0),
"name": pth.name.lower(), "name": pth.name.lower(),
"path": str(pth.parent).lower(),
} }
def find_preview(self, path): def find_preview(self, path):
...@@ -369,6 +375,9 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): ...@@ -369,6 +375,9 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
for page in ui.stored_extra_pages: for page in ui.stored_extra_pages:
with gr.Tab(page.title, elem_id=f"{tabname}_{page.id_page}", elem_classes=["extra-page"]) as tab: with gr.Tab(page.title, elem_id=f"{tabname}_{page.id_page}", elem_classes=["extra-page"]) as tab:
with gr.Column(elem_id=f"{tabname}_{page.id_page}_prompts", elem_classes=["extra-page-prompts"]):
pass
elem_id = f"{tabname}_{page.id_page}_cards_html" elem_id = f"{tabname}_{page.id_page}_cards_html"
page_elem = gr.HTML('Loading...', elem_id=elem_id) page_elem = gr.HTML('Loading...', elem_id=elem_id)
ui.pages.append(page_elem) ui.pages.append(page_elem)
...@@ -382,7 +391,7 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): ...@@ -382,7 +391,7 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
related_tabs.append(tab) related_tabs.append(tab)
edit_search = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True) edit_search = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True)
dropdown_sort = gr.Dropdown(choices=['Name', 'Date Created', 'Date Modified', ], value=shared.opts.extra_networks_card_order_field, elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order") dropdown_sort = gr.Dropdown(choices=['Path', 'Name', 'Date Created', 'Date Modified', ], value=shared.opts.extra_networks_card_order_field, elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order")
button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes=["sortorder"] + ([] if shared.opts.extra_networks_card_order == "Ascending" else ["sortReverse"]), visible=False, tooltip="Invert sort order") button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes=["sortorder"] + ([] if shared.opts.extra_networks_card_order == "Ascending" else ["sortReverse"]), visible=False, tooltip="Invert sort order")
button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False) button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False)
checkbox_show_dirs = gr.Checkbox(True, label='Show dirs', elem_id=tabname+"_extra_show_dirs", elem_classes="show-dirs", visible=False) checkbox_show_dirs = gr.Checkbox(True, label='Show dirs', elem_id=tabname+"_extra_show_dirs", elem_classes="show-dirs", visible=False)
...@@ -399,7 +408,7 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): ...@@ -399,7 +408,7 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
allow_prompt = "true" if page.allow_prompt else "false" allow_prompt = "true" if page.allow_prompt else "false"
allow_negative_prompt = "true" if page.allow_negative_prompt else "false" allow_negative_prompt = "true" if page.allow_negative_prompt else "false"
jscode = 'extraNetworksTabSelected("' + tabname + '", "' + f"{tabname}_{page.id_page}" + '", ' + allow_prompt + ', ' + allow_negative_prompt + ');' jscode = 'extraNetworksTabSelected("' + tabname + '", "' + f"{tabname}_{page.id_page}_prompts" + '", ' + allow_prompt + ', ' + allow_negative_prompt + ');'
tab.select(fn=lambda: [gr.update(visible=True) for _ in tab_controls], _js='function(){ ' + jscode + ' }', inputs=[], outputs=tab_controls, show_progress=False) tab.select(fn=lambda: [gr.update(visible=True) for _ in tab_controls], _js='function(){ ' + jscode + ' }', inputs=[], outputs=tab_controls, show_progress=False)
......
...@@ -17,6 +17,9 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): ...@@ -17,6 +17,9 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
def create_item(self, name, index=None, enable_filter=True): def create_item(self, name, index=None, enable_filter=True):
checkpoint: sd_models.CheckpointInfo = sd_models.checkpoint_aliases.get(name) checkpoint: sd_models.CheckpointInfo = sd_models.checkpoint_aliases.get(name)
if checkpoint is None:
return
path, ext = os.path.splitext(checkpoint.filename) path, ext = os.path.splitext(checkpoint.filename)
return { return {
"name": checkpoint.name_for_extra, "name": checkpoint.name_for_extra,
...@@ -32,9 +35,12 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): ...@@ -32,9 +35,12 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
} }
def list_items(self): def list_items(self):
# instantiate a list to protect against concurrent modification
names = list(sd_models.checkpoints_list) names = list(sd_models.checkpoints_list)
for index, name in enumerate(names): for index, name in enumerate(names):
yield self.create_item(name, index) item = self.create_item(name, index)
if item is not None:
yield item
def allowed_directories_for_previews(self): def allowed_directories_for_previews(self):
return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None] return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None]
......
...@@ -13,7 +13,10 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): ...@@ -13,7 +13,10 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
shared.reload_hypernetworks() shared.reload_hypernetworks()
def create_item(self, name, index=None, enable_filter=True): def create_item(self, name, index=None, enable_filter=True):
full_path = shared.hypernetworks[name] full_path = shared.hypernetworks.get(name)
if full_path is None:
return
path, ext = os.path.splitext(full_path) path, ext = os.path.splitext(full_path)
sha256 = sha256_from_cache(full_path, f'hypernet/{name}') sha256 = sha256_from_cache(full_path, f'hypernet/{name}')
shorthash = sha256[0:10] if sha256 else None shorthash = sha256[0:10] if sha256 else None
...@@ -31,8 +34,12 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): ...@@ -31,8 +34,12 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
} }
def list_items(self): def list_items(self):
for index, name in enumerate(shared.hypernetworks): # instantiate a list to protect against concurrent modification
yield self.create_item(name, index) names = list(shared.hypernetworks)
for index, name in enumerate(names):
item = self.create_item(name, index)
if item is not None:
yield item
def allowed_directories_for_previews(self): def allowed_directories_for_previews(self):
return [shared.cmd_opts.hypernetwork_dir] return [shared.cmd_opts.hypernetwork_dir]
......
...@@ -14,6 +14,8 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): ...@@ -14,6 +14,8 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
def create_item(self, name, index=None, enable_filter=True): def create_item(self, name, index=None, enable_filter=True):
embedding = sd_hijack.model_hijack.embedding_db.word_embeddings.get(name) embedding = sd_hijack.model_hijack.embedding_db.word_embeddings.get(name)
if embedding is None:
return
path, ext = os.path.splitext(embedding.filename) path, ext = os.path.splitext(embedding.filename)
return { return {
...@@ -29,8 +31,12 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): ...@@ -29,8 +31,12 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
} }
def list_items(self): def list_items(self):
for index, name in enumerate(sd_hijack.model_hijack.embedding_db.word_embeddings): # instantiate a list to protect against concurrent modification
yield self.create_item(name, index) names = list(sd_hijack.model_hijack.embedding_db.word_embeddings)
for index, name in enumerate(names):
item = self.create_item(name, index)
if item is not None:
yield item
def allowed_directories_for_previews(self): def allowed_directories_for_previews(self):
return list(sd_hijack.model_hijack.embedding_db.embedding_dirs) return list(sd_hijack.model_hijack.embedding_db.embedding_dirs)
...@@ -134,7 +134,7 @@ class UserMetadataEditor: ...@@ -134,7 +134,7 @@ class UserMetadataEditor:
basename, ext = os.path.splitext(filename) basename, ext = os.path.splitext(filename)
with open(basename + '.json', "w", encoding="utf8") as file: with open(basename + '.json', "w", encoding="utf8") as file:
json.dump(metadata, file, indent=4) json.dump(metadata, file, indent=4, ensure_ascii=False)
def save_user_metadata(self, name, desc, notes): def save_user_metadata(self, name, desc, notes):
user_metadata = self.get_user_metadata(name) user_metadata = self.get_user_metadata(name)
......
...@@ -141,7 +141,7 @@ class UiLoadsave: ...@@ -141,7 +141,7 @@ class UiLoadsave:
def write_to_file(self, current_ui_settings): def write_to_file(self, current_ui_settings):
with open(self.filename, "w", encoding="utf8") as file: with open(self.filename, "w", encoding="utf8") as file:
json.dump(current_ui_settings, file, indent=4) json.dump(current_ui_settings, file, indent=4, ensure_ascii=False)
def dump_defaults(self): def dump_defaults(self):
"""saves default values to a file unless tjhe file is present and there was an error loading default values at start""" """saves default values to a file unless tjhe file is present and there was an error loading default values at start"""
......
import gradio as gr import gradio as gr
from modules import scripts, shared, ui_common, postprocessing, call_queue from modules import scripts, shared, ui_common, postprocessing, call_queue, ui_toprow
import modules.generation_parameters_copypaste as parameters_copypaste import modules.generation_parameters_copypaste as parameters_copypaste
def create_ui(): def create_ui():
dummy_component = gr.Label(visible=False)
tab_index = gr.State(value=0) tab_index = gr.State(value=0)
with gr.Row(equal_height=False, variant='compact'): with gr.Row(equal_height=False, variant='compact'):
...@@ -20,11 +21,13 @@ def create_ui(): ...@@ -20,11 +21,13 @@ def create_ui():
extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir") extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir")
show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results") show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results")
submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
script_inputs = scripts.scripts_postproc.setup_ui() script_inputs = scripts.scripts_postproc.setup_ui()
with gr.Column(): with gr.Column():
toprow = ui_toprow.Toprow(is_compact=True, is_img2img=False, id_part="extras")
toprow.create_inline_toprow_image()
submit = toprow.submit
result_images, html_info_x, html_info, html_log = ui_common.create_output_panel("extras", shared.opts.outdir_extras_samples) result_images, html_info_x, html_info, html_log = ui_common.create_output_panel("extras", shared.opts.outdir_extras_samples)
tab_single.select(fn=lambda: 0, inputs=[], outputs=[tab_index]) tab_single.select(fn=lambda: 0, inputs=[], outputs=[tab_index])
...@@ -32,8 +35,10 @@ def create_ui(): ...@@ -32,8 +35,10 @@ def create_ui():
tab_batch_dir.select(fn=lambda: 2, inputs=[], outputs=[tab_index]) tab_batch_dir.select(fn=lambda: 2, inputs=[], outputs=[tab_index])
submit.click( submit.click(
fn=call_queue.wrap_gradio_gpu_call(postprocessing.run_postprocessing, extra_outputs=[None, '']), fn=call_queue.wrap_gradio_gpu_call(postprocessing.run_postprocessing_webui, extra_outputs=[None, '']),
_js="submit_extras",
inputs=[ inputs=[
dummy_component,
tab_index, tab_index,
extras_image, extras_image,
image_batch, image_batch,
...@@ -45,8 +50,9 @@ def create_ui(): ...@@ -45,8 +50,9 @@ def create_ui():
outputs=[ outputs=[
result_images, result_images,
html_info_x, html_info_x,
html_info, html_log,
] ],
show_progress=False,
) )
parameters_copypaste.add_paste_fields("extras", extras_image, None) parameters_copypaste.add_paste_fields("extras", extras_image, None)
......
...@@ -68,10 +68,10 @@ class UiPromptStyles: ...@@ -68,10 +68,10 @@ class UiPromptStyles:
self.copy = ui_components.ToolButton(value=styles_copy_symbol, elem_id=f"{tabname}_style_copy", tooltip="Copy main UI prompt to style.") self.copy = ui_components.ToolButton(value=styles_copy_symbol, elem_id=f"{tabname}_style_copy", tooltip="Copy main UI prompt to style.")
with gr.Row(): with gr.Row():
self.prompt = gr.Textbox(label="Prompt", show_label=True, elem_id=f"{tabname}_edit_style_prompt", lines=3) self.prompt = gr.Textbox(label="Prompt", show_label=True, elem_id=f"{tabname}_edit_style_prompt", lines=3, elem_classes=["prompt"])
with gr.Row(): with gr.Row():
self.neg_prompt = gr.Textbox(label="Negative prompt", show_label=True, elem_id=f"{tabname}_edit_style_neg_prompt", lines=3) self.neg_prompt = gr.Textbox(label="Negative prompt", show_label=True, elem_id=f"{tabname}_edit_style_neg_prompt", lines=3, elem_classes=["prompt"])
with gr.Row(): with gr.Row():
self.save = gr.Button('Save', variant='primary', elem_id=f'{tabname}_edit_style_save', visible=False) self.save = gr.Button('Save', variant='primary', elem_id=f'{tabname}_edit_style_save', visible=False)
......
...@@ -34,8 +34,10 @@ class Toprow: ...@@ -34,8 +34,10 @@ class Toprow:
submit_box = None submit_box = None
def __init__(self, is_img2img, is_compact=False): def __init__(self, is_img2img, is_compact=False, id_part=None):
id_part = "img2img" if is_img2img else "txt2img" if id_part is None:
id_part = "img2img" if is_img2img else "txt2img"
self.id_part = id_part self.id_part = id_part
self.is_img2img = is_img2img self.is_img2img = is_img2img
self.is_compact = is_compact self.is_compact = is_compact
......
...@@ -57,6 +57,9 @@ class Upscaler: ...@@ -57,6 +57,9 @@ class Upscaler:
dest_h = int((img.height * scale) // 8 * 8) dest_h = int((img.height * scale) // 8 * 8)
for _ in range(3): for _ in range(3):
if img.width >= dest_w and img.height >= dest_h:
break
shape = (img.width, img.height) shape = (img.width, img.height)
img = self.do_upscale(img, selected_model) img = self.do_upscale(img, selected_model)
...@@ -64,9 +67,6 @@ class Upscaler: ...@@ -64,9 +67,6 @@ class Upscaler:
if shape == (img.width, img.height): if shape == (img.width, img.height):
break break
if img.width >= dest_w and img.height >= dest_h:
break
if img.width != dest_w or img.height != dest_h: if img.width != dest_w or img.height != dest_h:
img = img.resize((int(dest_w), int(dest_h)), resample=LANCZOS) img = img.resize((int(dest_w), int(dest_h)), resample=LANCZOS)
......
from modules import shared
from modules.sd_hijack_utils import CondFunc
has_ipex = False
try:
import torch
import intel_extension_for_pytorch as ipex # noqa: F401
has_ipex = True
except Exception:
pass
def check_for_xpu():
return has_ipex and hasattr(torch, 'xpu') and torch.xpu.is_available()
def get_xpu_device_string():
if shared.cmd_opts.device_id is not None:
return f"xpu:{shared.cmd_opts.device_id}"
return "xpu"
def torch_xpu_gc():
with torch.xpu.device(get_xpu_device_string()):
torch.xpu.empty_cache()
has_xpu = check_for_xpu()
if has_xpu:
# W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device
CondFunc('torch.Generator',
lambda orig_func, device=None: torch.xpu.Generator(device),
lambda orig_func, device=None: device is not None and device.type == "xpu")
# W/A for some OPs that could not handle different input dtypes
CondFunc('torch.nn.functional.layer_norm',
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs),
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
weight is not None and input.dtype != weight.data.dtype)
CondFunc('torch.nn.modules.GroupNorm.forward',
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
CondFunc('torch.nn.modules.linear.Linear.forward',
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
CondFunc('torch.nn.modules.conv.Conv2d.forward',
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
CondFunc('torch.bmm',
lambda orig_func, input, mat2, out=None: orig_func(input.to(mat2.dtype), mat2, out=out),
lambda orig_func, input, mat2, out=None: input.dtype != mat2.dtype)
CondFunc('torch.cat',
lambda orig_func, tensors, dim=0, out=None: orig_func([t.to(tensors[0].dtype) for t in tensors], dim=dim, out=out),
lambda orig_func, tensors, dim=0, out=None: not all(t.dtype == tensors[0].dtype for t in tensors))
CondFunc('torch.nn.functional.scaled_dot_product_attention',
lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: orig_func(query, key.to(query.dtype), value.to(query.dtype), attn_mask, dropout_p, is_causal),
lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: query.dtype != key.dtype or query.dtype != value.dtype)
...@@ -16,6 +16,7 @@ exclude = [ ...@@ -16,6 +16,7 @@ exclude = [
ignore = [ ignore = [
"E501", # Line too long "E501", # Line too long
"E721", # Do not compare types, use `isinstance`
"E731", # Do not assign a `lambda` expression, use a `def` "E731", # Do not assign a `lambda` expression, use a `def`
"I001", # Import block is un-sorted or un-formatted "I001", # Import block is un-sorted or un-formatted
......
...@@ -133,9 +133,18 @@ document.addEventListener('keydown', function(e) { ...@@ -133,9 +133,18 @@ document.addEventListener('keydown', function(e) {
if (isEnter && isModifierKey) { if (isEnter && isModifierKey) {
if (interruptButton.style.display === 'block') { if (interruptButton.style.display === 'block') {
interruptButton.click(); interruptButton.click();
setTimeout(function() { const callback = (mutationList) => {
generateButton.click(); for (const mutation of mutationList) {
}, 500); if (mutation.type === 'attributes' && mutation.attributeName === 'style') {
if (interruptButton.style.display === 'none') {
generateButton.click();
observer.disconnect();
}
}
}
};
const observer = new MutationObserver(callback);
observer.observe(interruptButton, {attributes: true});
} else { } else {
generateButton.click(); generateButton.click();
} }
......
from modules import scripts_postprocessing, ui_components, deepbooru, shared
import gradio as gr
class ScriptPostprocessingCeption(scripts_postprocessing.ScriptPostprocessing):
name = "Caption"
order = 4000
def ui(self):
with ui_components.InputAccordion(False, label="Caption") as enable:
option = gr.CheckboxGroup(value=["Deepbooru"], choices=["Deepbooru", "BLIP"], show_label=False)
return {
"enable": enable,
"option": option,
}
def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, option):
if not enable:
return
captions = [pp.caption]
if "Deepbooru" in option:
captions.append(deepbooru.model.tag(pp.image))
if "BLIP" in option:
captions.append(shared.interrogator.generate_caption(pp.image))
pp.caption = ", ".join([x for x in captions if x])
from PIL import Image from PIL import Image
import numpy as np import numpy as np
from modules import scripts_postprocessing, codeformer_model from modules import scripts_postprocessing, codeformer_model, ui_components
import gradio as gr import gradio as gr
from modules.ui_components import FormRow
class ScriptPostprocessingCodeFormer(scripts_postprocessing.ScriptPostprocessing): class ScriptPostprocessingCodeFormer(scripts_postprocessing.ScriptPostprocessing):
name = "CodeFormer" name = "CodeFormer"
order = 3000 order = 3000
def ui(self): def ui(self):
with FormRow(): with ui_components.InputAccordion(False, label="CodeFormer") as enable:
codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, elem_id="extras_codeformer_visibility") with gr.Row():
codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, elem_id="extras_codeformer_weight") codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Visibility", value=1.0, elem_id="extras_codeformer_visibility")
codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Weight (0 = maximum effect, 1 = minimum effect)", value=0, elem_id="extras_codeformer_weight")
return { return {
"enable": enable,
"codeformer_visibility": codeformer_visibility, "codeformer_visibility": codeformer_visibility,
"codeformer_weight": codeformer_weight, "codeformer_weight": codeformer_weight,
} }
def process(self, pp: scripts_postprocessing.PostprocessedImage, codeformer_visibility, codeformer_weight): def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, codeformer_visibility, codeformer_weight):
if codeformer_visibility == 0: if codeformer_visibility == 0 or not enable:
return return
restored_img = codeformer_model.codeformer.restore(np.array(pp.image, dtype=np.uint8), w=codeformer_weight) restored_img = codeformer_model.codeformer.restore(np.array(pp.image, dtype=np.uint8), w=codeformer_weight)
......
from PIL import ImageOps, Image
from modules import scripts_postprocessing, ui_components
import gradio as gr
class ScriptPostprocessingCreateFlippedCopies(scripts_postprocessing.ScriptPostprocessing):
name = "Create flipped copies"
order = 4000
def ui(self):
with ui_components.InputAccordion(False, label="Create flipped copies") as enable:
with gr.Row():
option = gr.CheckboxGroup(value=["Horizontal"], choices=["Horizontal", "Vertical", "Both"], show_label=False)
return {
"enable": enable,
"option": option,
}
def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, option):
if not enable:
return
if "Horizontal" in option:
pp.extra_images.append(ImageOps.mirror(pp.image))
if "Vertical" in option:
pp.extra_images.append(pp.image.transpose(Image.Transpose.FLIP_TOP_BOTTOM))
if "Both" in option:
pp.extra_images.append(pp.image.transpose(Image.Transpose.FLIP_TOP_BOTTOM).transpose(Image.Transpose.FLIP_LEFT_RIGHT))
from modules import scripts_postprocessing, ui_components, errors
import gradio as gr
from modules.textual_inversion import autocrop
class ScriptPostprocessingFocalCrop(scripts_postprocessing.ScriptPostprocessing):
name = "Auto focal point crop"
order = 4000
def ui(self):
with ui_components.InputAccordion(False, label="Auto focal point crop") as enable:
face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="postprocess_focal_crop_face_weight")
entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="postprocess_focal_crop_entropy_weight")
edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="postprocess_focal_crop_edges_weight")
debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug")
return {
"enable": enable,
"face_weight": face_weight,
"entropy_weight": entropy_weight,
"edges_weight": edges_weight,
"debug": debug,
}
def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, face_weight, entropy_weight, edges_weight, debug):
if not enable:
return
if not pp.shared.target_width or not pp.shared.target_height:
return
dnn_model_path = None
try:
dnn_model_path = autocrop.download_and_cache_models()
except Exception:
errors.report("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", exc_info=True)
autocrop_settings = autocrop.Settings(
crop_width=pp.shared.target_width,
crop_height=pp.shared.target_height,
face_points_weight=face_weight,
entropy_points_weight=entropy_weight,
corner_points_weight=edges_weight,
annotate_image=debug,
dnn_model_path=dnn_model_path,
)
result, *others = autocrop.crop_image(pp.image, autocrop_settings)
pp.image = result
pp.extra_images = [pp.create_copy(x, nametags=["focal-crop-debug"], disable_processing=True) for x in others]
from PIL import Image from PIL import Image
import numpy as np import numpy as np
from modules import scripts_postprocessing, gfpgan_model from modules import scripts_postprocessing, gfpgan_model, ui_components
import gradio as gr import gradio as gr
from modules.ui_components import FormRow
class ScriptPostprocessingGfpGan(scripts_postprocessing.ScriptPostprocessing): class ScriptPostprocessingGfpGan(scripts_postprocessing.ScriptPostprocessing):
name = "GFPGAN" name = "GFPGAN"
order = 2000 order = 2000
def ui(self): def ui(self):
with FormRow(): with ui_components.InputAccordion(False, label="GFPGAN") as enable:
gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, elem_id="extras_gfpgan_visibility") gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Visibility", value=1.0, elem_id="extras_gfpgan_visibility")
return { return {
"enable": enable,
"gfpgan_visibility": gfpgan_visibility, "gfpgan_visibility": gfpgan_visibility,
} }
def process(self, pp: scripts_postprocessing.PostprocessedImage, gfpgan_visibility): def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, gfpgan_visibility):
if gfpgan_visibility == 0: if gfpgan_visibility == 0 or not enable:
return return
restored_img = gfpgan_model.gfpgan_fix_faces(np.array(pp.image, dtype=np.uint8)) restored_img = gfpgan_model.gfpgan_fix_faces(np.array(pp.image, dtype=np.uint8))
......
import math
from modules import scripts_postprocessing, ui_components
import gradio as gr
def split_pic(image, inverse_xy, width, height, overlap_ratio):
if inverse_xy:
from_w, from_h = image.height, image.width
to_w, to_h = height, width
else:
from_w, from_h = image.width, image.height
to_w, to_h = width, height
h = from_h * to_w // from_w
if inverse_xy:
image = image.resize((h, to_w))
else:
image = image.resize((to_w, h))
split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio)))
y_step = (h - to_h) / (split_count - 1)
for i in range(split_count):
y = int(y_step * i)
if inverse_xy:
splitted = image.crop((y, 0, y + to_h, to_w))
else:
splitted = image.crop((0, y, to_w, y + to_h))
yield splitted
class ScriptPostprocessingSplitOversized(scripts_postprocessing.ScriptPostprocessing):
name = "Split oversized images"
order = 4000
def ui(self):
with ui_components.InputAccordion(False, label="Split oversized images") as enable:
with gr.Row():
split_threshold = gr.Slider(label='Threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="postprocess_split_threshold")
overlap_ratio = gr.Slider(label='Overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05, elem_id="postprocess_overlap_ratio")
return {
"enable": enable,
"split_threshold": split_threshold,
"overlap_ratio": overlap_ratio,
}
def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, split_threshold, overlap_ratio):
if not enable:
return
width = pp.shared.target_width
height = pp.shared.target_height
if not width or not height:
return
if pp.image.height > pp.image.width:
ratio = (pp.image.width * height) / (pp.image.height * width)
inverse_xy = False
else:
ratio = (pp.image.height * width) / (pp.image.width * height)
inverse_xy = True
if ratio >= 1.0 and ratio > split_threshold:
return
result, *others = split_pic(pp.image, inverse_xy, width, height, overlap_ratio)
pp.image = result
pp.extra_images = [pp.create_copy(x) for x in others]
...@@ -81,6 +81,14 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing): ...@@ -81,6 +81,14 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
return image return image
def process_firstpass(self, pp: scripts_postprocessing.PostprocessedImage, upscale_mode=1, upscale_by=2.0, upscale_to_width=None, upscale_to_height=None, upscale_crop=False, upscaler_1_name=None, upscaler_2_name=None, upscaler_2_visibility=0.0):
if upscale_mode == 1:
pp.shared.target_width = upscale_to_width
pp.shared.target_height = upscale_to_height
else:
pp.shared.target_width = int(pp.image.width * upscale_by)
pp.shared.target_height = int(pp.image.height * upscale_by)
def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_mode=1, upscale_by=2.0, upscale_to_width=None, upscale_to_height=None, upscale_crop=False, upscaler_1_name=None, upscaler_2_name=None, upscaler_2_visibility=0.0): def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_mode=1, upscale_by=2.0, upscale_to_width=None, upscale_to_height=None, upscale_crop=False, upscaler_1_name=None, upscaler_2_name=None, upscaler_2_visibility=0.0):
if upscaler_1_name == "None": if upscaler_1_name == "None":
upscaler_1_name = None upscaler_1_name = None
...@@ -126,6 +134,10 @@ class ScriptPostprocessingUpscaleSimple(ScriptPostprocessingUpscale): ...@@ -126,6 +134,10 @@ class ScriptPostprocessingUpscaleSimple(ScriptPostprocessingUpscale):
"upscaler_name": upscaler_name, "upscaler_name": upscaler_name,
} }
def process_firstpass(self, pp: scripts_postprocessing.PostprocessedImage, upscale_by=2.0, upscaler_name=None):
pp.shared.target_width = int(pp.image.width * upscale_by)
pp.shared.target_height = int(pp.image.height * upscale_by)
def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_by=2.0, upscaler_name=None): def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_by=2.0, upscaler_name=None):
if upscaler_name is None or upscaler_name == "None": if upscaler_name is None or upscaler_name == "None":
return return
......
from PIL import Image
from modules import scripts_postprocessing, ui_components
import gradio as gr
def center_crop(image: Image, w: int, h: int):
iw, ih = image.size
if ih / h < iw / w:
sw = w * ih / h
box = (iw - sw) / 2, 0, iw - (iw - sw) / 2, ih
else:
sh = h * iw / w
box = 0, (ih - sh) / 2, iw, ih - (ih - sh) / 2
return image.resize((w, h), Image.Resampling.LANCZOS, box)
def multicrop_pic(image: Image, mindim, maxdim, minarea, maxarea, objective, threshold):
iw, ih = image.size
err = lambda w, h: 1 - (lambda x: x if x < 1 else 1 / x)(iw / ih / (w / h))
wh = max(((w, h) for w in range(mindim, maxdim + 1, 64) for h in range(mindim, maxdim + 1, 64)
if minarea <= w * h <= maxarea and err(w, h) <= threshold),
key=lambda wh: (wh[0] * wh[1], -err(*wh))[::1 if objective == 'Maximize area' else -1],
default=None
)
return wh and center_crop(image, *wh)
class ScriptPostprocessingAutosizedCrop(scripts_postprocessing.ScriptPostprocessing):
name = "Auto-sized crop"
order = 4000
def ui(self):
with ui_components.InputAccordion(False, label="Auto-sized crop") as enable:
gr.Markdown('Each image is center-cropped with an automatically chosen width and height.')
with gr.Row():
mindim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension lower bound", value=384, elem_id="postprocess_multicrop_mindim")
maxdim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension upper bound", value=768, elem_id="postprocess_multicrop_maxdim")
with gr.Row():
minarea = gr.Slider(minimum=64 * 64, maximum=2048 * 2048, step=1, label="Area lower bound", value=64 * 64, elem_id="postprocess_multicrop_minarea")
maxarea = gr.Slider(minimum=64 * 64, maximum=2048 * 2048, step=1, label="Area upper bound", value=640 * 640, elem_id="postprocess_multicrop_maxarea")
with gr.Row():
objective = gr.Radio(["Maximize area", "Minimize error"], value="Maximize area", label="Resizing objective", elem_id="postprocess_multicrop_objective")
threshold = gr.Slider(minimum=0, maximum=1, step=0.01, label="Error threshold", value=0.1, elem_id="postprocess_multicrop_threshold")
return {
"enable": enable,
"mindim": mindim,
"maxdim": maxdim,
"minarea": minarea,
"maxarea": maxarea,
"objective": objective,
"threshold": threshold,
}
def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, mindim, maxdim, minarea, maxarea, objective, threshold):
if not enable:
return
cropped = multicrop_pic(pp.image, mindim, maxdim, minarea, maxarea, objective, threshold)
if cropped is not None:
pp.image = cropped
else:
print(f"skipped {pp.image.width}x{pp.image.height} image (can't find suitable size within error threshold)")
...@@ -462,6 +462,15 @@ div.toprow-compact-tools{ ...@@ -462,6 +462,15 @@ div.toprow-compact-tools{
padding: 4px; padding: 4px;
} }
#settings > div.tab-nav .settings-category{
display: block;
margin: 1em 0 0.25em 0;
font-weight: bold;
text-decoration: underline;
cursor: default;
user-select: none;
}
#settings_result{ #settings_result{
height: 1.4em; height: 1.4em;
margin: 0 1.2em; margin: 0 1.2em;
...@@ -637,6 +646,8 @@ table.popup-table .link{ ...@@ -637,6 +646,8 @@ table.popup-table .link{
margin: auto; margin: auto;
padding: 2em; padding: 2em;
z-index: 1001; z-index: 1001;
max-height: 90%;
max-width: 90%;
} }
/* fullpage image viewer */ /* fullpage image viewer */
...@@ -840,8 +851,16 @@ footer { ...@@ -840,8 +851,16 @@ footer {
/* extra networks UI */ /* extra networks UI */
.extra-page .prompt{ .extra-page > div.gap{
margin: 0 0 0.5em 0; gap: 0;
}
.extra-page-prompts{
margin-bottom: 0;
}
.extra-page-prompts.extra-page-prompts-active{
margin-bottom: 1em;
} }
.extra-network-cards{ .extra-network-cards{
......
...@@ -89,7 +89,7 @@ delimiter="################################################################" ...@@ -89,7 +89,7 @@ delimiter="################################################################"
printf "\n%s\n" "${delimiter}" printf "\n%s\n" "${delimiter}"
printf "\e[1m\e[32mInstall script for stable-diffusion + Web UI\n" printf "\e[1m\e[32mInstall script for stable-diffusion + Web UI\n"
printf "\e[1m\e[34mTested on Debian 11 (Bullseye)\e[0m" printf "\e[1m\e[34mTested on Debian 11 (Bullseye), Fedora 34+ and openSUSE Leap 15.4 or newer.\e[0m"
printf "\n%s\n" "${delimiter}" printf "\n%s\n" "${delimiter}"
# Do not run as root # Do not run as root
...@@ -223,7 +223,7 @@ fi ...@@ -223,7 +223,7 @@ fi
# Try using TCMalloc on Linux # Try using TCMalloc on Linux
prepare_tcmalloc() { prepare_tcmalloc() {
if [[ "${OSTYPE}" == "linux"* ]] && [[ -z "${NO_TCMALLOC}" ]] && [[ -z "${LD_PRELOAD}" ]]; then if [[ "${OSTYPE}" == "linux"* ]] && [[ -z "${NO_TCMALLOC}" ]] && [[ -z "${LD_PRELOAD}" ]]; then
TCMALLOC="$(PATH=/usr/sbin:$PATH ldconfig -p | grep -Po "libtcmalloc(_minimal|)\.so\.\d" | head -n 1)" TCMALLOC="$(PATH=/sbin:$PATH ldconfig -p | grep -Po "libtcmalloc(_minimal|)\.so\.\d" | head -n 1)"
if [[ ! -z "${TCMALLOC}" ]]; then if [[ ! -z "${TCMALLOC}" ]]; then
echo "Using TCMalloc: ${TCMALLOC}" echo "Using TCMalloc: ${TCMALLOC}"
export LD_PRELOAD="${TCMALLOC}" export LD_PRELOAD="${TCMALLOC}"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment