Commit 0fb34b57 authored by Kohaku-Blueleaf's avatar Kohaku-Blueleaf

Merge branch 'dev' into test-fp8

parents 39ebd568 aeaf1c51
......@@ -21,6 +21,8 @@ class NetworkModuleOFT(network.NetworkModule):
self.lin_module = None
self.org_module: list[torch.Module] = [self.sd_module]
self.scale = 1.0
# kohya-ss
if "oft_blocks" in weights.w.keys():
self.is_kohya = True
......@@ -53,12 +55,18 @@ class NetworkModuleOFT(network.NetworkModule):
self.constraint = None
self.block_size, self.num_blocks = factorization(self.out_dim, self.dim)
def calc_updown_kb(self, orig_weight, multiplier):
def calc_updown(self, orig_weight):
oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
oft_blocks = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix
eye = torch.eye(self.block_size, device=self.oft_blocks.device)
if self.is_kohya:
block_Q = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix
norm_Q = torch.norm(block_Q.flatten())
new_norm_Q = torch.clamp(norm_Q, max=self.constraint)
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse())
R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
R = R * multiplier + torch.eye(self.block_size, device=orig_weight.device)
# This errors out for MultiheadAttention, might need to be handled up-stream
merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
......@@ -72,26 +80,3 @@ class NetworkModuleOFT(network.NetworkModule):
updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
output_shape = orig_weight.shape
return self.finalize_updown(updown, orig_weight, output_shape)
def calc_updown(self, orig_weight):
# if alpha is a very small number as in coft, calc_scale() will return a almost zero number so we ignore it
multiplier = self.multiplier()
return self.calc_updown_kb(orig_weight, multiplier)
# override to remove the multiplier/scale factor; it's already multiplied in get_weight
def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
if self.bias is not None:
updown = updown.reshape(self.bias.shape)
updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
updown = updown.reshape(output_shape)
if len(output_shape) == 4:
updown = updown.reshape(output_shape)
if orig_weight.size().numel() == updown.size().numel():
updown = updown.reshape(orig_weight.shape)
if ex_bias is not None:
ex_bias = ex_bias * self.multiplier()
return updown, ex_bias
......@@ -159,7 +159,8 @@ def load_network(name, network_on_disk):
bundle_embeddings = {}
for key_network, weight in sd.items():
key_network_without_network_parts, network_part = key_network.split(".", 1)
key_network_without_network_parts, _, network_part = key_network.partition(".")
if key_network_without_network_parts == "bundle_emb":
emb_name, vec_name = network_part.split(".", 1)
emb_dict = bundle_embeddings.get(emb_name, {})
......
......@@ -23,11 +23,12 @@ class ExtraOptionsSection(scripts.Script):
self.setting_names = []
self.infotext_fields = []
extra_options = shared.opts.extra_options_img2img if is_img2img else shared.opts.extra_options_txt2img
elem_id_tabname = "extra_options_" + ("img2img" if is_img2img else "txt2img")
mapping = {k: v for v, k in generation_parameters_copypaste.infotext_to_setting_name_mapping}
with gr.Blocks() as interface:
with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and extra_options else gr.Group():
with gr.Accordion("Options", open=False, elem_id=elem_id_tabname) if shared.opts.extra_options_accordion and extra_options else gr.Group(elem_id=elem_id_tabname):
row_count = math.ceil(len(extra_options) / shared.opts.extra_options_cols)
......@@ -70,7 +71,7 @@ This page allows you to add some settings to the main interface of txt2img and i
"""),
"extra_options_txt2img": shared.OptionInfo([], "Settings for txt2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img interfaces").needs_reload_ui(),
"extra_options_img2img": shared.OptionInfo([], "Settings for img2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in img2img interfaces").needs_reload_ui(),
"extra_options_cols": shared.OptionInfo(1, "Number of columns for added settings", gr.Number, {"precision": 0}).needs_reload_ui(),
"extra_options_cols": shared.OptionInfo(1, "Number of columns for added settings", gr.Slider, {"step": 1, "minimum": 1, "maximum": 20}).info("displayed amount will depend on the actual browser window width").needs_reload_ui(),
"extra_options_accordion": shared.OptionInfo(False, "Place added settings into an accordion").needs_reload_ui()
}))
......
......@@ -34,7 +34,7 @@ function updateOnBackgroundChange() {
if (modalImage && modalImage.offsetParent) {
let currentButton = selected_gallery_button();
let preview = gradioApp().querySelectorAll('.livePreview > img');
if (preview.length > 0) {
if (opts.js_live_preview_in_modal_lightbox && preview.length > 0) {
// show preview image if available
modalImage.src = preview[preview.length - 1].src;
} else if (currentButton?.children?.length > 0 && modalImage.src != currentButton.children[0].src) {
......
......@@ -215,9 +215,33 @@ function restoreProgressImg2img() {
}
/**
* Configure the width and height elements on `tabname` to accept
* pasting of resolutions in the form of "width x height".
*/
function setupResolutionPasting(tabname) {
var width = gradioApp().querySelector(`#${tabname}_width input[type=number]`);
var height = gradioApp().querySelector(`#${tabname}_height input[type=number]`);
for (const el of [width, height]) {
el.addEventListener('paste', function(event) {
var pasteData = event.clipboardData.getData('text/plain');
var parsed = pasteData.match(/^\s*(\d+)\D+(\d+)\s*$/);
if (parsed) {
width.value = parsed[1];
height.value = parsed[2];
updateInput(width);
updateInput(height);
event.preventDefault();
}
});
}
}
onUiLoaded(function() {
showRestoreProgressButton('txt2img', localGet("txt2img_task_id"));
showRestoreProgressButton('img2img', localGet("img2img_task_id"));
setupResolutionPasting('txt2img');
setupResolutionPasting('img2img');
});
......
......@@ -791,3 +791,4 @@ def flatten(img, bgcolor):
img = background
return img.convert('RGB')
This diff is collapsed.
......@@ -11,11 +11,31 @@ from modules import shared, paths, script_callbacks, extensions, script_loading,
AlwaysVisible = object()
class MaskBlendArgs:
def __init__(self, current_latent, nmask, init_latent, mask, blended_latent, denoiser=None, sigma=None):
self.current_latent = current_latent
self.nmask = nmask
self.init_latent = init_latent
self.mask = mask
self.blended_latent = blended_latent
self.denoiser = denoiser
self.is_final_blend = denoiser is None
self.sigma = sigma
class PostSampleArgs:
def __init__(self, samples):
self.samples = samples
class PostprocessImageArgs:
def __init__(self, image):
self.image = image
class PostProcessMaskOverlayArgs:
def __init__(self, index, mask_for_overlay, overlay_image):
self.index = index
self.mask_for_overlay = mask_for_overlay
self.overlay_image = overlay_image
class PostprocessBatchListArgs:
def __init__(self, images):
......@@ -206,6 +226,25 @@ class Script:
pass
def on_mask_blend(self, p, mba: MaskBlendArgs, *args):
"""
Called in inpainting mode when the original content is blended with the inpainted content.
This is called at every step in the denoising process and once at the end.
If is_final_blend is true, this is called for the final blending stage.
Otherwise, denoiser and sigma are defined and may be used to inform the procedure.
"""
pass
def post_sample(self, p, ps: PostSampleArgs, *args):
"""
Called after the samples have been generated,
but before they have been decoded by the VAE, if applicable.
Check getattr(samples, 'already_decoded', False) to test if the images are decoded.
"""
pass
def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
"""
Called for every image after it has been generated.
......@@ -213,6 +252,13 @@ class Script:
pass
def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs, *args):
"""
Called for every image after it has been generated.
"""
pass
def postprocess(self, p, processed, *args):
"""
This function is called after processing ends for AlwaysVisible scripts.
......@@ -767,6 +813,22 @@ class ScriptRunner:
except Exception:
errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True)
def post_sample(self, p, ps: PostSampleArgs):
for script in self.alwayson_scripts:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.post_sample(p, ps, *script_args)
except Exception:
errors.report(f"Error running post_sample: {script.filename}", exc_info=True)
def on_mask_blend(self, p, mba: MaskBlendArgs):
for script in self.alwayson_scripts:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.on_mask_blend(p, mba, *script_args)
except Exception:
errors.report(f"Error running post_sample: {script.filename}", exc_info=True)
def postprocess_image(self, p, pp: PostprocessImageArgs):
for script in self.alwayson_scripts:
try:
......@@ -775,6 +837,14 @@ class ScriptRunner:
except Exception:
errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs):
for script in self.alwayson_scripts:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.postprocess_maskoverlay(p, ppmo, *script_args)
except Exception:
errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
def before_component(self, component, **kwargs):
for callback, script in self.on_before_component_elem_id.get(kwargs.get("elem_id"), []):
try:
......
......@@ -215,7 +215,7 @@ class LoadStateDictOnMeta(ReplaceHelper):
would be on the meta device.
"""
if state_dict == sd:
if state_dict is sd:
state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()}
original(module, state_dict, strict=strict)
......
......@@ -56,6 +56,9 @@ class CFGDenoiser(torch.nn.Module):
self.sampler = sampler
self.model_wrap = None
self.p = None
# NOTE: masking before denoising can cause the original latents to be oversmoothed
# as the original latents do not have noise
self.mask_before_denoising = False
@property
......@@ -105,8 +108,21 @@ class CFGDenoiser(torch.nn.Module):
assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
# If we use masks, blending between the denoised and original latent images occurs here.
def apply_blend(current_latent):
blended_latent = current_latent * self.nmask + self.init_latent * self.mask
if self.p.scripts is not None:
from modules import scripts
mba = scripts.MaskBlendArgs(current_latent, self.nmask, self.init_latent, self.mask, blended_latent, denoiser=self, sigma=sigma)
self.p.scripts.on_mask_blend(self.p, mba)
blended_latent = mba.blended_latent
return blended_latent
# Blend in the original latents (before)
if self.mask_before_denoising and self.mask is not None:
x = self.init_latent * self.mask + self.nmask * x
x = apply_blend(x)
batch_size = len(conds_list)
repeats = [len(conds_list[i]) for i in range(batch_size)]
......@@ -207,8 +223,9 @@ class CFGDenoiser(torch.nn.Module):
else:
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
# Blend in the original latents (after)
if not self.mask_before_denoising and self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised
denoised = apply_blend(denoised)
self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)
......
......@@ -258,6 +258,7 @@ options_templates.update(options_section(('ui_prompt_editing', "Prompt editing",
"keyedit_precision_extra": OptionInfo(0.05, "Precision for <extra networks:0.9> when editing the prompt with Ctrl+up/down", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
"keyedit_delimiters": OptionInfo(r".,\/!?%^*;:{}=`~() ", "Word delimiters when editing the prompt with Ctrl+up/down"),
"keyedit_delimiters_whitespace": OptionInfo(["Tab", "Carriage Return", "Line Feed"], "Ctrl+up/down whitespace delimiters", gr.CheckboxGroup, lambda: {"choices": ["Tab", "Carriage Return", "Line Feed"]}),
"keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"),
"disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(),
}))
......@@ -332,6 +333,7 @@ options_templates.update(options_section(('ui', "Live previews", "ui"), {
"live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
"live_preview_refresh_period": OptionInfo(1000, "Progressbar and preview update period").info("in milliseconds"),
"live_preview_fast_interrupt": OptionInfo(False, "Return image with chosen live preview method on interrupt").info("makes interrupts faster"),
"js_live_preview_in_modal_lightbox": OptionInfo(True, "Show Live preview in full page image viewer"),
}))
options_templates.update(options_section(('sampler-params', "Sampler parameters", "sd"), {
......
......@@ -98,10 +98,8 @@ class StyleDatabase:
self.path = path
folder, file = os.path.split(self.path)
self.default_file = file.split("*")[0] + ".csv"
if self.default_file == ".csv":
self.default_file = "styles.csv"
self.default_path = os.path.join(folder, self.default_file)
filename, _, ext = file.partition('*')
self.default_path = os.path.join(folder, filename + ext)
self.prompt_fields = [field for field in PromptStyle._fields if field != "path"]
......@@ -155,10 +153,8 @@ class StyleDatabase:
row["name"], prompt, negative_prompt, path
)
def get_style_paths(self) -> list():
"""
Returns a list of all distinct paths, including the default path, of
files that styles are loaded from."""
def get_style_paths(self) -> set:
"""Returns a set of all distinct paths of files that styles are loaded from."""
# Update any styles without a path to the default path
for style in list(self.styles.values()):
if not style.path:
......@@ -172,9 +168,9 @@ class StyleDatabase:
style_paths.add(style.path)
# Remove any paths for styles that are just list dividers
style_paths.remove("do_not_save")
style_paths.discard("do_not_save")
return list(style_paths)
return style_paths
def get_style_prompts(self, styles):
return [self.styles.get(x, self.no_style).prompt for x in styles]
......@@ -196,20 +192,7 @@ class StyleDatabase:
# The path argument is deprecated, but kept for backwards compatibility
_ = path
# Update any styles without a path to the default path
for style in list(self.styles.values()):
if not style.path:
self.styles[style.name] = style._replace(path=self.default_path)
# Create a list of all distinct paths, including the default path
style_paths = set()
style_paths.add(self.default_path)
for _, style in self.styles.items():
if style.path:
style_paths.add(style.path)
# Remove any paths for styles that are just list dividers
style_paths.remove("do_not_save")
style_paths = self.get_style_paths()
csv_names = [os.path.split(path)[1].lower() for path in style_paths]
......
......@@ -79,11 +79,11 @@ class Toprow:
def create_prompts(self):
with gr.Column(elem_id=f"{self.id_part}_prompt_container", elem_classes=["prompt-container-compact"] if self.is_compact else [], scale=6):
with gr.Row(elem_id=f"{self.id_part}_prompt_row", elem_classes=["prompt-row"]):
self.prompt = gr.Textbox(label="Prompt", elem_id=f"{self.id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
self.prompt = gr.Textbox(label="Prompt", elem_id=f"{self.id_part}_prompt", show_label=False, lines=3, placeholder="Prompt\n(Press Ctrl+Enter to generate, Alt+Enter to skip, Esc to interrupt)", elem_classes=["prompt"])
self.prompt_img = gr.File(label="", elem_id=f"{self.id_part}_prompt_image", file_count="single", type="binary", visible=False)
with gr.Row(elem_id=f"{self.id_part}_neg_prompt_row", elem_classes=["prompt-row"]):
self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{self.id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{self.id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt\n(Press Ctrl+Enter to generate, Alt+Enter to skip, Esc to interrupt)", elem_classes=["prompt"])
self.prompt_img.change(
fn=modules.images.image_data,
......
......@@ -48,3 +48,12 @@ if has_xpu:
CondFunc('torch.nn.modules.conv.Conv2d.forward',
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
CondFunc('torch.bmm',
lambda orig_func, input, mat2, out=None: orig_func(input.to(mat2.dtype), mat2, out=out),
lambda orig_func, input, mat2, out=None: input.dtype != mat2.dtype)
CondFunc('torch.cat',
lambda orig_func, tensors, dim=0, out=None: orig_func([t.to(tensors[0].dtype) for t in tensors], dim=dim, out=out),
lambda orig_func, tensors, dim=0, out=None: not all(t.dtype == tensors[0].dtype for t in tensors))
CondFunc('torch.nn.functional.scaled_dot_product_attention',
lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: orig_func(query, key.to(query.dtype), value.to(query.dtype), attn_mask, dropout_p, is_causal),
lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: query.dtype != key.dtype or query.dtype != value.dtype)
......@@ -121,16 +121,22 @@ document.addEventListener("DOMContentLoaded", function() {
});
/**
* Add a ctrl+enter as a shortcut to start a generation
* Add keyboard shortcuts:
* Ctrl+Enter to start/restart a generation
* Alt/Option+Enter to skip a generation
* Esc to interrupt a generation
*/
document.addEventListener('keydown', function(e) {
const isEnter = e.key === 'Enter' || e.keyCode === 13;
const isModifierKey = e.metaKey || e.ctrlKey || e.altKey;
const isCtrlKey = e.metaKey || e.ctrlKey;
const isAltKey = e.altKey;
const isEsc = e.key === 'Escape';
const interruptButton = get_uiCurrentTabContent().querySelector('button[id$=_interrupt]');
const generateButton = get_uiCurrentTabContent().querySelector('button[id$=_generate]');
const interruptButton = get_uiCurrentTabContent().querySelector('button[id$=_interrupt]');
const skipButton = get_uiCurrentTabContent().querySelector('button[id$=_skip]');
if (isEnter && isModifierKey) {
if (isCtrlKey && isEnter) {
if (interruptButton.style.display === 'block') {
interruptButton.click();
const callback = (mutationList) => {
......@@ -150,6 +156,21 @@ document.addEventListener('keydown', function(e) {
}
e.preventDefault();
}
if (isAltKey && isEnter) {
skipButton.click();
e.preventDefault();
}
if (isEsc) {
const globalPopup = document.querySelector('.global-popup');
const lightboxModal = document.querySelector('#lightboxModal');
if (!globalPopup || globalPopup.style.display === 'none') {
if (document.activeElement === lightboxModal) return;
interruptButton.click();
e.preventDefault();
}
}
});
/**
......
This diff is collapsed.
......@@ -133,7 +133,7 @@ case "$gpu_info" in
if [[ $(bc <<< "$pyv <= 3.10") -eq 1 ]]
then
# Navi users will still use torch 1.13 because 2.0 does not seem to work.
export TORCH_COMMAND="pip install torch==1.13.1+rocm5.2 torchvision==0.14.1+rocm5.2 --index-url https://download.pytorch.org/whl/rocm5.2"
export TORCH_COMMAND="pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm5.6"
else
printf "\e[1m\e[31mERROR: RX 5000 series GPUs must be using at max python 3.10, aborting...\e[0m"
exit 1
......@@ -143,8 +143,7 @@ case "$gpu_info" in
*"Navi 2"*) export HSA_OVERRIDE_GFX_VERSION=10.3.0
;;
*"Navi 3"*) [[ -z "${TORCH_COMMAND}" ]] && \
export TORCH_COMMAND="pip install torch torchvision --index-url https://download.pytorch.org/whl/test/rocm5.6"
# Navi 3 needs at least 5.5 which is only on the torch 2.1.0 release candidates right now
export TORCH_COMMAND="pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm5.7"
;;
*"Renoir"*) export HSA_OVERRIDE_GFX_VERSION=9.0.0
printf "\n%s\n" "${delimiter}"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment