Commit 530103b5 authored by AUTOMATIC's avatar AUTOMATIC

fixes related to merge

parent 5de80618
import glob
import os
import sys
import traceback
import torch
from ldm.util import default
from modules import devices, shared
import torch
from torch import einsum
from einops import rearrange, repeat
class HypernetworkModule(torch.nn.Module):
def __init__(self, dim, state_dict):
super().__init__()
self.linear1 = torch.nn.Linear(dim, dim * 2)
self.linear2 = torch.nn.Linear(dim * 2, dim)
self.load_state_dict(state_dict, strict=True)
self.to(devices.device)
def forward(self, x):
return x + (self.linear2(self.linear1(x)))
class Hypernetwork:
filename = None
name = None
def __init__(self, filename):
self.filename = filename
self.name = os.path.splitext(os.path.basename(filename))[0]
self.layers = {}
state_dict = torch.load(filename, map_location='cpu')
for size, sd in state_dict.items():
self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1]))
def list_hypernetworks(path):
res = {}
for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True):
name = os.path.splitext(os.path.basename(filename))[0]
res[name] = filename
return res
def load_hypernetwork(filename):
path = shared.hypernetworks.get(filename, None)
if path is not None:
print(f"Loading hypernetwork {filename}")
try:
shared.loaded_hypernetwork = Hypernetwork(path)
except Exception:
print(f"Error loading hypernetwork {path}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
else:
if shared.loaded_hypernetwork is not None:
print(f"Unloading hypernetwork")
shared.loaded_hypernetwork = None
def apply_hypernetwork(hypernetwork, context):
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
if hypernetwork_layers is None:
return context, context
context_k = hypernetwork_layers[0](context)
context_v = hypernetwork_layers[1](context)
return context_k, context_v
def attention_CrossAttention_forward(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context)
k = self.to_k(context_k)
v = self.to_v(context_v)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
if mask is not None:
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)
...@@ -26,10 +26,11 @@ class HypernetworkModule(torch.nn.Module): ...@@ -26,10 +26,11 @@ class HypernetworkModule(torch.nn.Module):
if state_dict is not None: if state_dict is not None:
self.load_state_dict(state_dict, strict=True) self.load_state_dict(state_dict, strict=True)
else: else:
self.linear1.weight.data.fill_(0.0001)
self.linear1.bias.data.fill_(0.0001) self.linear1.weight.data.normal_(mean=0.0, std=0.01)
self.linear2.weight.data.fill_(0.0001) self.linear1.bias.data.zero_()
self.linear2.bias.data.fill_(0.0001) self.linear2.weight.data.normal_(mean=0.0, std=0.01)
self.linear2.bias.data.zero_()
self.to(devices.device) self.to(devices.device)
...@@ -92,41 +93,54 @@ class Hypernetwork: ...@@ -92,41 +93,54 @@ class Hypernetwork:
self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None) self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
def load_hypernetworks(path): def list_hypernetworks(path):
res = {} res = {}
for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True):
name = os.path.splitext(os.path.basename(filename))[0]
res[name] = filename
return res
for filename in glob.iglob(path + '**/*.pt', recursive=True):
def load_hypernetwork(filename):
path = shared.hypernetworks.get(filename, None)
if path is not None:
print(f"Loading hypernetwork {filename}")
try: try:
hn = Hypernetwork() shared.loaded_hypernetwork = Hypernetwork()
hn.load(filename) shared.loaded_hypernetwork.load(path)
res[hn.name] = hn
except Exception: except Exception:
print(f"Error loading hypernetwork {filename}", file=sys.stderr) print(f"Error loading hypernetwork {path}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
else:
if shared.loaded_hypernetwork is not None:
print(f"Unloading hypernetwork")
return res shared.loaded_hypernetwork = None
def attention_CrossAttention_forward(self, x, context=None, mask=None): def apply_hypernetwork(hypernetwork, context, layer=None):
h = self.heads hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
q = self.to_q(x) if hypernetwork_layers is None:
context = default(context, x) return context, context
hypernetwork_layers = (shared.hypernetwork.layers if shared.hypernetwork is not None else {}).get(context.shape[2], None) if layer is not None:
layer.hyper_k = hypernetwork_layers[0]
layer.hyper_v = hypernetwork_layers[1]
if hypernetwork_layers is not None: context_k = hypernetwork_layers[0](context)
hypernetwork_k, hypernetwork_v = hypernetwork_layers context_v = hypernetwork_layers[1](context)
return context_k, context_v
self.hypernetwork_k = hypernetwork_k
self.hypernetwork_v = hypernetwork_v
context_k = hypernetwork_k(context) def attention_CrossAttention_forward(self, x, context=None, mask=None):
context_v = hypernetwork_v(context) h = self.heads
else:
context_k = context q = self.to_q(x)
context_v = context context = default(context, x)
context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context, self)
k = self.to_k(context_k) k = self.to_k(context_k)
v = self.to_v(context_v) v = self.to_v(context_v)
...@@ -151,7 +165,9 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None): ...@@ -151,7 +165,9 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None):
def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_image_prompt): def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_image_prompt):
assert hypernetwork_name, 'embedding not selected' assert hypernetwork_name, 'embedding not selected'
shared.hypernetwork = shared.hypernetworks[hypernetwork_name] path = shared.hypernetworks.get(hypernetwork_name, None)
shared.loaded_hypernetwork = Hypernetwork()
shared.loaded_hypernetwork.load(path)
shared.state.textinfo = "Initializing hypernetwork training..." shared.state.textinfo = "Initializing hypernetwork training..."
shared.state.job_count = steps shared.state.job_count = steps
...@@ -176,9 +192,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, ...@@ -176,9 +192,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"): with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=512, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file) ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=1, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file)
hypernetwork = shared.hypernetworks[hypernetwork_name] hypernetwork = shared.loaded_hypernetwork
weights = hypernetwork.weights() weights = hypernetwork.weights()
for weight in weights: for weight in weights:
weight.requires_grad = True weight.requires_grad = True
...@@ -194,7 +210,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, ...@@ -194,7 +210,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
if ititial_step > steps: if ititial_step > steps:
return hypernetwork, filename return hypernetwork, filename
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
for i, (x, text) in pbar: for i, (x, text) in pbar:
hypernetwork.step = i + ititial_step hypernetwork.step = i + ititial_step
......
...@@ -6,24 +6,24 @@ import gradio as gr ...@@ -6,24 +6,24 @@ import gradio as gr
import modules.textual_inversion.textual_inversion import modules.textual_inversion.textual_inversion
import modules.textual_inversion.preprocess import modules.textual_inversion.preprocess
from modules import sd_hijack, shared from modules import sd_hijack, shared
from modules.hypernetwork import hypernetwork
def create_hypernetwork(name): def create_hypernetwork(name):
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
assert not os.path.exists(fn), f"file {fn} already exists" assert not os.path.exists(fn), f"file {fn} already exists"
hypernetwork = modules.hypernetwork.hypernetwork.Hypernetwork(name=name) hypernet = modules.hypernetwork.hypernetwork.Hypernetwork(name=name)
hypernetwork.save(fn) hypernet.save(fn)
shared.reload_hypernetworks() shared.reload_hypernetworks()
shared.hypernetwork = shared.hypernetworks.get(shared.opts.sd_hypernetwork, None)
return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {fn}", "" return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {fn}", ""
def train_hypernetwork(*args): def train_hypernetwork(*args):
initial_hypernetwork = shared.hypernetwork initial_hypernetwork = shared.loaded_hypernetwork
try: try:
sd_hijack.undo_optimizations() sd_hijack.undo_optimizations()
...@@ -38,6 +38,6 @@ Hypernetwork saved to {html.escape(filename)} ...@@ -38,6 +38,6 @@ Hypernetwork saved to {html.escape(filename)}
except Exception: except Exception:
raise raise
finally: finally:
shared.hypernetwork = initial_hypernetwork shared.loaded_hypernetwork = initial_hypernetwork
sd_hijack.apply_optimizations() sd_hijack.apply_optimizations()
...@@ -8,7 +8,8 @@ from torch import einsum ...@@ -8,7 +8,8 @@ from torch import einsum
from ldm.util import default from ldm.util import default
from einops import rearrange from einops import rearrange
from modules import shared, hypernetwork from modules import shared
from modules.hypernetwork import hypernetwork
if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
......
...@@ -13,7 +13,8 @@ import modules.memmon ...@@ -13,7 +13,8 @@ import modules.memmon
import modules.sd_models import modules.sd_models
import modules.styles import modules.styles
import modules.devices as devices import modules.devices as devices
from modules import sd_samplers, hypernetwork from modules import sd_samplers
from modules.hypernetwork import hypernetwork
from modules.paths import models_path, script_path, sd_path from modules.paths import models_path, script_path, sd_path
sd_model_file = os.path.join(script_path, 'model.ckpt') sd_model_file = os.path.join(script_path, 'model.ckpt')
...@@ -29,6 +30,7 @@ parser.add_argument("--no-half-vae", action='store_true', help="do not switch th ...@@ -29,6 +30,7 @@ parser.add_argument("--no-half-vae", action='store_true', help="do not switch th
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)") parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI") parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)") parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui") parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage") parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage") parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
...@@ -82,10 +84,17 @@ parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram ...@@ -82,10 +84,17 @@ parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
xformers_available = False xformers_available = False
config_filename = cmd_opts.ui_settings_file config_filename = cmd_opts.ui_settings_file
hypernetworks = hypernetwork.list_hypernetworks(os.path.join(models_path, 'hypernetworks')) hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
loaded_hypernetwork = None loaded_hypernetwork = None
def reload_hypernetworks():
global hypernetworks
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
hypernetwork.load_hypernetwork(opts.sd_hypernetwork)
class State: class State:
skipped = False skipped = False
interrupted = False interrupted = False
......
...@@ -156,7 +156,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'): ...@@ -156,7 +156,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
return fn return fn
def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file): def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file, preview_image_prompt):
assert embedding_name, 'embedding not selected' assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..." shared.state.textinfo = "Initializing textual inversion training..."
...@@ -238,12 +238,14 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini ...@@ -238,12 +238,14 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0: if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png') last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
preview_text = text if preview_image_prompt == "" else preview_image_prompt
p = processing.StableDiffusionProcessingTxt2Img( p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model, sd_model=shared.sd_model,
prompt=text, prompt=preview_text,
steps=20, steps=20,
height=training_height, height=training_height,
width=training_width, width=training_width,
do_not_save_grid=True, do_not_save_grid=True,
do_not_save_samples=True, do_not_save_samples=True,
) )
...@@ -254,7 +256,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini ...@@ -254,7 +256,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
shared.state.current_image = image shared.state.current_image = image
image.save(last_saved_image) image.save(last_saved_image)
last_saved_image += f", prompt: {text}" last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = embedding.step shared.state.job_no = embedding.step
......
...@@ -1023,7 +1023,7 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1023,7 +1023,7 @@ def create_ui(wrap_gradio_gpu_call):
gr.HTML(value="") gr.HTML(value="")
with gr.Column(): with gr.Column():
create_embedding = gr.Button(value="Create", variant='primary') create_embedding = gr.Button(value="Create embedding", variant='primary')
with gr.Group(): with gr.Group():
gr.HTML(value="<p style='margin-bottom: 0.7em'>Create a new hypernetwork</p>") gr.HTML(value="<p style='margin-bottom: 0.7em'>Create a new hypernetwork</p>")
...@@ -1035,7 +1035,7 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1035,7 +1035,7 @@ def create_ui(wrap_gradio_gpu_call):
gr.HTML(value="") gr.HTML(value="")
with gr.Column(): with gr.Column():
create_hypernetwork = gr.Button(value="Create", variant='primary') create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary')
with gr.Group(): with gr.Group():
gr.HTML(value="<p style='margin-bottom: 0.7em'>Preprocess images</p>") gr.HTML(value="<p style='margin-bottom: 0.7em'>Preprocess images</p>")
...@@ -1147,6 +1147,7 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1147,6 +1147,7 @@ def create_ui(wrap_gradio_gpu_call):
create_image_every, create_image_every,
save_embedding_every, save_embedding_every,
template_file, template_file,
preview_image_prompt,
], ],
outputs=[ outputs=[
ti_output, ti_output,
......
...@@ -10,7 +10,8 @@ import numpy as np ...@@ -10,7 +10,8 @@ import numpy as np
import modules.scripts as scripts import modules.scripts as scripts
import gradio as gr import gradio as gr
from modules import images, hypernetwork from modules import images
from modules.hypernetwork import hypernetwork
from modules.processing import process_images, Processed, get_correct_sampler from modules.processing import process_images, Processed, get_correct_sampler
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
import modules.shared as shared import modules.shared as shared
......
...@@ -29,6 +29,7 @@ from modules import devices ...@@ -29,6 +29,7 @@ from modules import devices
from modules import modelloader from modules import modelloader
from modules.paths import script_path from modules.paths import script_path
from modules.shared import cmd_opts from modules.shared import cmd_opts
import modules.hypernetwork.hypernetwork
modelloader.cleanup_models() modelloader.cleanup_models()
modules.sd_models.setup_model() modules.sd_models.setup_model()
...@@ -77,22 +78,12 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): ...@@ -77,22 +78,12 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs) return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs)
def set_hypernetwork():
shared.hypernetwork = shared.hypernetworks.get(shared.opts.sd_hypernetwork, None)
shared.reload_hypernetworks()
shared.opts.onchange("sd_hypernetwork", set_hypernetwork)
set_hypernetwork()
modules.scripts.load_scripts(os.path.join(script_path, "scripts")) modules.scripts.load_scripts(os.path.join(script_path, "scripts"))
shared.sd_model = modules.sd_models.load_model() shared.sd_model = modules.sd_models.load_model()
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model))) shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model)))
loaded_hypernetwork = modules.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork) shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetwork.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
def webui(): def webui():
...@@ -117,7 +108,7 @@ def webui(): ...@@ -117,7 +108,7 @@ def webui():
prevent_thread_lock=True prevent_thread_lock=True
) )
app.add_middleware(GZipMiddleware,minimum_size=1000) app.add_middleware(GZipMiddleware, minimum_size=1000)
while 1: while 1:
time.sleep(0.5) time.sleep(0.5)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment