Commit 38864816 authored by CodeHatchling's avatar CodeHatchling

Merge remote-tracking branch 'origin2/dev' into soft-inpainting

# Conflicts:
#	modules/processing.py
parents 49bbf114 22e23dbf
This diff is collapsed.
...@@ -64,11 +64,14 @@ class ExtraOptionsSection(scripts.Script): ...@@ -64,11 +64,14 @@ class ExtraOptionsSection(scripts.Script):
p.override_settings[name] = value p.override_settings[name] = value
shared.options_templates.update(shared.options_section(('ui', "User interface"), { shared.options_templates.update(shared.options_section(('settings_in_ui', "Settings in UI", "ui"), {
"extra_options_txt2img": shared.OptionInfo([], "Options in main UI - txt2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img interfaces").needs_reload_ui(), "settings_in_ui": shared.OptionHTML("""
"extra_options_img2img": shared.OptionInfo([], "Options in main UI - img2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in img2img interfaces").needs_reload_ui(), This page allows you to add some settings to the main interface of txt2img and img2img tabs.
"extra_options_cols": shared.OptionInfo(1, "Options in main UI - number of columns", gr.Number, {"precision": 0}).needs_reload_ui(), """),
"extra_options_accordion": shared.OptionInfo(False, "Options in main UI - place into an accordion").needs_reload_ui() "extra_options_txt2img": shared.OptionInfo([], "Settings for txt2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img interfaces").needs_reload_ui(),
"extra_options_img2img": shared.OptionInfo([], "Settings for img2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in img2img interfaces").needs_reload_ui(),
"extra_options_cols": shared.OptionInfo(1, "Number of columns for added settings", gr.Number, {"precision": 0}).needs_reload_ui(),
"extra_options_accordion": shared.OptionInfo(False, "Place added settings into an accordion").needs_reload_ui()
})) }))
...@@ -6,7 +6,6 @@ Original author: @tfernd Github: https://github.com/tfernd/HyperTile ...@@ -6,7 +6,6 @@ Original author: @tfernd Github: https://github.com/tfernd/HyperTile
from __future__ import annotations from __future__ import annotations
import functools
from dataclasses import dataclass from dataclasses import dataclass
from typing import Callable from typing import Callable
...@@ -189,20 +188,27 @@ DEPTH_LAYERS_XL = { ...@@ -189,20 +188,27 @@ DEPTH_LAYERS_XL = {
RNG_INSTANCE = random.Random() RNG_INSTANCE = random.Random()
@cache
def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int: def get_divisors(value: int, min_value: int, /, max_options: int = 1) -> list[int]:
""" """
Returns a random divisor of value that Returns divisors of value that
x * min_value <= value x * min_value <= value
if max_options is 1, the behavior is deterministic in big -> small order, amount of divisors is limited by max_options
""" """
max_options = max(1, max_options) # at least 1 option should be returned
min_value = min(min_value, value) min_value = min(min_value, value)
# All big divisors of value (inclusive)
divisors = [i for i in range(min_value, value + 1) if value % i == 0] # divisors in small -> big order divisors = [i for i in range(min_value, value + 1) if value % i == 0] # divisors in small -> big order
ns = [value // i for i in divisors[:max_options]] # has at least 1 element # big -> small order ns = [value // i for i in divisors[:max_options]] # has at least 1 element # big -> small order
return ns
def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
"""
Returns a random divisor of value that
x * min_value <= value
if max_options is 1, the behavior is deterministic
"""
ns = get_divisors(value, min_value, max_options=max_options) # get cached divisors
idx = RNG_INSTANCE.randint(0, len(ns) - 1) idx = RNG_INSTANCE.randint(0, len(ns) - 1)
return ns[idx] return ns[idx]
...@@ -212,7 +218,7 @@ def set_hypertile_seed(seed: int) -> None: ...@@ -212,7 +218,7 @@ def set_hypertile_seed(seed: int) -> None:
RNG_INSTANCE.seed(seed) RNG_INSTANCE.seed(seed)
@functools.cache @cache
def largest_tile_size_available(width: int, height: int) -> int: def largest_tile_size_available(width: int, height: int) -> int:
""" """
Calculates the largest tile size available for a given width and height Calculates the largest tile size available for a given width and height
......
import hypertile import hypertile
from modules import scripts, script_callbacks, shared from modules import scripts, script_callbacks, shared
from scripts.hypertile_xyz import add_axis_options
class ScriptHypertile(scripts.Script): class ScriptHypertile(scripts.Script):
...@@ -16,8 +17,42 @@ class ScriptHypertile(scripts.Script): ...@@ -16,8 +17,42 @@ class ScriptHypertile(scripts.Script):
configure_hypertile(p.width, p.height, enable_unet=shared.opts.hypertile_enable_unet) configure_hypertile(p.width, p.height, enable_unet=shared.opts.hypertile_enable_unet)
self.add_infotext(p)
def before_hr(self, p, *args): def before_hr(self, p, *args):
configure_hypertile(p.hr_upscale_to_x, p.hr_upscale_to_y, enable_unet=shared.opts.hypertile_enable_unet_secondpass or shared.opts.hypertile_enable_unet)
enable = shared.opts.hypertile_enable_unet_secondpass or shared.opts.hypertile_enable_unet
# exclusive hypertile seed for the second pass
if enable:
hypertile.set_hypertile_seed(p.all_seeds[0])
configure_hypertile(p.hr_upscale_to_x, p.hr_upscale_to_y, enable_unet=enable)
if enable and not shared.opts.hypertile_enable_unet:
p.extra_generation_params["Hypertile U-Net second pass"] = True
self.add_infotext(p, add_unet_params=True)
def add_infotext(self, p, add_unet_params=False):
def option(name):
value = getattr(shared.opts, name)
default_value = shared.opts.get_default(name)
return None if value == default_value else value
if shared.opts.hypertile_enable_unet:
p.extra_generation_params["Hypertile U-Net"] = True
if shared.opts.hypertile_enable_unet or add_unet_params:
p.extra_generation_params["Hypertile U-Net max depth"] = option('hypertile_max_depth_unet')
p.extra_generation_params["Hypertile U-Net max tile size"] = option('hypertile_max_tile_unet')
p.extra_generation_params["Hypertile U-Net swap size"] = option('hypertile_swap_size_unet')
if shared.opts.hypertile_enable_vae:
p.extra_generation_params["Hypertile VAE"] = True
p.extra_generation_params["Hypertile VAE max depth"] = option('hypertile_max_depth_vae')
p.extra_generation_params["Hypertile VAE max tile size"] = option('hypertile_max_tile_vae')
p.extra_generation_params["Hypertile VAE swap size"] = option('hypertile_swap_size_vae')
def configure_hypertile(width, height, enable_unet=True): def configure_hypertile(width, height, enable_unet=True):
...@@ -53,16 +88,16 @@ def on_ui_settings(): ...@@ -53,16 +88,16 @@ def on_ui_settings():
benefit. benefit.
"""), """),
"hypertile_enable_unet": shared.OptionInfo(False, "Enable Hypertile U-Net").info("noticeable change in details of the generated picture; if enabled, overrides the setting below"), "hypertile_enable_unet": shared.OptionInfo(False, "Enable Hypertile U-Net", infotext="Hypertile U-Net").info("enables hypertile for all modes, including hires fix second pass; noticeable change in details of the generated picture"),
"hypertile_enable_unet_secondpass": shared.OptionInfo(False, "Enable Hypertile U-Net for hires fix second pass"), "hypertile_enable_unet_secondpass": shared.OptionInfo(False, "Enable Hypertile U-Net for hires fix second pass", infotext="Hypertile U-Net second pass").info("enables hypertile just for hires fix second pass - regardless of whether the above setting is enabled"),
"hypertile_max_depth_unet": shared.OptionInfo(3, "Hypertile U-Net max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}), "hypertile_max_depth_unet": shared.OptionInfo(3, "Hypertile U-Net max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}, infotext="Hypertile U-Net max depth").info("larger = more neural network layers affected; minor effect on performance"),
"hypertile_max_tile_unet": shared.OptionInfo(256, "Hypertile U-net max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}), "hypertile_max_tile_unet": shared.OptionInfo(256, "Hypertile U-Net max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, infotext="Hypertile U-Net max tile size").info("larger = worse performance"),
"hypertile_swap_size_unet": shared.OptionInfo(3, "Hypertile U-net swap size", gr.Slider, {"minimum": 0, "maximum": 6, "step": 1}), "hypertile_swap_size_unet": shared.OptionInfo(3, "Hypertile U-Net swap size", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, infotext="Hypertile U-Net swap size"),
"hypertile_enable_vae": shared.OptionInfo(False, "Enable Hypertile VAE").info("minimal change in the generated picture"), "hypertile_enable_vae": shared.OptionInfo(False, "Enable Hypertile VAE", infotext="Hypertile VAE").info("minimal change in the generated picture"),
"hypertile_max_depth_vae": shared.OptionInfo(3, "Hypertile VAE max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}), "hypertile_max_depth_vae": shared.OptionInfo(3, "Hypertile VAE max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}, infotext="Hypertile VAE max depth"),
"hypertile_max_tile_vae": shared.OptionInfo(128, "Hypertile VAE max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}), "hypertile_max_tile_vae": shared.OptionInfo(128, "Hypertile VAE max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, infotext="Hypertile VAE max tile size"),
"hypertile_swap_size_vae": shared.OptionInfo(3, "Hypertile VAE swap size ", gr.Slider, {"minimum": 0, "maximum": 6, "step": 1}), "hypertile_swap_size_vae": shared.OptionInfo(3, "Hypertile VAE swap size ", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, infotext="Hypertile VAE swap size"),
} }
for name, opt in options.items(): for name, opt in options.items():
...@@ -71,3 +106,4 @@ def on_ui_settings(): ...@@ -71,3 +106,4 @@ def on_ui_settings():
script_callbacks.on_ui_settings(on_ui_settings) script_callbacks.on_ui_settings(on_ui_settings)
script_callbacks.on_before_ui(add_axis_options)
from modules import scripts
from modules.shared import opts
xyz_grid = [x for x in scripts.scripts_data if x.script_class.__module__ == "xyz_grid.py"][0].module
def int_applier(value_name:str, min_range:int = -1, max_range:int = -1):
"""
Returns a function that applies the given value to the given value_name in opts.data.
"""
def validate(value_name:str, value:str):
value = int(value)
# validate value
if not min_range == -1:
assert value >= min_range, f"Value {value} for {value_name} must be greater than or equal to {min_range}"
if not max_range == -1:
assert value <= max_range, f"Value {value} for {value_name} must be less than or equal to {max_range}"
def apply_int(p, x, xs):
validate(value_name, x)
opts.data[value_name] = int(x)
return apply_int
def bool_applier(value_name:str):
"""
Returns a function that applies the given value to the given value_name in opts.data.
"""
def validate(value_name:str, value:str):
assert value.lower() in ["true", "false"], f"Value {value} for {value_name} must be either true or false"
def apply_bool(p, x, xs):
validate(value_name, x)
value_boolean = x.lower() == "true"
opts.data[value_name] = value_boolean
return apply_bool
def add_axis_options():
extra_axis_options = [
xyz_grid.AxisOption("[Hypertile] Unet First pass Enabled", str, bool_applier("hypertile_enable_unet"), choices=xyz_grid.boolean_choice(reverse=True)),
xyz_grid.AxisOption("[Hypertile] Unet Second pass Enabled", str, bool_applier("hypertile_enable_unet_secondpass"), choices=xyz_grid.boolean_choice(reverse=True)),
xyz_grid.AxisOption("[Hypertile] Unet Max Depth", int, int_applier("hypertile_max_depth_unet", 0, 3), choices=lambda: [str(x) for x in range(4)]),
xyz_grid.AxisOption("[Hypertile] Unet Max Tile Size", int, int_applier("hypertile_max_tile_unet", 0, 512)),
xyz_grid.AxisOption("[Hypertile] Unet Swap Size", int, int_applier("hypertile_swap_size_unet", 0, 64)),
xyz_grid.AxisOption("[Hypertile] VAE Enabled", str, bool_applier("hypertile_enable_vae"), choices=xyz_grid.boolean_choice(reverse=True)),
xyz_grid.AxisOption("[Hypertile] VAE Max Depth", int, int_applier("hypertile_max_depth_vae", 0, 3), choices=lambda: [str(x) for x in range(4)]),
xyz_grid.AxisOption("[Hypertile] VAE Max Tile Size", int, int_applier("hypertile_max_tile_vae", 0, 512)),
xyz_grid.AxisOption("[Hypertile] VAE Swap Size", int, int_applier("hypertile_swap_size_vae", 0, 64)),
]
set_a = {opt.label for opt in xyz_grid.axis_options}
set_b = {opt.label for opt in extra_axis_options}
if set_a.intersection(set_b):
return
xyz_grid.axis_options.extend(extra_axis_options)
...@@ -392,3 +392,9 @@ function extraNetworksRefreshSingleCard(page, tabname, name) { ...@@ -392,3 +392,9 @@ function extraNetworksRefreshSingleCard(page, tabname, name) {
} }
}); });
} }
window.addEventListener("keydown", function(event) {
if (event.key == "Escape") {
closePopup();
}
});
...@@ -170,6 +170,23 @@ function submit_img2img() { ...@@ -170,6 +170,23 @@ function submit_img2img() {
return res; return res;
} }
function submit_extras() {
showSubmitButtons('extras', false);
var id = randomId();
requestProgress(id, gradioApp().getElementById('extras_gallery_container'), gradioApp().getElementById('extras_gallery'), function() {
showSubmitButtons('extras', true);
});
var res = create_submit_args(arguments);
res[0] = id;
console.log(res);
return res;
}
function restoreProgressTxt2img() { function restoreProgressTxt2img() {
showRestoreProgressButton("txt2img", false); showRestoreProgressButton("txt2img", false);
var id = localGet("txt2img_task_id"); var id = localGet("txt2img_task_id");
......
...@@ -22,7 +22,6 @@ from modules.api import models ...@@ -22,7 +22,6 @@ from modules.api import models
from modules.shared import opts from modules.shared import opts
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
from PIL import PngImagePlugin, Image from PIL import PngImagePlugin, Image
from modules.sd_models_config import find_checkpoint_config_near_filename from modules.sd_models_config import find_checkpoint_config_near_filename
...@@ -235,7 +234,6 @@ class Api: ...@@ -235,7 +234,6 @@ class Api:
self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"]) self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse) self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse) self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse)
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse) self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse)
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse) self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse)
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse) self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)
...@@ -675,19 +673,6 @@ class Api: ...@@ -675,19 +673,6 @@ class Api:
finally: finally:
shared.state.end() shared.state.end()
def preprocess(self, args: dict):
try:
shared.state.begin(job="preprocess")
preprocess(**args) # quick operation unless blip/booru interrogation is enabled
shared.state.end()
return models.PreprocessResponse(info='preprocess complete')
except KeyError as e:
return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}")
except Exception as e:
return models.PreprocessResponse(info=f"preprocess error: {e}")
finally:
shared.state.end()
def train_embedding(self, args: dict): def train_embedding(self, args: dict):
try: try:
shared.state.begin(job="train_embedding") shared.state.begin(job="train_embedding")
......
...@@ -202,9 +202,6 @@ class TrainResponse(BaseModel): ...@@ -202,9 +202,6 @@ class TrainResponse(BaseModel):
class CreateResponse(BaseModel): class CreateResponse(BaseModel):
info: str = Field(title="Create info", description="Response string from create embedding or hypernetwork task.") info: str = Field(title="Create info", description="Response string from create embedding or hypernetwork task.")
class PreprocessResponse(BaseModel):
info: str = Field(title="Preprocess info", description="Response string from preprocessing task.")
fields = {} fields = {}
for key, metadata in opts.data_labels.items(): for key, metadata in opts.data_labels.items():
value = opts.data.get(key) value = opts.data.get(key)
......
...@@ -70,6 +70,7 @@ parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="pre ...@@ -70,6 +70,7 @@ parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="pre
parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization") parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization")
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI") parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower) parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
parser.add_argument("--use-ipex", action="store_true", help="use Intel XPU as torch device")
parser.add_argument("--disable-model-loading-ram-optimization", action='store_true', help="disable an optimization that reduces RAM use when loading a model") parser.add_argument("--disable-model-loading-ram-optimization", action='store_true', help="disable an optimization that reduces RAM use when loading a model")
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests") parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None) parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
......
...@@ -8,6 +8,13 @@ from modules import errors, shared ...@@ -8,6 +8,13 @@ from modules import errors, shared
if sys.platform == "darwin": if sys.platform == "darwin":
from modules import mac_specific from modules import mac_specific
if shared.cmd_opts.use_ipex:
from modules import xpu_specific
def has_xpu() -> bool:
return shared.cmd_opts.use_ipex and xpu_specific.has_xpu
def has_mps() -> bool: def has_mps() -> bool:
if sys.platform != "darwin": if sys.platform != "darwin":
...@@ -30,6 +37,9 @@ def get_optimal_device_name(): ...@@ -30,6 +37,9 @@ def get_optimal_device_name():
if has_mps(): if has_mps():
return "mps" return "mps"
if has_xpu():
return xpu_specific.get_xpu_device_string()
return "cpu" return "cpu"
...@@ -38,7 +48,7 @@ def get_optimal_device(): ...@@ -38,7 +48,7 @@ def get_optimal_device():
def get_device_for(task): def get_device_for(task):
if task in shared.cmd_opts.use_cpu: if task in shared.cmd_opts.use_cpu or "all" in shared.cmd_opts.use_cpu:
return cpu return cpu
return get_optimal_device() return get_optimal_device()
...@@ -54,6 +64,9 @@ def torch_gc(): ...@@ -54,6 +64,9 @@ def torch_gc():
if has_mps(): if has_mps():
mac_specific.torch_mps_gc() mac_specific.torch_mps_gc()
if has_xpu():
xpu_specific.torch_xpu_gc()
def enable_tf32(): def enable_tf32():
if torch.cuda.is_available(): if torch.cuda.is_available():
......
from __future__ import annotations
import base64 import base64
import io import io
import json import json
...@@ -15,9 +16,6 @@ re_imagesize = re.compile(r"^(\d+)x(\d+)$") ...@@ -15,9 +16,6 @@ re_imagesize = re.compile(r"^(\d+)x(\d+)$")
re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$") re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$")
type_of_gr_update = type(gr.update()) type_of_gr_update = type(gr.update())
paste_fields = {}
registered_param_bindings = []
class ParamBinding: class ParamBinding:
def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=None): def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=None):
...@@ -30,6 +28,10 @@ class ParamBinding: ...@@ -30,6 +28,10 @@ class ParamBinding:
self.paste_field_names = paste_field_names or [] self.paste_field_names = paste_field_names or []
paste_fields: dict[str, dict] = {}
registered_param_bindings: list[ParamBinding] = []
def reset(): def reset():
paste_fields.clear() paste_fields.clear()
registered_param_bindings.clear() registered_param_bindings.clear()
...@@ -113,7 +115,6 @@ def register_paste_params_button(binding: ParamBinding): ...@@ -113,7 +115,6 @@ def register_paste_params_button(binding: ParamBinding):
def connect_paste_params_buttons(): def connect_paste_params_buttons():
binding: ParamBinding
for binding in registered_param_bindings: for binding in registered_param_bindings:
destination_image_component = paste_fields[binding.tabname]["init_img"] destination_image_component = paste_fields[binding.tabname]["init_img"]
fields = paste_fields[binding.tabname]["fields"] fields = paste_fields[binding.tabname]["fields"]
...@@ -313,6 +314,9 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model ...@@ -313,6 +314,9 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
if "VAE Decoder" not in res: if "VAE Decoder" not in res:
res["VAE Decoder"] = "Full" res["VAE Decoder"] = "Full"
skip = set(shared.opts.infotext_skip_pasting)
res = {k: v for k, v in res.items() if k not in skip}
return res return res
...@@ -443,3 +447,4 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component, ...@@ -443,3 +447,4 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
outputs=[], outputs=[],
show_progress=False, show_progress=False,
) )
...@@ -47,10 +47,20 @@ def Block_get_config(self): ...@@ -47,10 +47,20 @@ def Block_get_config(self):
def BlockContext_init(self, *args, **kwargs): def BlockContext_init(self, *args, **kwargs):
if scripts.scripts_current is not None:
scripts.scripts_current.before_component(self, **kwargs)
scripts.script_callbacks.before_component_callback(self, **kwargs)
res = original_BlockContext_init(self, *args, **kwargs) res = original_BlockContext_init(self, *args, **kwargs)
add_classes_to_gradio_component(self) add_classes_to_gradio_component(self)
scripts.script_callbacks.after_component_callback(self, **kwargs)
if scripts.scripts_current is not None:
scripts.scripts_current.after_component(self, **kwargs)
return res return res
......
...@@ -3,3 +3,14 @@ import sys ...@@ -3,3 +3,14 @@ import sys
# this will break any attempt to import xformers which will prevent stability diffusion repo from trying to use it # this will break any attempt to import xformers which will prevent stability diffusion repo from trying to use it
if "--xformers" not in "".join(sys.argv): if "--xformers" not in "".join(sys.argv):
sys.modules["xformers"] = None sys.modules["xformers"] = None
# Hack to fix a changed import in torchvision 0.17+, which otherwise breaks
# basicsr; see https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/13985
try:
import torchvision.transforms.functional_tensor # noqa: F401
except ImportError:
try:
import torchvision.transforms.functional as functional
sys.modules["torchvision.transforms.functional_tensor"] = functional
except ImportError:
pass # shrug...
...@@ -6,6 +6,7 @@ import os ...@@ -6,6 +6,7 @@ import os
import shutil import shutil
import sys import sys
import importlib.util import importlib.util
import importlib.metadata
import platform import platform
import json import json
from functools import lru_cache from functools import lru_cache
...@@ -119,11 +120,16 @@ def run(command, desc=None, errdesc=None, custom_env=None, live: bool = default_ ...@@ -119,11 +120,16 @@ def run(command, desc=None, errdesc=None, custom_env=None, live: bool = default_
def is_installed(package): def is_installed(package):
try: try:
spec = importlib.util.find_spec(package) dist = importlib.metadata.distribution(package)
except ModuleNotFoundError: except importlib.metadata.PackageNotFoundError:
return False try:
spec = importlib.util.find_spec(package)
except ModuleNotFoundError:
return False
return spec is not None
return spec is not None return dist is not None
def repo_dir(name): def repo_dir(name):
...@@ -310,6 +316,26 @@ def requirements_met(requirements_file): ...@@ -310,6 +316,26 @@ def requirements_met(requirements_file):
def prepare_environment(): def prepare_environment():
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118") torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118")
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}") torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}")
if args.use_ipex:
if platform.system() == "Windows":
# The "Nuullll/intel-extension-for-pytorch" wheels were built from IPEX source for Intel Arc GPU: https://github.com/intel/intel-extension-for-pytorch/tree/xpu-main
# This is NOT an Intel official release so please use it at your own risk!!
# See https://github.com/Nuullll/intel-extension-for-pytorch/releases/tag/v2.0.110%2Bxpu-master%2Bdll-bundle for details.
#
# Strengths (over official IPEX 2.0.110 windows release):
# - AOT build (for Arc GPU only) to eliminate JIT compilation overhead: https://github.com/intel/intel-extension-for-pytorch/issues/399
# - Bundles minimal oneAPI 2023.2 dependencies into the python wheels, so users don't need to install oneAPI for the whole system.
# - Provides a compatible torchvision wheel: https://github.com/intel/intel-extension-for-pytorch/issues/465
# Limitation:
# - Only works for python 3.10
url_prefix = "https://github.com/Nuullll/intel-extension-for-pytorch/releases/download/v2.0.110%2Bxpu-master%2Bdll-bundle"
torch_command = os.environ.get('TORCH_COMMAND', f"pip install {url_prefix}/torch-2.0.0a0+gite9ebda2-cp310-cp310-win_amd64.whl {url_prefix}/torchvision-0.15.2a0+fa99a53-cp310-cp310-win_amd64.whl {url_prefix}/intel_extension_for_pytorch-2.0.110+gitc6ea20b-cp310-cp310-win_amd64.whl")
else:
# Using official IPEX release for linux since it's already an AOT build.
# However, users still have to install oneAPI toolkit and activate oneAPI environment manually.
# See https://intel.github.io/intel-extension-for-pytorch/index.html#installation for details.
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://pytorch-extension.intel.com/release-whl/stable/xpu/us/")
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.0a0 intel-extension-for-pytorch==2.0.110+gitba7f6c1 --extra-index-url {torch_index_url}")
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.20') xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.20')
...@@ -352,6 +378,8 @@ def prepare_environment(): ...@@ -352,6 +378,8 @@ def prepare_environment():
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True) run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
startup_timer.record("install torch") startup_timer.record("install torch")
if args.use_ipex:
args.skip_torch_cuda_test = True
if not args.skip_torch_cuda_test and not check_run_python("import torch; assert torch.cuda.is_available()"): if not args.skip_torch_cuda_test and not check_run_python("import torch; assert torch.cuda.is_available()"):
raise RuntimeError( raise RuntimeError(
'Torch is not able to use GPU; ' 'Torch is not able to use GPU; '
......
import logging import logging
import torch import torch
from torch import Tensor
import platform import platform
from modules.sd_hijack_utils import CondFunc from modules.sd_hijack_utils import CondFunc
from packaging import version from packaging import version
...@@ -51,6 +52,17 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs): ...@@ -51,6 +52,17 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs):
return cumsum_func(input, *args, **kwargs) return cumsum_func(input, *args, **kwargs)
# MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046
def interpolate_with_fp32_fallback(orig_func, *args, **kwargs) -> Tensor:
try:
return orig_func(*args, **kwargs)
except RuntimeError as e:
if "not implemented for" in str(e) and "Half" in str(e):
input_tensor = args[0]
return orig_func(input_tensor.to(torch.float32), *args[1:], **kwargs).to(input_tensor.dtype)
else:
print(f"An unexpected RuntimeError occurred: {str(e)}")
if has_mps: if has_mps:
if platform.mac_ver()[0].startswith("13.2."): if platform.mac_ver()[0].startswith("13.2."):
# MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124) # MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124)
...@@ -77,6 +89,9 @@ if has_mps: ...@@ -77,6 +89,9 @@ if has_mps:
# MPS workaround for https://github.com/pytorch/pytorch/issues/96113 # MPS workaround for https://github.com/pytorch/pytorch/issues/96113
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda _, input, *args, **kwargs: len(args) == 4 and input.device.type == 'mps') CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda _, input, *args, **kwargs: len(args) == 4 and input.device.type == 'mps')
# MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046
CondFunc('torch.nn.functional.interpolate', interpolate_with_fp32_fallback, None)
# MPS workaround for https://github.com/pytorch/pytorch/issues/92311 # MPS workaround for https://github.com/pytorch/pytorch/issues/92311
if platform.processor() == 'i386': if platform.processor() == 'i386':
for funcName in ['torch.argmax', 'torch.Tensor.argmax']: for funcName in ['torch.argmax', 'torch.Tensor.argmax']:
......
...@@ -24,10 +24,15 @@ from pytorch_lightning.utilities.distributed import rank_zero_only ...@@ -24,10 +24,15 @@ from pytorch_lightning.utilities.distributed import rank_zero_only
from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
from ldm.modules.ema import LitEma from ldm.modules.ema import LitEma
from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL from ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL
from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ddim import DDIMSampler
try:
from ldm.models.autoencoder import VQModelInterface
except Exception:
class VQModelInterface:
pass
__conditioning_keys__ = {'concat': 'c_concat', __conditioning_keys__ = {'concat': 'c_concat',
'crossattn': 'c_crossattn', 'crossattn': 'c_crossattn',
......
...@@ -29,11 +29,7 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, ...@@ -29,11 +29,7 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
image_list = shared.listfiles(input_dir) image_list = shared.listfiles(input_dir)
for filename in image_list: for filename in image_list:
try: yield filename, filename
image = Image.open(filename)
except Exception:
continue
yield image, filename
else: else:
assert image, 'image not selected' assert image, 'image not selected'
yield image, None yield image, None
...@@ -45,35 +41,85 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, ...@@ -45,35 +41,85 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
infotext = '' infotext = ''
for image_data, name in get_images(extras_mode, image, image_folder, input_dir): data_to_process = list(get_images(extras_mode, image, image_folder, input_dir))
shared.state.job_count = len(data_to_process)
for image_placeholder, name in data_to_process:
image_data: Image.Image image_data: Image.Image
shared.state.nextjob()
shared.state.textinfo = name shared.state.textinfo = name
shared.state.skipped = False
if shared.state.interrupted:
break
if isinstance(image_placeholder, str):
try:
image_data = Image.open(image_placeholder)
except Exception:
continue
else:
image_data = image_placeholder
shared.state.assign_current_image(image_data)
parameters, existing_pnginfo = images.read_info_from_image(image_data) parameters, existing_pnginfo = images.read_info_from_image(image_data)
if parameters: if parameters:
existing_pnginfo["parameters"] = parameters existing_pnginfo["parameters"] = parameters
pp = scripts_postprocessing.PostprocessedImage(image_data.convert("RGB")) initial_pp = scripts_postprocessing.PostprocessedImage(image_data.convert("RGB"))
scripts.scripts_postproc.run(pp, args) scripts.scripts_postproc.run(initial_pp, args)
if opts.use_original_name_batch and name is not None: if shared.state.skipped:
basename = os.path.splitext(os.path.basename(name))[0] continue
else:
basename = '' used_suffixes = {}
for pp in [initial_pp, *initial_pp.extra_images]:
suffix = pp.get_suffix(used_suffixes)
infotext = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in pp.info.items() if v is not None]) if opts.use_original_name_batch and name is not None:
basename = os.path.splitext(os.path.basename(name))[0]
forced_filename = basename + suffix
else:
basename = ''
forced_filename = None
if opts.enable_pnginfo: infotext = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in pp.info.items() if v is not None])
pp.image.info = existing_pnginfo
pp.image.info["postprocessing"] = infotext
if save_output: if opts.enable_pnginfo:
images.save_image(pp.image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None) pp.image.info = existing_pnginfo
pp.image.info["postprocessing"] = infotext
if extras_mode != 2 or show_extras_results: if save_output:
outputs.append(pp.image) fullfn, _ = images.save_image(pp.image, path=outpath, basename=basename, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=forced_filename, suffix=suffix)
if pp.caption:
caption_filename = os.path.splitext(fullfn)[0] + ".txt"
if os.path.isfile(caption_filename):
with open(caption_filename, encoding="utf8") as file:
existing_caption = file.read().strip()
else:
existing_caption = ""
action = shared.opts.postprocessing_existing_caption_action
if action == 'Prepend' and existing_caption:
caption = f"{existing_caption} {pp.caption}"
elif action == 'Append' and existing_caption:
caption = f"{pp.caption} {existing_caption}"
elif action == 'Keep' and existing_caption:
caption = existing_caption
else:
caption = pp.caption
caption = caption.strip()
if caption:
with open(caption_filename, "w", encoding="utf8") as file:
file.write(caption)
if extras_mode != 2 or show_extras_results:
outputs.append(pp.image)
image_data.close() image_data.close()
...@@ -82,6 +128,10 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, ...@@ -82,6 +128,10 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
return outputs, ui_common.plaintext_to_html(infotext), '' return outputs, ui_common.plaintext_to_html(infotext), ''
def run_postprocessing_webui(id_task, *args, **kwargs):
return run_postprocessing(*args, **kwargs)
def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True): def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True):
"""old handler for API""" """old handler for API"""
...@@ -97,9 +147,11 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ ...@@ -97,9 +147,11 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
"upscaler_2_visibility": extras_upscaler_2_visibility, "upscaler_2_visibility": extras_upscaler_2_visibility,
}, },
"GFPGAN": { "GFPGAN": {
"enable": True,
"gfpgan_visibility": gfpgan_visibility, "gfpgan_visibility": gfpgan_visibility,
}, },
"CodeFormer": { "CodeFormer": {
"enable": True,
"codeformer_visibility": codeformer_visibility, "codeformer_visibility": codeformer_visibility,
"codeformer_weight": codeformer_weight, "codeformer_weight": codeformer_weight,
}, },
......
...@@ -692,8 +692,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter ...@@ -692,8 +692,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Size": f"{p.width}x{p.height}", "Size": f"{p.width}x{p.height}",
"Model hash": p.sd_model_hash if opts.add_model_hash_to_info else None, "Model hash": p.sd_model_hash if opts.add_model_hash_to_info else None,
"Model": p.sd_model_name if opts.add_model_name_to_info else None, "Model": p.sd_model_name if opts.add_model_name_to_info else None,
"VAE hash": p.sd_vae_hash if opts.add_model_hash_to_info else None, "VAE hash": p.sd_vae_hash if opts.add_vae_hash_to_info else None,
"VAE": p.sd_vae_name if opts.add_model_name_to_info else None, "VAE": p.sd_vae_name if opts.add_vae_name_to_info else None,
"Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])), "Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
"Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), "Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
...@@ -980,27 +980,26 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: ...@@ -980,27 +980,26 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if opts.enable_pnginfo: if opts.enable_pnginfo:
image.info["parameters"] = text image.info["parameters"] = text
output_images.append(image) output_images.append(image)
if save_samples and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]):
if hasattr(p, 'masks_for_overlay') and p.masks_for_overlay:
image_mask = p.masks_for_overlay[i].convert('RGB')
image_mask_composite = Image.composite(original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.masks_for_overlay[i], image.width, image.height).convert('L')).convert('RGBA')
elif hasattr(p, 'mask_for_overlay') and p.mask_for_overlay:
image_mask = p.mask_for_overlay.convert('RGB')
image_mask_composite = Image.composite(original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
else:
image_mask = None
image_mask_composite = None
if image_mask is not None and image_mask_composite is not None:
if opts.save_mask:
images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask")
if opts.save_mask_composite: if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay:
images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite") mask_for_overlay = p.mask_for_overlay
elif hasattr(p, 'masks_for_overlay') and p.masks_for_overlay and p.masks_for_overlay[i]:
mask_for_overlay = p.masks_for_overlay[i]
else:
mask_for_overlay = None
if mask_for_overlay is not None:
if opts.return_mask or opts.save_mask:
image_mask = mask_for_overlay.convert('RGB')
if save_samples and opts.save_mask:
images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask")
if opts.return_mask: if opts.return_mask:
output_images.append(image_mask) output_images.append(image_mask)
if opts.return_mask_composite or opts.save_mask_composite:
image_mask_composite = Image.composite(original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
if save_samples and opts.save_mask_composite:
images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite")
if opts.return_mask_composite: if opts.return_mask_composite:
output_images.append(image_mask_composite) output_images.append(image_mask_composite)
......
...@@ -560,17 +560,25 @@ class ScriptRunner: ...@@ -560,17 +560,25 @@ class ScriptRunner:
on_after.clear() on_after.clear()
def create_script_ui(self, script): def create_script_ui(self, script):
import modules.api.models as api_models
script.args_from = len(self.inputs) script.args_from = len(self.inputs)
script.args_to = len(self.inputs) script.args_to = len(self.inputs)
try:
self.create_script_ui_inner(script)
except Exception:
errors.report(f"Error creating UI for {script.name}: ", exc_info=True)
def create_script_ui_inner(self, script):
import modules.api.models as api_models
controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img) controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
if controls is None: if controls is None:
return return
script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower() script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower()
api_args = [] api_args = []
for control in controls: for control in controls:
......
import dataclasses
import os import os
import gradio as gr import gradio as gr
from modules import errors, shared from modules import errors, shared
@dataclasses.dataclass
class PostprocessedImageSharedInfo:
target_width: int = None
target_height: int = None
class PostprocessedImage: class PostprocessedImage:
def __init__(self, image): def __init__(self, image):
self.image = image self.image = image
self.info = {} self.info = {}
self.shared = PostprocessedImageSharedInfo()
self.extra_images = []
self.nametags = []
self.disable_processing = False
self.caption = None
def get_suffix(self, used_suffixes=None):
used_suffixes = {} if used_suffixes is None else used_suffixes
suffix = "-".join(self.nametags)
if suffix:
suffix = "-" + suffix
if suffix not in used_suffixes:
used_suffixes[suffix] = 1
return suffix
for i in range(1, 100):
proposed_suffix = suffix + "-" + str(i)
if proposed_suffix not in used_suffixes:
used_suffixes[proposed_suffix] = 1
return proposed_suffix
return suffix
def create_copy(self, new_image, *, nametags=None, disable_processing=False):
pp = PostprocessedImage(new_image)
pp.shared = self.shared
pp.nametags = self.nametags.copy()
pp.info = self.info.copy()
pp.disable_processing = disable_processing
if nametags is not None:
pp.nametags += nametags
return pp
class ScriptPostprocessing: class ScriptPostprocessing:
...@@ -42,10 +85,17 @@ class ScriptPostprocessing: ...@@ -42,10 +85,17 @@ class ScriptPostprocessing:
pass pass
def image_changed(self): def process_firstpass(self, pp: PostprocessedImage, **args):
pass """
Called for all scripts before calling process(). Scripts can examine the image here and set fields
of the pp object to communicate things to other scripts.
args contains a dictionary with all values returned by components from ui()
"""
pass
def image_changed(self):
pass
def wrap_call(func, filename, funcname, *args, default=None, **kwargs): def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
...@@ -118,16 +168,42 @@ class ScriptPostprocessingRunner: ...@@ -118,16 +168,42 @@ class ScriptPostprocessingRunner:
return inputs return inputs
def run(self, pp: PostprocessedImage, args): def run(self, pp: PostprocessedImage, args):
for script in self.scripts_in_preferred_order(): scripts = []
shared.state.job = script.name
for script in self.scripts_in_preferred_order():
script_args = args[script.args_from:script.args_to] script_args = args[script.args_from:script.args_to]
process_args = {} process_args = {}
for (name, _component), value in zip(script.controls.items(), script_args): for (name, _component), value in zip(script.controls.items(), script_args):
process_args[name] = value process_args[name] = value
script.process(pp, **process_args) scripts.append((script, process_args))
for script, process_args in scripts:
script.process_firstpass(pp, **process_args)
all_images = [pp]
for script, process_args in scripts:
if shared.state.skipped:
break
shared.state.job = script.name
for single_image in all_images.copy():
if not single_image.disable_processing:
script.process(single_image, **process_args)
for extra_image in single_image.extra_images:
if not isinstance(extra_image, PostprocessedImage):
extra_image = single_image.create_copy(extra_image)
all_images.append(extra_image)
single_image.extra_images.clear()
pp.extra_images = all_images[1:]
def create_args_for_run(self, scripts_args): def create_args_for_run(self, scripts_args):
if not self.ui_created: if not self.ui_created:
......
...@@ -38,8 +38,12 @@ ldm.models.diffusion.ddpm.print = shared.ldm_print ...@@ -38,8 +38,12 @@ ldm.models.diffusion.ddpm.print = shared.ldm_print
optimizers = [] optimizers = []
current_optimizer: sd_hijack_optimizations.SdOptimization = None current_optimizer: sd_hijack_optimizations.SdOptimization = None
ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sd_unet.UNetModel_forward) ldm_patched_forward = sd_unet.create_unet_forward(ldm.modules.diffusionmodules.openaimodel.UNetModel.forward)
sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sd_unet.UNetModel_forward) ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", ldm_patched_forward)
sgm_patched_forward = sd_unet.create_unet_forward(sgm.modules.diffusionmodules.openaimodel.UNetModel.forward)
sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sgm_patched_forward)
def list_optimizers(): def list_optimizers():
new_optimizers = script_callbacks.list_optimizers_callback() new_optimizers = script_callbacks.list_optimizers_callback()
...@@ -303,8 +307,6 @@ class StableDiffusionModelHijack: ...@@ -303,8 +307,6 @@ class StableDiffusionModelHijack:
self.layers = None self.layers = None
self.clip = None self.clip = None
sd_unet.original_forward = None
def apply_circular(self, enable): def apply_circular(self, enable):
if self.circular_enabled == enable: if self.circular_enabled == enable:
......
...@@ -230,15 +230,19 @@ def select_checkpoint(): ...@@ -230,15 +230,19 @@ def select_checkpoint():
return checkpoint_info return checkpoint_info
checkpoint_dict_replacements = { checkpoint_dict_replacements_sd1 = {
'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.', 'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.', 'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.', 'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.',
} }
checkpoint_dict_replacements_sd2_turbo = { # Converts SD 2.1 Turbo from SGM to LDM format.
'conditioner.embedders.0.': 'cond_stage_model.',
}
def transform_checkpoint_dict_key(k): def transform_checkpoint_dict_key(k, replacements):
for text, replacement in checkpoint_dict_replacements.items(): for text, replacement in replacements.items():
if k.startswith(text): if k.startswith(text):
k = replacement + k[len(text):] k = replacement + k[len(text):]
...@@ -249,9 +253,14 @@ def get_state_dict_from_checkpoint(pl_sd): ...@@ -249,9 +253,14 @@ def get_state_dict_from_checkpoint(pl_sd):
pl_sd = pl_sd.pop("state_dict", pl_sd) pl_sd = pl_sd.pop("state_dict", pl_sd)
pl_sd.pop("state_dict", None) pl_sd.pop("state_dict", None)
is_sd2_turbo = 'conditioner.embedders.0.model.ln_final.weight' in pl_sd and pl_sd['conditioner.embedders.0.model.ln_final.weight'].size()[0] == 1024
sd = {} sd = {}
for k, v in pl_sd.items(): for k, v in pl_sd.items():
new_key = transform_checkpoint_dict_key(k) if is_sd2_turbo:
new_key = transform_checkpoint_dict_key(k, checkpoint_dict_replacements_sd2_turbo)
else:
new_key = transform_checkpoint_dict_key(k, checkpoint_dict_replacements_sd1)
if new_key is not None: if new_key is not None:
sd[new_key] = v sd[new_key] = v
......
...@@ -11,7 +11,7 @@ from modules.models.diffusion.uni_pc import uni_pc ...@@ -11,7 +11,7 @@ from modules.models.diffusion.uni_pc import uni_pc
def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0): def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0):
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
alphas = alphas_cumprod[timesteps] alphas = alphas_cumprod[timesteps]
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' else torch.float32) alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' and x.device.type != 'xpu' else torch.float32)
sqrt_one_minus_alphas = torch.sqrt(1 - alphas) sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy())) sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))
...@@ -43,7 +43,7 @@ def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta= ...@@ -43,7 +43,7 @@ def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=
def plms(model, x, timesteps, extra_args=None, callback=None, disable=None): def plms(model, x, timesteps, extra_args=None, callback=None, disable=None):
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
alphas = alphas_cumprod[timesteps] alphas = alphas_cumprod[timesteps]
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' else torch.float32) alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' and x.device.type != 'xpu' else torch.float32)
sqrt_one_minus_alphas = torch.sqrt(1 - alphas) sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
extra_args = {} if extra_args is None else extra_args extra_args = {} if extra_args is None else extra_args
......
...@@ -5,8 +5,7 @@ from modules import script_callbacks, shared, devices ...@@ -5,8 +5,7 @@ from modules import script_callbacks, shared, devices
unet_options = [] unet_options = []
current_unet_option = None current_unet_option = None
current_unet = None current_unet = None
original_forward = None original_forward = None # not used, only left temporarily for compatibility
def list_unets(): def list_unets():
new_unets = script_callbacks.list_unets_callback() new_unets = script_callbacks.list_unets_callback()
...@@ -84,9 +83,12 @@ class SdUnet(torch.nn.Module): ...@@ -84,9 +83,12 @@ class SdUnet(torch.nn.Module):
pass pass
def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs): def create_unet_forward(original_forward):
if current_unet is not None: def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs):
return current_unet.forward(x, timesteps, context, *args, **kwargs) if current_unet is not None:
return current_unet.forward(x, timesteps, context, *args, **kwargs)
return original_forward(self, x, timesteps, context, *args, **kwargs)
return original_forward(self, x, timesteps, context, *args, **kwargs) return UNetModel_forward
...@@ -66,6 +66,22 @@ def reload_hypernetworks(): ...@@ -66,6 +66,22 @@ def reload_hypernetworks():
shared.hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) shared.hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
def get_infotext_names():
from modules import generation_parameters_copypaste, shared
res = {}
for info in shared.opts.data_labels.values():
if info.infotext:
res[info.infotext] = 1
for tab_data in generation_parameters_copypaste.paste_fields.values():
for _, name in tab_data.get("fields") or []:
if isinstance(name, str):
res[name] = 1
return list(res)
ui_reorder_categories_builtin_items = [ ui_reorder_categories_builtin_items = [
"prompt", "prompt",
"image", "image",
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -3,7 +3,6 @@ import html ...@@ -3,7 +3,6 @@ import html
import gradio as gr import gradio as gr
import modules.textual_inversion.textual_inversion import modules.textual_inversion.textual_inversion
import modules.textual_inversion.preprocess
from modules import sd_hijack, shared from modules import sd_hijack, shared
...@@ -15,12 +14,6 @@ def create_embedding(name, initialization_text, nvpt, overwrite_old): ...@@ -15,12 +14,6 @@ def create_embedding(name, initialization_text, nvpt, overwrite_old):
return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", "" return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", ""
def preprocess(*args):
modules.textual_inversion.preprocess.preprocess(*args)
return f"Preprocessing {'interrupted' if shared.state.interrupted else 'finished'}.", ""
def train_embedding(*args): def train_embedding(*args):
assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible' assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible'
......
...@@ -919,71 +919,6 @@ def create_ui(): ...@@ -919,71 +919,6 @@ def create_ui():
with gr.Column(): with gr.Column():
create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork") create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork")
with gr.Tab(label="Preprocess images", id="preprocess_images"):
process_src = gr.Textbox(label='Source directory', elem_id="train_process_src")
process_dst = gr.Textbox(label='Destination directory', elem_id="train_process_dst")
process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_process_width")
process_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_process_height")
preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"], elem_id="train_preprocess_txt_action")
with gr.Row():
process_keep_original_size = gr.Checkbox(label='Keep original size', elem_id="train_process_keep_original_size")
process_flip = gr.Checkbox(label='Create flipped copies', elem_id="train_process_flip")
process_split = gr.Checkbox(label='Split oversized images', elem_id="train_process_split")
process_focal_crop = gr.Checkbox(label='Auto focal point crop', elem_id="train_process_focal_crop")
process_multicrop = gr.Checkbox(label='Auto-sized crop', elem_id="train_process_multicrop")
process_caption = gr.Checkbox(label='Use BLIP for caption', elem_id="train_process_caption")
process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True, elem_id="train_process_caption_deepbooru")
with gr.Row(visible=False) as process_split_extra_row:
process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_split_threshold")
process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05, elem_id="train_process_overlap_ratio")
with gr.Row(visible=False) as process_focal_crop_row:
process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_face_weight")
process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight")
process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight")
process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug")
with gr.Column(visible=False) as process_multicrop_col:
gr.Markdown('Each image is center-cropped with an automatically chosen width and height.')
with gr.Row():
process_multicrop_mindim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension lower bound", value=384, elem_id="train_process_multicrop_mindim")
process_multicrop_maxdim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension upper bound", value=768, elem_id="train_process_multicrop_maxdim")
with gr.Row():
process_multicrop_minarea = gr.Slider(minimum=64*64, maximum=2048*2048, step=1, label="Area lower bound", value=64*64, elem_id="train_process_multicrop_minarea")
process_multicrop_maxarea = gr.Slider(minimum=64*64, maximum=2048*2048, step=1, label="Area upper bound", value=640*640, elem_id="train_process_multicrop_maxarea")
with gr.Row():
process_multicrop_objective = gr.Radio(["Maximize area", "Minimize error"], value="Maximize area", label="Resizing objective", elem_id="train_process_multicrop_objective")
process_multicrop_threshold = gr.Slider(minimum=0, maximum=1, step=0.01, label="Error threshold", value=0.1, elem_id="train_process_multicrop_threshold")
with gr.Row():
with gr.Column(scale=3):
gr.HTML(value="")
with gr.Column():
with gr.Row():
interrupt_preprocessing = gr.Button("Interrupt", elem_id="train_interrupt_preprocessing")
run_preprocess = gr.Button(value="Preprocess", variant='primary', elem_id="train_run_preprocess")
process_split.change(
fn=lambda show: gr_show(show),
inputs=[process_split],
outputs=[process_split_extra_row],
)
process_focal_crop.change(
fn=lambda show: gr_show(show),
inputs=[process_focal_crop],
outputs=[process_focal_crop_row],
)
process_multicrop.change(
fn=lambda show: gr_show(show),
inputs=[process_multicrop],
outputs=[process_multicrop_col],
)
def get_textual_inversion_template_names(): def get_textual_inversion_template_names():
return sorted(textual_inversion.textual_inversion_templates) return sorted(textual_inversion.textual_inversion_templates)
...@@ -1084,42 +1019,6 @@ def create_ui(): ...@@ -1084,42 +1019,6 @@ def create_ui():
] ]
) )
run_preprocess.click(
fn=wrap_gradio_gpu_call(textual_inversion_ui.preprocess, extra_outputs=[gr.update()]),
_js="start_training_textual_inversion",
inputs=[
dummy_component,
process_src,
process_dst,
process_width,
process_height,
preprocess_txt_action,
process_keep_original_size,
process_flip,
process_split,
process_caption,
process_caption_deepbooru,
process_split_threshold,
process_overlap_ratio,
process_focal_crop,
process_focal_crop_face_weight,
process_focal_crop_entropy_weight,
process_focal_crop_edges_weight,
process_focal_crop_debug,
process_multicrop,
process_multicrop_mindim,
process_multicrop_maxdim,
process_multicrop_minarea,
process_multicrop_maxarea,
process_multicrop_objective,
process_multicrop_threshold,
],
outputs=[
ti_output,
ti_outcome,
],
)
train_embedding.click( train_embedding.click(
fn=wrap_gradio_gpu_call(textual_inversion_ui.train_embedding, extra_outputs=[gr.update()]), fn=wrap_gradio_gpu_call(textual_inversion_ui.train_embedding, extra_outputs=[gr.update()]),
_js="start_training_textual_inversion", _js="start_training_textual_inversion",
...@@ -1193,12 +1092,6 @@ def create_ui(): ...@@ -1193,12 +1092,6 @@ def create_ui():
outputs=[], outputs=[],
) )
interrupt_preprocessing.click(
fn=lambda: shared.state.interrupt(),
inputs=[],
outputs=[],
)
loadsave = ui_loadsave.UiLoadsave(cmd_opts.ui_config_file) loadsave = ui_loadsave.UiLoadsave(cmd_opts.ui_config_file)
settings = ui_settings.UiSettings() settings = ui_settings.UiSettings()
......
...@@ -335,6 +335,11 @@ def normalize_git_url(url): ...@@ -335,6 +335,11 @@ def normalize_git_url(url):
return url return url
def get_extension_dirname_from_url(url):
*parts, last_part = url.split('/')
return normalize_git_url(last_part)
def install_extension_from_url(dirname, url, branch_name=None): def install_extension_from_url(dirname, url, branch_name=None):
check_access() check_access()
...@@ -346,10 +351,7 @@ def install_extension_from_url(dirname, url, branch_name=None): ...@@ -346,10 +351,7 @@ def install_extension_from_url(dirname, url, branch_name=None):
assert url, 'No URL specified' assert url, 'No URL specified'
if dirname is None or dirname == "": if dirname is None or dirname == "":
*parts, last_part = url.split('/') dirname = get_extension_dirname_from_url(url)
last_part = normalize_git_url(last_part)
dirname = last_part
target_dir = os.path.join(extensions.extensions_dir, dirname) target_dir = os.path.join(extensions.extensions_dir, dirname)
assert not os.path.exists(target_dir), f'Extension directory already exists: {target_dir}' assert not os.path.exists(target_dir), f'Extension directory already exists: {target_dir}'
...@@ -449,7 +451,8 @@ def get_date(info: dict, key): ...@@ -449,7 +451,8 @@ def get_date(info: dict, key):
def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=""): def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=""):
extlist = available_extensions["extensions"] extlist = available_extensions["extensions"]
installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions} installed_extensions = {extension.name for extension in extensions.extensions}
installed_extension_urls = {normalize_git_url(extension.remote) for extension in extensions.extensions if extension.remote is not None}
tags = available_extensions.get("tags", {}) tags = available_extensions.get("tags", {})
tags_to_hide = set(hide_tags) tags_to_hide = set(hide_tags)
...@@ -482,7 +485,7 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=" ...@@ -482,7 +485,7 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text="
if url is None: if url is None:
continue continue
existing = installed_extension_urls.get(normalize_git_url(url), None) existing = get_extension_dirname_from_url(url) in installed_extensions or normalize_git_url(url) in installed_extension_urls
extension_tags = extension_tags + ["installed"] if existing else extension_tags extension_tags = extension_tags + ["installed"] if existing else extension_tags
if any(x for x in extension_tags if x in tags_to_hide): if any(x for x in extension_tags if x in tags_to_hide):
......
...@@ -151,8 +151,13 @@ class ExtraNetworksPage: ...@@ -151,8 +151,13 @@ class ExtraNetworksPage:
continue continue
subdir = os.path.abspath(x)[len(parentdir):].replace("\\", "/") subdir = os.path.abspath(x)[len(parentdir):].replace("\\", "/")
while subdir.startswith("/"):
subdir = subdir[1:] if shared.opts.extra_networks_dir_button_function:
if not subdir.startswith("/"):
subdir = "/" + subdir
else:
while subdir.startswith("/"):
subdir = subdir[1:]
is_empty = len(os.listdir(x)) == 0 is_empty = len(os.listdir(x)) == 0
if not is_empty and not subdir.endswith("/"): if not is_empty and not subdir.endswith("/"):
......
import gradio as gr import gradio as gr
from modules import scripts, shared, ui_common, postprocessing, call_queue from modules import scripts, shared, ui_common, postprocessing, call_queue, ui_toprow
import modules.generation_parameters_copypaste as parameters_copypaste import modules.generation_parameters_copypaste as parameters_copypaste
def create_ui(): def create_ui():
dummy_component = gr.Label(visible=False)
tab_index = gr.State(value=0) tab_index = gr.State(value=0)
with gr.Row(equal_height=False, variant='compact'): with gr.Row(equal_height=False, variant='compact'):
...@@ -20,11 +21,13 @@ def create_ui(): ...@@ -20,11 +21,13 @@ def create_ui():
extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir") extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir")
show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results") show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results")
submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
script_inputs = scripts.scripts_postproc.setup_ui() script_inputs = scripts.scripts_postproc.setup_ui()
with gr.Column(): with gr.Column():
toprow = ui_toprow.Toprow(is_compact=True, is_img2img=False, id_part="extras")
toprow.create_inline_toprow_image()
submit = toprow.submit
result_images, html_info_x, html_info, html_log = ui_common.create_output_panel("extras", shared.opts.outdir_extras_samples) result_images, html_info_x, html_info, html_log = ui_common.create_output_panel("extras", shared.opts.outdir_extras_samples)
tab_single.select(fn=lambda: 0, inputs=[], outputs=[tab_index]) tab_single.select(fn=lambda: 0, inputs=[], outputs=[tab_index])
...@@ -32,8 +35,10 @@ def create_ui(): ...@@ -32,8 +35,10 @@ def create_ui():
tab_batch_dir.select(fn=lambda: 2, inputs=[], outputs=[tab_index]) tab_batch_dir.select(fn=lambda: 2, inputs=[], outputs=[tab_index])
submit.click( submit.click(
fn=call_queue.wrap_gradio_gpu_call(postprocessing.run_postprocessing, extra_outputs=[None, '']), fn=call_queue.wrap_gradio_gpu_call(postprocessing.run_postprocessing_webui, extra_outputs=[None, '']),
_js="submit_extras",
inputs=[ inputs=[
dummy_component,
tab_index, tab_index,
extras_image, extras_image,
image_batch, image_batch,
...@@ -45,8 +50,9 @@ def create_ui(): ...@@ -45,8 +50,9 @@ def create_ui():
outputs=[ outputs=[
result_images, result_images,
html_info_x, html_info_x,
html_info, html_log,
] ],
show_progress=False,
) )
parameters_copypaste.add_paste_fields("extras", extras_image, None) parameters_copypaste.add_paste_fields("extras", extras_image, None)
......
...@@ -34,8 +34,10 @@ class Toprow: ...@@ -34,8 +34,10 @@ class Toprow:
submit_box = None submit_box = None
def __init__(self, is_img2img, is_compact=False): def __init__(self, is_img2img, is_compact=False, id_part=None):
id_part = "img2img" if is_img2img else "txt2img" if id_part is None:
id_part = "img2img" if is_img2img else "txt2img"
self.id_part = id_part self.id_part = id_part
self.is_img2img = is_img2img self.is_img2img = is_img2img
self.is_compact = is_compact self.is_compact = is_compact
......
...@@ -57,6 +57,9 @@ class Upscaler: ...@@ -57,6 +57,9 @@ class Upscaler:
dest_h = int((img.height * scale) // 8 * 8) dest_h = int((img.height * scale) // 8 * 8)
for _ in range(3): for _ in range(3):
if img.width >= dest_w and img.height >= dest_h:
break
shape = (img.width, img.height) shape = (img.width, img.height)
img = self.do_upscale(img, selected_model) img = self.do_upscale(img, selected_model)
...@@ -64,9 +67,6 @@ class Upscaler: ...@@ -64,9 +67,6 @@ class Upscaler:
if shape == (img.width, img.height): if shape == (img.width, img.height):
break break
if img.width >= dest_w and img.height >= dest_h:
break
if img.width != dest_w or img.height != dest_h: if img.width != dest_w or img.height != dest_h:
img = img.resize((int(dest_w), int(dest_h)), resample=LANCZOS) img = img.resize((int(dest_w), int(dest_h)), resample=LANCZOS)
......
from modules import shared
from modules.sd_hijack_utils import CondFunc
has_ipex = False
try:
import torch
import intel_extension_for_pytorch as ipex # noqa: F401
has_ipex = True
except Exception:
pass
def check_for_xpu():
return has_ipex and hasattr(torch, 'xpu') and torch.xpu.is_available()
def get_xpu_device_string():
if shared.cmd_opts.device_id is not None:
return f"xpu:{shared.cmd_opts.device_id}"
return "xpu"
def torch_xpu_gc():
with torch.xpu.device(get_xpu_device_string()):
torch.xpu.empty_cache()
has_xpu = check_for_xpu()
if has_xpu:
# W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device
CondFunc('torch.Generator',
lambda orig_func, device=None: torch.xpu.Generator(device),
lambda orig_func, device=None: device is not None and device.type == "xpu")
# W/A for some OPs that could not handle different input dtypes
CondFunc('torch.nn.functional.layer_norm',
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs),
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
weight is not None and input.dtype != weight.data.dtype)
CondFunc('torch.nn.modules.GroupNorm.forward',
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
CondFunc('torch.nn.modules.linear.Linear.forward',
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
CondFunc('torch.nn.modules.conv.Conv2d.forward',
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
from modules import scripts_postprocessing, ui_components, deepbooru, shared
import gradio as gr
class ScriptPostprocessingCeption(scripts_postprocessing.ScriptPostprocessing):
name = "Caption"
order = 4000
def ui(self):
with ui_components.InputAccordion(False, label="Caption") as enable:
option = gr.CheckboxGroup(value=["Deepbooru"], choices=["Deepbooru", "BLIP"], show_label=False)
return {
"enable": enable,
"option": option,
}
def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, option):
if not enable:
return
captions = [pp.caption]
if "Deepbooru" in option:
captions.append(deepbooru.model.tag(pp.image))
if "BLIP" in option:
captions.append(shared.interrogator.generate_caption(pp.image))
pp.caption = ", ".join([x for x in captions if x])
from PIL import Image from PIL import Image
import numpy as np import numpy as np
from modules import scripts_postprocessing, codeformer_model from modules import scripts_postprocessing, codeformer_model, ui_components
import gradio as gr import gradio as gr
from modules.ui_components import FormRow
class ScriptPostprocessingCodeFormer(scripts_postprocessing.ScriptPostprocessing): class ScriptPostprocessingCodeFormer(scripts_postprocessing.ScriptPostprocessing):
name = "CodeFormer" name = "CodeFormer"
order = 3000 order = 3000
def ui(self): def ui(self):
with FormRow(): with ui_components.InputAccordion(False, label="CodeFormer") as enable:
codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, elem_id="extras_codeformer_visibility") with gr.Row():
codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, elem_id="extras_codeformer_weight") codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Visibility", value=1.0, elem_id="extras_codeformer_visibility")
codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Weight (0 = maximum effect, 1 = minimum effect)", value=0, elem_id="extras_codeformer_weight")
return { return {
"enable": enable,
"codeformer_visibility": codeformer_visibility, "codeformer_visibility": codeformer_visibility,
"codeformer_weight": codeformer_weight, "codeformer_weight": codeformer_weight,
} }
def process(self, pp: scripts_postprocessing.PostprocessedImage, codeformer_visibility, codeformer_weight): def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, codeformer_visibility, codeformer_weight):
if codeformer_visibility == 0: if codeformer_visibility == 0 or not enable:
return return
restored_img = codeformer_model.codeformer.restore(np.array(pp.image, dtype=np.uint8), w=codeformer_weight) restored_img = codeformer_model.codeformer.restore(np.array(pp.image, dtype=np.uint8), w=codeformer_weight)
......
from PIL import ImageOps, Image
from modules import scripts_postprocessing, ui_components
import gradio as gr
class ScriptPostprocessingCreateFlippedCopies(scripts_postprocessing.ScriptPostprocessing):
name = "Create flipped copies"
order = 4000
def ui(self):
with ui_components.InputAccordion(False, label="Create flipped copies") as enable:
with gr.Row():
option = gr.CheckboxGroup(value=["Horizontal"], choices=["Horizontal", "Vertical", "Both"], show_label=False)
return {
"enable": enable,
"option": option,
}
def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, option):
if not enable:
return
if "Horizontal" in option:
pp.extra_images.append(ImageOps.mirror(pp.image))
if "Vertical" in option:
pp.extra_images.append(pp.image.transpose(Image.Transpose.FLIP_TOP_BOTTOM))
if "Both" in option:
pp.extra_images.append(pp.image.transpose(Image.Transpose.FLIP_TOP_BOTTOM).transpose(Image.Transpose.FLIP_LEFT_RIGHT))
from modules import scripts_postprocessing, ui_components, errors
import gradio as gr
from modules.textual_inversion import autocrop
class ScriptPostprocessingFocalCrop(scripts_postprocessing.ScriptPostprocessing):
name = "Auto focal point crop"
order = 4000
def ui(self):
with ui_components.InputAccordion(False, label="Auto focal point crop") as enable:
face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="postprocess_focal_crop_face_weight")
entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="postprocess_focal_crop_entropy_weight")
edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="postprocess_focal_crop_edges_weight")
debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug")
return {
"enable": enable,
"face_weight": face_weight,
"entropy_weight": entropy_weight,
"edges_weight": edges_weight,
"debug": debug,
}
def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, face_weight, entropy_weight, edges_weight, debug):
if not enable:
return
if not pp.shared.target_width or not pp.shared.target_height:
return
dnn_model_path = None
try:
dnn_model_path = autocrop.download_and_cache_models()
except Exception:
errors.report("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", exc_info=True)
autocrop_settings = autocrop.Settings(
crop_width=pp.shared.target_width,
crop_height=pp.shared.target_height,
face_points_weight=face_weight,
entropy_points_weight=entropy_weight,
corner_points_weight=edges_weight,
annotate_image=debug,
dnn_model_path=dnn_model_path,
)
result, *others = autocrop.crop_image(pp.image, autocrop_settings)
pp.image = result
pp.extra_images = [pp.create_copy(x, nametags=["focal-crop-debug"], disable_processing=True) for x in others]
from PIL import Image from PIL import Image
import numpy as np import numpy as np
from modules import scripts_postprocessing, gfpgan_model from modules import scripts_postprocessing, gfpgan_model, ui_components
import gradio as gr import gradio as gr
from modules.ui_components import FormRow
class ScriptPostprocessingGfpGan(scripts_postprocessing.ScriptPostprocessing): class ScriptPostprocessingGfpGan(scripts_postprocessing.ScriptPostprocessing):
name = "GFPGAN" name = "GFPGAN"
order = 2000 order = 2000
def ui(self): def ui(self):
with FormRow(): with ui_components.InputAccordion(False, label="GFPGAN") as enable:
gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, elem_id="extras_gfpgan_visibility") gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Visibility", value=1.0, elem_id="extras_gfpgan_visibility")
return { return {
"enable": enable,
"gfpgan_visibility": gfpgan_visibility, "gfpgan_visibility": gfpgan_visibility,
} }
def process(self, pp: scripts_postprocessing.PostprocessedImage, gfpgan_visibility): def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, gfpgan_visibility):
if gfpgan_visibility == 0: if gfpgan_visibility == 0 or not enable:
return return
restored_img = gfpgan_model.gfpgan_fix_faces(np.array(pp.image, dtype=np.uint8)) restored_img = gfpgan_model.gfpgan_fix_faces(np.array(pp.image, dtype=np.uint8))
......
import math
from modules import scripts_postprocessing, ui_components
import gradio as gr
def split_pic(image, inverse_xy, width, height, overlap_ratio):
if inverse_xy:
from_w, from_h = image.height, image.width
to_w, to_h = height, width
else:
from_w, from_h = image.width, image.height
to_w, to_h = width, height
h = from_h * to_w // from_w
if inverse_xy:
image = image.resize((h, to_w))
else:
image = image.resize((to_w, h))
split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio)))
y_step = (h - to_h) / (split_count - 1)
for i in range(split_count):
y = int(y_step * i)
if inverse_xy:
splitted = image.crop((y, 0, y + to_h, to_w))
else:
splitted = image.crop((0, y, to_w, y + to_h))
yield splitted
class ScriptPostprocessingSplitOversized(scripts_postprocessing.ScriptPostprocessing):
name = "Split oversized images"
order = 4000
def ui(self):
with ui_components.InputAccordion(False, label="Split oversized images") as enable:
with gr.Row():
split_threshold = gr.Slider(label='Threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="postprocess_split_threshold")
overlap_ratio = gr.Slider(label='Overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05, elem_id="postprocess_overlap_ratio")
return {
"enable": enable,
"split_threshold": split_threshold,
"overlap_ratio": overlap_ratio,
}
def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, split_threshold, overlap_ratio):
if not enable:
return
width = pp.shared.target_width
height = pp.shared.target_height
if not width or not height:
return
if pp.image.height > pp.image.width:
ratio = (pp.image.width * height) / (pp.image.height * width)
inverse_xy = False
else:
ratio = (pp.image.height * width) / (pp.image.width * height)
inverse_xy = True
if ratio >= 1.0 and ratio > split_threshold:
return
result, *others = split_pic(pp.image, inverse_xy, width, height, overlap_ratio)
pp.image = result
pp.extra_images = [pp.create_copy(x) for x in others]
...@@ -81,6 +81,14 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing): ...@@ -81,6 +81,14 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
return image return image
def process_firstpass(self, pp: scripts_postprocessing.PostprocessedImage, upscale_mode=1, upscale_by=2.0, upscale_to_width=None, upscale_to_height=None, upscale_crop=False, upscaler_1_name=None, upscaler_2_name=None, upscaler_2_visibility=0.0):
if upscale_mode == 1:
pp.shared.target_width = upscale_to_width
pp.shared.target_height = upscale_to_height
else:
pp.shared.target_width = int(pp.image.width * upscale_by)
pp.shared.target_height = int(pp.image.height * upscale_by)
def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_mode=1, upscale_by=2.0, upscale_to_width=None, upscale_to_height=None, upscale_crop=False, upscaler_1_name=None, upscaler_2_name=None, upscaler_2_visibility=0.0): def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_mode=1, upscale_by=2.0, upscale_to_width=None, upscale_to_height=None, upscale_crop=False, upscaler_1_name=None, upscaler_2_name=None, upscaler_2_visibility=0.0):
if upscaler_1_name == "None": if upscaler_1_name == "None":
upscaler_1_name = None upscaler_1_name = None
...@@ -126,6 +134,10 @@ class ScriptPostprocessingUpscaleSimple(ScriptPostprocessingUpscale): ...@@ -126,6 +134,10 @@ class ScriptPostprocessingUpscaleSimple(ScriptPostprocessingUpscale):
"upscaler_name": upscaler_name, "upscaler_name": upscaler_name,
} }
def process_firstpass(self, pp: scripts_postprocessing.PostprocessedImage, upscale_by=2.0, upscaler_name=None):
pp.shared.target_width = int(pp.image.width * upscale_by)
pp.shared.target_height = int(pp.image.height * upscale_by)
def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_by=2.0, upscaler_name=None): def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_by=2.0, upscaler_name=None):
if upscaler_name is None or upscaler_name == "None": if upscaler_name is None or upscaler_name == "None":
return return
......
from PIL import Image
from modules import scripts_postprocessing, ui_components
import gradio as gr
def center_crop(image: Image, w: int, h: int):
iw, ih = image.size
if ih / h < iw / w:
sw = w * ih / h
box = (iw - sw) / 2, 0, iw - (iw - sw) / 2, ih
else:
sh = h * iw / w
box = 0, (ih - sh) / 2, iw, ih - (ih - sh) / 2
return image.resize((w, h), Image.Resampling.LANCZOS, box)
def multicrop_pic(image: Image, mindim, maxdim, minarea, maxarea, objective, threshold):
iw, ih = image.size
err = lambda w, h: 1 - (lambda x: x if x < 1 else 1 / x)(iw / ih / (w / h))
wh = max(((w, h) for w in range(mindim, maxdim + 1, 64) for h in range(mindim, maxdim + 1, 64)
if minarea <= w * h <= maxarea and err(w, h) <= threshold),
key=lambda wh: (wh[0] * wh[1], -err(*wh))[::1 if objective == 'Maximize area' else -1],
default=None
)
return wh and center_crop(image, *wh)
class ScriptPostprocessingAutosizedCrop(scripts_postprocessing.ScriptPostprocessing):
name = "Auto-sized crop"
order = 4000
def ui(self):
with ui_components.InputAccordion(False, label="Auto-sized crop") as enable:
gr.Markdown('Each image is center-cropped with an automatically chosen width and height.')
with gr.Row():
mindim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension lower bound", value=384, elem_id="postprocess_multicrop_mindim")
maxdim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension upper bound", value=768, elem_id="postprocess_multicrop_maxdim")
with gr.Row():
minarea = gr.Slider(minimum=64 * 64, maximum=2048 * 2048, step=1, label="Area lower bound", value=64 * 64, elem_id="postprocess_multicrop_minarea")
maxarea = gr.Slider(minimum=64 * 64, maximum=2048 * 2048, step=1, label="Area upper bound", value=640 * 640, elem_id="postprocess_multicrop_maxarea")
with gr.Row():
objective = gr.Radio(["Maximize area", "Minimize error"], value="Maximize area", label="Resizing objective", elem_id="postprocess_multicrop_objective")
threshold = gr.Slider(minimum=0, maximum=1, step=0.01, label="Error threshold", value=0.1, elem_id="postprocess_multicrop_threshold")
return {
"enable": enable,
"mindim": mindim,
"maxdim": maxdim,
"minarea": minarea,
"maxarea": maxarea,
"objective": objective,
"threshold": threshold,
}
def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, mindim, maxdim, minarea, maxarea, objective, threshold):
if not enable:
return
cropped = multicrop_pic(pp.image, mindim, maxdim, minarea, maxarea, objective, threshold)
if cropped is not None:
pp.image = cropped
else:
print(f"skipped {pp.image.width}x{pp.image.height} image (can't find suitable size within error threshold)")
...@@ -646,6 +646,8 @@ table.popup-table .link{ ...@@ -646,6 +646,8 @@ table.popup-table .link{
margin: auto; margin: auto;
padding: 2em; padding: 2em;
z-index: 1001; z-index: 1001;
max-height: 90%;
max-width: 90%;
} }
/* fullpage image viewer */ /* fullpage image viewer */
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment