Commit 501d4e9c authored by unknown's avatar unknown

Merge branch 'master' of github.com:AUTOMATIC1111/stable-diffusion-webui into gamepad

parents 5e1f4f74 ea9bd9fc
...@@ -66,8 +66,8 @@ titles = { ...@@ -66,8 +66,8 @@ titles = {
"Interrogate": "Reconstruct prompt from existing image and put it into the prompt field.", "Interrogate": "Reconstruct prompt from existing image and put it into the prompt field.",
"Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.", "Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.",
"Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.", "Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg],[prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.",
"Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle", "Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle",
"Loopback": "Process an image, use it as an input, repeat.", "Loopback": "Process an image, use it as an input, repeat.",
......
import sys, os, shlex import sys
import contextlib import contextlib
import torch import torch
from modules import errors from modules import errors
from modules.sd_hijack_utils import CondFunc
from packaging import version if sys.platform == "darwin":
from modules import mac_specific
# has_mps is only available in nightly pytorch (for now) and macOS 12.3+.
# check `getattr` and try it for compatibility
def has_mps() -> bool: def has_mps() -> bool:
if not getattr(torch, 'has_mps', False): if sys.platform != "darwin":
return False
try:
torch.zeros(1).to(torch.device("mps"))
return True
except Exception:
return False return False
else:
return mac_specific.has_mps
def extract_device_id(args, name): def extract_device_id(args, name):
for x in range(len(args)): for x in range(len(args)):
...@@ -155,36 +150,3 @@ def test_for_nans(x, where): ...@@ -155,36 +150,3 @@ def test_for_nans(x, where):
message += " Use --disable-nan-check commandline argument to disable this check." message += " Use --disable-nan-check commandline argument to disable this check."
raise NansException(message) raise NansException(message)
# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
def cumsum_fix(input, cumsum_func, *args, **kwargs):
if input.device.type == 'mps':
output_dtype = kwargs.get('dtype', input.dtype)
if output_dtype == torch.int64:
return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
elif cumsum_needs_bool_fix and output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
return cumsum_func(input, *args, **kwargs)
if has_mps():
if version.parse(torch.__version__) < version.parse("1.13"):
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),
lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))
# MPS workaround for https://github.com/pytorch/pytorch/issues/80800
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),
lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')
# MPS workaround for https://github.com/pytorch/pytorch/issues/90532
CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad)
elif version.parse(torch.__version__) > version.parse("1.13.1"):
cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0))
cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs)
CondFunc('torch.cumsum', cumsum_fix_func, None)
CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
...@@ -4,6 +4,7 @@ import os.path ...@@ -4,6 +4,7 @@ import os.path
import filelock import filelock
from modules import shared
from modules.paths import data_path from modules.paths import data_path
...@@ -68,6 +69,9 @@ def sha256(filename, title): ...@@ -68,6 +69,9 @@ def sha256(filename, title):
if sha256_value is not None: if sha256_value is not None:
return sha256_value return sha256_value
if shared.cmd_opts.no_hashing:
return None
print(f"Calculating sha256 for {filename}: ", end='') print(f"Calculating sha256 for {filename}: ", end='')
sha256_value = calculate_sha256(filename) sha256_value = calculate_sha256(filename)
print(f"{sha256_value}") print(f"{sha256_value}")
......
...@@ -307,7 +307,7 @@ class Hypernetwork: ...@@ -307,7 +307,7 @@ class Hypernetwork:
def shorthash(self): def shorthash(self):
sha256 = hashes.sha256(self.filename, f'hypernet/{self.name}') sha256 = hashes.sha256(self.filename, f'hypernet/{self.name}')
return sha256[0:10] return sha256[0:10] if sha256 else None
def list_hypernetworks(path): def list_hypernetworks(path):
......
...@@ -16,6 +16,7 @@ from PIL import Image, ImageFont, ImageDraw, PngImagePlugin ...@@ -16,6 +16,7 @@ from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
from fonts.ttf import Roboto from fonts.ttf import Roboto
import string import string
import json import json
import hashlib
from modules import sd_samplers, shared, script_callbacks from modules import sd_samplers, shared, script_callbacks
from modules.shared import opts, cmd_opts from modules.shared import opts, cmd_opts
...@@ -130,7 +131,7 @@ class GridAnnotation: ...@@ -130,7 +131,7 @@ class GridAnnotation:
self.size = None self.size = None
def draw_grid_annotations(im, width, height, hor_texts, ver_texts): def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
def wrap(drawing, text, font, line_length): def wrap(drawing, text, font, line_length):
lines = [''] lines = ['']
for word in text.split(): for word in text.split():
...@@ -194,32 +195,35 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts): ...@@ -194,32 +195,35 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
line.allowed_width = allowed_width line.allowed_width = allowed_width
hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts] hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts]
ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in ver_texts]
ver_texts]
pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2 pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2
result = Image.new("RGB", (im.width + pad_left, im.height + pad_top), "white") result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + margin * (rows-1)), "white")
result.paste(im, (pad_left, pad_top))
for row in range(rows):
for col in range(cols):
cell = im.crop((width * col, height * row, width * (col+1), height * (row+1)))
result.paste(cell, (pad_left + (width + margin) * col, pad_top + (height + margin) * row))
d = ImageDraw.Draw(result) d = ImageDraw.Draw(result)
for col in range(cols): for col in range(cols):
x = pad_left + width * col + width / 2 x = pad_left + (width + margin) * col + width / 2
y = pad_top / 2 - hor_text_heights[col] / 2 y = pad_top / 2 - hor_text_heights[col] / 2
draw_texts(d, x, y, hor_texts[col], fnt, fontsize) draw_texts(d, x, y, hor_texts[col], fnt, fontsize)
for row in range(rows): for row in range(rows):
x = pad_left / 2 x = pad_left / 2
y = pad_top + height * row + height / 2 - ver_text_heights[row] / 2 y = pad_top + (height + margin) * row + height / 2 - ver_text_heights[row] / 2
draw_texts(d, x, y, ver_texts[row], fnt, fontsize) draw_texts(d, x, y, ver_texts[row], fnt, fontsize)
return result return result
def draw_prompt_matrix(im, width, height, all_prompts): def draw_prompt_matrix(im, width, height, all_prompts, margin=0):
prompts = all_prompts[1:] prompts = all_prompts[1:]
boundary = math.ceil(len(prompts) / 2) boundary = math.ceil(len(prompts) / 2)
...@@ -229,7 +233,7 @@ def draw_prompt_matrix(im, width, height, all_prompts): ...@@ -229,7 +233,7 @@ def draw_prompt_matrix(im, width, height, all_prompts):
hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in range(1 << len(prompts_horiz))] hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in range(1 << len(prompts_horiz))]
ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in range(1 << len(prompts_vert))] ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in range(1 << len(prompts_vert))]
return draw_grid_annotations(im, width, height, hor_texts, ver_texts) return draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin)
def resize_image(resize_mode, im, width, height, upscaler_name=None): def resize_image(resize_mode, im, width, height, upscaler_name=None):
...@@ -340,6 +344,7 @@ class FilenameGenerator: ...@@ -340,6 +344,7 @@ class FilenameGenerator:
'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'), 'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>] 'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]
'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp), 'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp),
'prompt_hash': lambda self: hashlib.sha256(self.prompt.encode()).hexdigest()[0:8],
'prompt': lambda self: sanitize_filename_part(self.prompt), 'prompt': lambda self: sanitize_filename_part(self.prompt),
'prompt_no_styles': lambda self: self.prompt_no_style(), 'prompt_no_styles': lambda self: self.prompt_no_style(),
'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False), 'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
......
...@@ -76,7 +76,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args): ...@@ -76,7 +76,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
processed_image.save(os.path.join(output_dir, filename)) processed_image.save(os.path.join(output_dir, filename))
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args): def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args):
override_settings = create_override_settings_dict(override_settings_texts) override_settings = create_override_settings_dict(override_settings_texts)
is_batch = mode == 5 is_batch = mode == 5
...@@ -142,6 +142,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s ...@@ -142,6 +142,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
inpainting_fill=inpainting_fill, inpainting_fill=inpainting_fill,
resize_mode=resize_mode, resize_mode=resize_mode,
denoising_strength=denoising_strength, denoising_strength=denoising_strength,
image_cfg_scale=image_cfg_scale,
inpaint_full_res=inpaint_full_res, inpaint_full_res=inpaint_full_res,
inpaint_full_res_padding=inpaint_full_res_padding, inpaint_full_res_padding=inpaint_full_res_padding,
inpainting_mask_invert=inpainting_mask_invert, inpainting_mask_invert=inpainting_mask_invert,
......
import torch
from modules import paths
from modules.sd_hijack_utils import CondFunc
from packaging import version
# has_mps is only available in nightly pytorch (for now) and macOS 12.3+.
# check `getattr` and try it for compatibility
def check_for_mps() -> bool:
if not getattr(torch, 'has_mps', False):
return False
try:
torch.zeros(1).to(torch.device("mps"))
return True
except Exception:
return False
has_mps = check_for_mps()
# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
def cumsum_fix(input, cumsum_func, *args, **kwargs):
if input.device.type == 'mps':
output_dtype = kwargs.get('dtype', input.dtype)
if output_dtype == torch.int64:
return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
elif cumsum_needs_bool_fix and output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
return cumsum_func(input, *args, **kwargs)
if has_mps:
# MPS fix for randn in torchsde
CondFunc('torchsde._brownian.brownian_interval._randn', lambda _, size, dtype, device, seed: torch.randn(size, dtype=dtype, device=torch.device("cpu"), generator=torch.Generator(torch.device("cpu")).manual_seed(int(seed))).to(device), lambda _, size, dtype, device, seed: device.type == 'mps')
if version.parse(torch.__version__) < version.parse("1.13"):
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),
lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))
# MPS workaround for https://github.com/pytorch/pytorch/issues/80800
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),
lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')
# MPS workaround for https://github.com/pytorch/pytorch/issues/90532
CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad)
elif version.parse(torch.__version__) > version.parse("1.13.1"):
cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0))
cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs)
CondFunc('torch.cumsum', cumsum_fix_func, None)
CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
...@@ -45,6 +45,9 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None ...@@ -45,6 +45,9 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
full_path = file full_path = file
if os.path.isdir(full_path): if os.path.isdir(full_path):
continue continue
if os.path.islink(full_path) and not os.path.exists(full_path):
print(f"Skipping broken symlink: {full_path}")
continue
if ext_blacklist is not None and any([full_path.endswith(x) for x in ext_blacklist]): if ext_blacklist is not None and any([full_path.endswith(x) for x in ext_blacklist]):
continue continue
if len(ext_filter) != 0: if len(ext_filter) != 0:
......
...@@ -186,7 +186,7 @@ class StableDiffusionProcessing: ...@@ -186,7 +186,7 @@ class StableDiffusionProcessing:
return conditioning return conditioning
def edit_image_conditioning(self, source_image): def edit_image_conditioning(self, source_image):
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image)) conditioning_image = self.sd_model.encode_first_stage(source_image).mode()
return conditioning_image return conditioning_image
...@@ -268,6 +268,7 @@ class Processed: ...@@ -268,6 +268,7 @@ class Processed:
self.height = p.height self.height = p.height
self.sampler_name = p.sampler_name self.sampler_name = p.sampler_name
self.cfg_scale = p.cfg_scale self.cfg_scale = p.cfg_scale
self.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
self.steps = p.steps self.steps = p.steps
self.batch_size = p.batch_size self.batch_size = p.batch_size
self.restore_faces = p.restore_faces self.restore_faces = p.restore_faces
...@@ -445,6 +446,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter ...@@ -445,6 +446,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Steps": p.steps, "Steps": p.steps,
"Sampler": p.sampler_name, "Sampler": p.sampler_name,
"CFG scale": p.cfg_scale, "CFG scale": p.cfg_scale,
"Image CFG scale": getattr(p, 'image_cfg_scale', None),
"Seed": all_seeds[index], "Seed": all_seeds[index],
"Face restoration": (opts.face_restoration_model if p.restore_faces else None), "Face restoration": (opts.face_restoration_model if p.restore_faces else None),
"Size": f"{p.width}x{p.height}", "Size": f"{p.width}x{p.height}",
...@@ -901,12 +903,13 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): ...@@ -901,12 +903,13 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
sampler = None sampler = None
def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, mask: Any = None, mask_blur: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs): def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, image_cfg_scale: float = None, mask: Any = None, mask_blur: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.init_images = init_images self.init_images = init_images
self.resize_mode: int = resize_mode self.resize_mode: int = resize_mode
self.denoising_strength: float = denoising_strength self.denoising_strength: float = denoising_strength
self.image_cfg_scale: float = image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None
self.init_latent = None self.init_latent = None
self.image_mask = mask self.image_mask = mask
self.latent_mask = None self.latent_mask = None
......
...@@ -20,8 +20,9 @@ class DisableInitialization: ...@@ -20,8 +20,9 @@ class DisableInitialization:
``` ```
""" """
def __init__(self): def __init__(self, disable_clip=True):
self.replaced = [] self.replaced = []
self.disable_clip = disable_clip
def replace(self, obj, field, func): def replace(self, obj, field, func):
original = getattr(obj, field, None) original = getattr(obj, field, None)
...@@ -75,6 +76,8 @@ class DisableInitialization: ...@@ -75,6 +76,8 @@ class DisableInitialization:
self.replace(torch.nn.init, 'kaiming_uniform_', do_nothing) self.replace(torch.nn.init, 'kaiming_uniform_', do_nothing)
self.replace(torch.nn.init, '_no_grad_normal_', do_nothing) self.replace(torch.nn.init, '_no_grad_normal_', do_nothing)
self.replace(torch.nn.init, '_no_grad_uniform_', do_nothing) self.replace(torch.nn.init, '_no_grad_uniform_', do_nothing)
if self.disable_clip:
self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained) self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained)
self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained) self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained)
self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model) self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model)
......
...@@ -59,13 +59,17 @@ class CheckpointInfo: ...@@ -59,13 +59,17 @@ class CheckpointInfo:
def calculate_shorthash(self): def calculate_shorthash(self):
self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.name) self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.name)
if self.sha256 is None:
return
self.shorthash = self.sha256[0:10] self.shorthash = self.sha256[0:10]
if self.shorthash not in self.ids: if self.shorthash not in self.ids:
self.ids += [self.shorthash, self.sha256] self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]']
self.register()
checkpoints_list.pop(self.title)
self.title = f'{self.name} [{self.shorthash}]' self.title = f'{self.name} [{self.shorthash}]'
self.register()
return self.shorthash return self.shorthash
...@@ -158,7 +162,7 @@ def select_checkpoint(): ...@@ -158,7 +162,7 @@ def select_checkpoint():
print(f" - directory {model_path}", file=sys.stderr) print(f" - directory {model_path}", file=sys.stderr)
if shared.cmd_opts.ckpt_dir is not None: if shared.cmd_opts.ckpt_dir is not None:
print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr) print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr)
print("Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr) print("Can't run without a checkpoint. Find and place a .ckpt or .safetensors file into any of those locations. The program will exit.", file=sys.stderr)
exit(1) exit(1)
checkpoint_info = next(iter(checkpoints_list.values())) checkpoint_info = next(iter(checkpoints_list.values()))
...@@ -350,6 +354,9 @@ def repair_config(sd_config): ...@@ -350,6 +354,9 @@ def repair_config(sd_config):
sd_config.model.params.unet_config.params.use_fp16 = True sd_config.model.params.unet_config.params.use_fp16 = True
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None): def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None):
from modules import lowvram, sd_hijack from modules import lowvram, sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint() checkpoint_info = checkpoint_info or select_checkpoint()
...@@ -370,6 +377,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_ ...@@ -370,6 +377,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_
state_dict = get_checkpoint_state_dict(checkpoint_info, timer) state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
clip_is_included_into_sd = sd1_clip_weight in state_dict or sd2_clip_weight in state_dict
timer.record("find config") timer.record("find config")
...@@ -382,7 +390,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_ ...@@ -382,7 +390,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_
sd_model = None sd_model = None
try: try:
with sd_disable_initialization.DisableInitialization(): with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd):
sd_model = instantiate_from_config(sd_config.model) sd_model = instantiate_from_config(sd_config.model)
except Exception as e: except Exception as e:
pass pass
......
...@@ -2,7 +2,6 @@ from collections import namedtuple ...@@ -2,7 +2,6 @@ from collections import namedtuple
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
import torchsde._brownian.brownian_interval
from modules import devices, processing, images, sd_vae_approx from modules import devices, processing, images, sd_vae_approx
from modules.shared import opts, state from modules.shared import opts, state
...@@ -61,18 +60,3 @@ def store_latent(decoded): ...@@ -61,18 +60,3 @@ def store_latent(decoded):
class InterruptedException(BaseException): class InterruptedException(BaseException):
pass pass
# MPS fix for randn in torchsde
# XXX move this to separate file for MPS
def torchsde_randn(size, dtype, device, seed):
if device.type == 'mps':
generator = torch.Generator(devices.cpu).manual_seed(int(seed))
return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
else:
generator = torch.Generator(device).manual_seed(int(seed))
return torch.randn(size, dtype=dtype, device=device, generator=generator)
torchsde._brownian.brownian_interval._randn = torchsde_randn
from collections import deque from collections import deque
import torch import torch
import inspect import inspect
import einops
import k_diffusion.sampling import k_diffusion.sampling
from modules import prompt_parser, devices, sd_samplers_common from modules import prompt_parser, devices, sd_samplers_common
...@@ -56,6 +57,7 @@ class CFGDenoiser(torch.nn.Module): ...@@ -56,6 +57,7 @@ class CFGDenoiser(torch.nn.Module):
self.nmask = None self.nmask = None
self.init_latent = None self.init_latent = None
self.step = 0 self.step = 0
self.image_cfg_scale = None
def combine_denoised(self, x_out, conds_list, uncond, cond_scale): def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
denoised_uncond = x_out[-uncond.shape[0]:] denoised_uncond = x_out[-uncond.shape[0]:]
...@@ -67,19 +69,36 @@ class CFGDenoiser(torch.nn.Module): ...@@ -67,19 +69,36 @@ class CFGDenoiser(torch.nn.Module):
return denoised return denoised
def combine_denoised_for_edit_model(self, x_out, cond_scale):
out_cond, out_img_cond, out_uncond = x_out.chunk(3)
denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond)
return denoised
def forward(self, x, sigma, uncond, cond, cond_scale, image_cond): def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
if state.interrupted or state.skipped: if state.interrupted or state.skipped:
raise sd_samplers_common.InterruptedException raise sd_samplers_common.InterruptedException
# at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
# so is_edit_model is set to False to support AND composition.
is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
assert not is_edit_model or all([len(conds) == 1 for conds in conds_list]), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
batch_size = len(conds_list) batch_size = len(conds_list)
repeats = [len(conds_list[i]) for i in range(batch_size)] repeats = [len(conds_list[i]) for i in range(batch_size)]
if not is_edit_model:
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma]) sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
else:
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond] + [torch.zeros_like(self.init_latent)])
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps) denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps)
cfg_denoiser_callback(denoiser_params) cfg_denoiser_callback(denoiser_params)
...@@ -88,7 +107,10 @@ class CFGDenoiser(torch.nn.Module): ...@@ -88,7 +107,10 @@ class CFGDenoiser(torch.nn.Module):
sigma_in = denoiser_params.sigma sigma_in = denoiser_params.sigma
if tensor.shape[1] == uncond.shape[1]: if tensor.shape[1] == uncond.shape[1]:
if not is_edit_model:
cond_in = torch.cat([tensor, uncond]) cond_in = torch.cat([tensor, uncond])
else:
cond_in = torch.cat([tensor, uncond, uncond])
if shared.batch_cond_uncond: if shared.batch_cond_uncond:
x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]}) x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]})
...@@ -104,7 +126,13 @@ class CFGDenoiser(torch.nn.Module): ...@@ -104,7 +126,13 @@ class CFGDenoiser(torch.nn.Module):
for batch_offset in range(0, tensor.shape[0], batch_size): for batch_offset in range(0, tensor.shape[0], batch_size):
a = batch_offset a = batch_offset
b = min(a + batch_size, tensor.shape[0]) b = min(a + batch_size, tensor.shape[0])
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [tensor[a:b]], "c_concat": [image_cond_in[a:b]]})
if not is_edit_model:
c_crossattn = [tensor[a:b]]
else:
c_crossattn = torch.cat([tensor[a:b]], uncond)
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": c_crossattn, "c_concat": [image_cond_in[a:b]]})
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]}) x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
...@@ -115,7 +143,10 @@ class CFGDenoiser(torch.nn.Module): ...@@ -115,7 +143,10 @@ class CFGDenoiser(torch.nn.Module):
elif opts.live_preview_content == "Negative prompt": elif opts.live_preview_content == "Negative prompt":
sd_samplers_common.store_latent(x_out[-uncond.shape[0]:]) sd_samplers_common.store_latent(x_out[-uncond.shape[0]:])
if not is_edit_model:
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale) denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
else:
denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
if self.mask is not None: if self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised denoised = self.init_latent * self.mask + self.nmask * denoised
...@@ -198,6 +229,7 @@ class KDiffusionSampler: ...@@ -198,6 +229,7 @@ class KDiffusionSampler:
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
self.model_wrap_cfg.step = 0 self.model_wrap_cfg.step = 0
self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
self.eta = p.eta if p.eta is not None else opts.eta_ancestral self.eta = p.eta if p.eta is not None else opts.eta_ancestral
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else []) k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
...@@ -260,13 +292,14 @@ class KDiffusionSampler: ...@@ -260,13 +292,14 @@ class KDiffusionSampler:
self.model_wrap_cfg.init_latent = x self.model_wrap_cfg.init_latent = x
self.last_latent = x self.last_latent = x
extra_args={
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args={
'cond': conditioning, 'cond': conditioning,
'image_cond': image_conditioning, 'image_cond': image_conditioning,
'uncond': unconditional_conditioning, 'uncond': unconditional_conditioning,
'cond_scale': p.cfg_scale 'cond_scale': p.cfg_scale,
}, disable=False, callback=self.callback_state, **extra_params_kwargs)) }
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
return samples return samples
......
...@@ -106,7 +106,7 @@ parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, req ...@@ -106,7 +106,7 @@ parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, req
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None) parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
parser.add_argument("--gradio-queue", action='store_true', help="Uses gradio queue; experimental option; breaks restart UI button") parser.add_argument("--gradio-queue", action='store_true', help="Uses gradio queue; experimental option; breaks restart UI button")
parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers") parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
script_loading.preload_extensions(extensions.extensions_dir, parser) script_loading.preload_extensions(extensions.extensions_dir, parser)
......
...@@ -765,7 +765,9 @@ def create_ui(): ...@@ -765,7 +765,9 @@ def create_ui():
elif category == "cfg": elif category == "cfg":
with FormGroup(): with FormGroup():
with FormRow():
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale") cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale")
image_cfg_scale = gr.Slider(minimum=0, maximum=3.0, step=0.05, label='Image CFG Scale', value=1.5, elem_id="img2img_image_cfg_scale", visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit")
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength") denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength")
elif category == "seed": elif category == "seed":
...@@ -861,6 +863,7 @@ def create_ui(): ...@@ -861,6 +863,7 @@ def create_ui():
batch_count, batch_count,
batch_size, batch_size,
cfg_scale, cfg_scale,
image_cfg_scale,
denoising_strength, denoising_strength,
seed, seed,
subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox, subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox,
...@@ -947,6 +950,7 @@ def create_ui(): ...@@ -947,6 +950,7 @@ def create_ui():
(sampler_index, "Sampler"), (sampler_index, "Sampler"),
(restore_faces, "Face restoration"), (restore_faces, "Face restoration"),
(cfg_scale, "CFG scale"), (cfg_scale, "CFG scale"),
(image_cfg_scale, "Image CFG scale"),
(seed, "Seed"), (seed, "Seed"),
(width, "Size-1"), (width, "Size-1"),
(height, "Size-2"), (height, "Size-2"),
...@@ -1591,6 +1595,12 @@ def create_ui(): ...@@ -1591,6 +1595,12 @@ def create_ui():
outputs=[component, text_settings], outputs=[component, text_settings],
) )
text_settings.change(
fn=lambda: gr.update(visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit"),
inputs=[],
outputs=[image_cfg_scale],
)
button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False) button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
button_set_checkpoint.click( button_set_checkpoint.click(
fn=lambda value, _: run_settings_single(value, key='sd_model_checkpoint'), fn=lambda value, _: run_settings_single(value, key='sd_model_checkpoint'),
......
...@@ -26,11 +26,12 @@ def add_pages_to_demo(app): ...@@ -26,11 +26,12 @@ def add_pages_to_demo(app):
def fetch_file(filename: str = ""): def fetch_file(filename: str = ""):
from starlette.responses import FileResponse from starlette.responses import FileResponse
if not any([Path(x).resolve() in Path(filename).resolve().parents for x in allowed_dirs]): if not any([Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs]):
raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.") raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
if os.path.splitext(filename)[1].lower() != ".png": ext = os.path.splitext(filename)[1].lower()
raise ValueError(f"File cannot be fetched: {filename}. Only png.") if ext not in (".png", ".jpg"):
raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg.")
# would profit from returning 304 # would profit from returning 304
return FileResponse(filename, headers={"Accept-Ranges": "bytes"}) return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
......
...@@ -6,7 +6,7 @@ from tqdm import trange ...@@ -6,7 +6,7 @@ from tqdm import trange
import modules.scripts as scripts import modules.scripts as scripts
import gradio as gr import gradio as gr
from modules import processing, shared, sd_samplers, prompt_parser from modules import processing, shared, sd_samplers, prompt_parser, sd_samplers_common
from modules.processing import Processed from modules.processing import Processed
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
...@@ -50,7 +50,7 @@ def find_noise_for_image(p, cond, uncond, cfg_scale, steps): ...@@ -50,7 +50,7 @@ def find_noise_for_image(p, cond, uncond, cfg_scale, steps):
x = x + d * dt x = x + d * dt
sd_samplers.store_latent(x) sd_samplers_common.store_latent(x)
# This shouldn't be necessary, but solved some VRAM issues # This shouldn't be necessary, but solved some VRAM issues
del x_in, sigma_in, cond_in, c_out, c_in, t, del x_in, sigma_in, cond_in, c_out, c_in, t,
...@@ -104,7 +104,7 @@ def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps): ...@@ -104,7 +104,7 @@ def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps):
dt = sigmas[i] - sigmas[i - 1] dt = sigmas[i] - sigmas[i - 1]
x = x + d * dt x = x + d * dt
sd_samplers.store_latent(x) sd_samplers_common.store_latent(x)
# This shouldn't be necessary, but solved some VRAM issues # This shouldn't be necessary, but solved some VRAM issues
del x_in, sigma_in, cond_in, c_out, c_in, t, del x_in, sigma_in, cond_in, c_out, c_in, t,
......
...@@ -48,23 +48,17 @@ class Script(scripts.Script): ...@@ -48,23 +48,17 @@ class Script(scripts.Script):
gr.HTML('<br />') gr.HTML('<br />')
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
put_at_start = gr.Checkbox(label='Put variable parts at start of prompt', put_at_start = gr.Checkbox(label='Put variable parts at start of prompt', value=False, elem_id=self.elem_id("put_at_start"))
value=False, elem_id=self.elem_id("put_at_start")) different_seeds = gr.Checkbox(label='Use different seed for each picture', value=False, elem_id=self.elem_id("different_seeds"))
with gr.Column(): with gr.Column():
# Radio buttons for selecting the prompt between positive and negative prompt_type = gr.Radio(["positive", "negative"], label="Select prompt", elem_id=self.elem_id("prompt_type"), value="positive")
prompt_type = gr.Radio(["positive", "negative"], label="Select prompt", variations_delimiter = gr.Radio(["comma", "space"], label="Select joining char", elem_id=self.elem_id("variations_delimiter"), value="comma")
elem_id=self.elem_id("prompt_type"), value="positive")
with gr.Row():
with gr.Column():
different_seeds = gr.Checkbox(
label='Use different seed for each picture', value=False, elem_id=self.elem_id("different_seeds"))
with gr.Column(): with gr.Column():
# Radio buttons for selecting the delimiter to use in the resulting prompt margin_size = gr.Slider(label="Grid margins (px)", min=0, max=500, value=0, step=2, elem_id=self.elem_id("margin_size"))
variations_delimiter = gr.Radio(["comma", "space"], label="Select delimiter", elem_id=self.elem_id(
"variations_delimiter"), value="comma") return [put_at_start, different_seeds, prompt_type, variations_delimiter, margin_size]
return [put_at_start, different_seeds, prompt_type, variations_delimiter]
def run(self, p, put_at_start, different_seeds, prompt_type, variations_delimiter): def run(self, p, put_at_start, different_seeds, prompt_type, variations_delimiter, margin_size):
modules.processing.fix_seed(p) modules.processing.fix_seed(p)
# Raise error if promp type is not positive or negative # Raise error if promp type is not positive or negative
if prompt_type not in ["positive", "negative"]: if prompt_type not in ["positive", "negative"]:
...@@ -106,7 +100,7 @@ class Script(scripts.Script): ...@@ -106,7 +100,7 @@ class Script(scripts.Script):
processed = process_images(p) processed = process_images(p)
grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2)) grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2))
grid = images.draw_prompt_matrix(grid, p.width, p.height, prompt_matrix_parts) grid = images.draw_prompt_matrix(grid, p.width, p.height, prompt_matrix_parts, margin_size)
processed.images.insert(0, grid) processed.images.insert(0, grid)
processed.index_of_first_image = 1 processed.index_of_first_image = 1
processed.infotexts.insert(0, processed.infotexts[0]) processed.infotexts.insert(0, processed.infotexts[0])
......
...@@ -205,7 +205,7 @@ axis_options = [ ...@@ -205,7 +205,7 @@ axis_options = [
] ]
def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend, include_lone_images, include_sub_grids, first_axes_processed, second_axes_processed): def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend, include_lone_images, include_sub_grids, first_axes_processed, second_axes_processed, margin_size):
hor_texts = [[images.GridAnnotation(x)] for x in x_labels] hor_texts = [[images.GridAnnotation(x)] for x in x_labels]
ver_texts = [[images.GridAnnotation(y)] for y in y_labels] ver_texts = [[images.GridAnnotation(y)] for y in y_labels]
title_texts = [[images.GridAnnotation(z)] for z in z_labels] title_texts = [[images.GridAnnotation(z)] for z in z_labels]
...@@ -292,7 +292,7 @@ def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend ...@@ -292,7 +292,7 @@ def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend
end_index = start_index + len(xs) * len(ys) end_index = start_index + len(xs) * len(ys)
grid = images.image_grid(image_cache[start_index:end_index], rows=len(ys)) grid = images.image_grid(image_cache[start_index:end_index], rows=len(ys))
if draw_legend: if draw_legend:
grid = images.draw_grid_annotations(grid, cell_size[0], cell_size[1], hor_texts, ver_texts) grid = images.draw_grid_annotations(grid, cell_size[0], cell_size[1], hor_texts, ver_texts, margin_size)
sub_grids[i] = grid sub_grids[i] = grid
if include_sub_grids and len(zs) > 1: if include_sub_grids and len(zs) > 1:
processed_result.images.insert(i+1, grid) processed_result.images.insert(i+1, grid)
...@@ -351,10 +351,16 @@ class Script(scripts.Script): ...@@ -351,10 +351,16 @@ class Script(scripts.Script):
fill_z_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_z_tool_button", visible=False) fill_z_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_z_tool_button", visible=False)
with gr.Row(variant="compact", elem_id="axis_options"): with gr.Row(variant="compact", elem_id="axis_options"):
with gr.Column():
draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend")) draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend"))
no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds"))
with gr.Column():
include_lone_images = gr.Checkbox(label='Include Sub Images', value=False, elem_id=self.elem_id("include_lone_images")) include_lone_images = gr.Checkbox(label='Include Sub Images', value=False, elem_id=self.elem_id("include_lone_images"))
include_sub_grids = gr.Checkbox(label='Include Sub Grids', value=False, elem_id=self.elem_id("include_sub_grids")) include_sub_grids = gr.Checkbox(label='Include Sub Grids', value=False, elem_id=self.elem_id("include_sub_grids"))
no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds")) with gr.Column():
margin_size = gr.Slider(label="Grid margins (px)", min=0, max=500, value=0, step=2, elem_id=self.elem_id("margin_size"))
with gr.Row(variant="compact", elem_id="swap_axes"):
swap_xy_axes_button = gr.Button(value="Swap X/Y axes", elem_id="xy_grid_swap_axes_button") swap_xy_axes_button = gr.Button(value="Swap X/Y axes", elem_id="xy_grid_swap_axes_button")
swap_yz_axes_button = gr.Button(value="Swap Y/Z axes", elem_id="yz_grid_swap_axes_button") swap_yz_axes_button = gr.Button(value="Swap Y/Z axes", elem_id="yz_grid_swap_axes_button")
swap_xz_axes_button = gr.Button(value="Swap X/Z axes", elem_id="xz_grid_swap_axes_button") swap_xz_axes_button = gr.Button(value="Swap X/Z axes", elem_id="xz_grid_swap_axes_button")
...@@ -393,9 +399,9 @@ class Script(scripts.Script): ...@@ -393,9 +399,9 @@ class Script(scripts.Script):
(z_values, "Z Values"), (z_values, "Z Values"),
) )
return [x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds] return [x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size]
def run(self, p, x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds): def run(self, p, x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size):
if not no_fixed_seeds: if not no_fixed_seeds:
modules.processing.fix_seed(p) modules.processing.fix_seed(p)
...@@ -590,7 +596,8 @@ class Script(scripts.Script): ...@@ -590,7 +596,8 @@ class Script(scripts.Script):
include_lone_images=include_lone_images, include_lone_images=include_lone_images,
include_sub_grids=include_sub_grids, include_sub_grids=include_sub_grids,
first_axes_processed=first_axes_processed, first_axes_processed=first_axes_processed,
second_axes_processed=second_axes_processed second_axes_processed=second_axes_processed,
margin_size=margin_size
) )
if opts.grid_save and len(sub_grids) > 1: if opts.grid_save and len(sub_grids) > 1:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment