Commit df570640 authored by AUTOMATIC's avatar AUTOMATIC

do not load aesthetic clip model until it's needed

add refresh button for aesthetic embeddings
add aesthetic params to images' infotext
parent 7d6b388d
...@@ -40,6 +40,8 @@ def iter_to_batched(iterable, n=1): ...@@ -40,6 +40,8 @@ def iter_to_batched(iterable, n=1):
def create_ui(): def create_ui():
import modules.ui
with gr.Group(): with gr.Group():
with gr.Accordion("Open for Clip Aesthetic!", open=False): with gr.Accordion("Open for Clip Aesthetic!", open=False):
with gr.Row(): with gr.Row():
...@@ -55,6 +57,8 @@ def create_ui(): ...@@ -55,6 +57,8 @@ def create_ui():
label="Aesthetic imgs embedding", label="Aesthetic imgs embedding",
value="None") value="None")
modules.ui.create_refresh_button(aesthetic_imgs, shared.update_aesthetic_embeddings, lambda: {"choices": sorted(shared.aesthetic_embeddings.keys())}, "refresh_aesthetic_embeddings")
with gr.Row(): with gr.Row():
aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs', aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs',
placeholder="This text is used to rotate the feature space of the imgs embs", placeholder="This text is used to rotate the feature space of the imgs embs",
...@@ -66,11 +70,21 @@ def create_ui(): ...@@ -66,11 +70,21 @@ def create_ui():
return aesthetic_weight, aesthetic_steps, aesthetic_lr, aesthetic_slerp, aesthetic_imgs, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative return aesthetic_weight, aesthetic_steps, aesthetic_lr, aesthetic_slerp, aesthetic_imgs, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative
aesthetic_clip_model = None
def aesthetic_clip():
global aesthetic_clip_model
if aesthetic_clip_model is None or aesthetic_clip_model.name_or_path != shared.sd_model.cond_stage_model.wrapped.transformer.name_or_path:
aesthetic_clip_model = CLIPModel.from_pretrained(shared.sd_model.cond_stage_model.wrapped.transformer.name_or_path)
aesthetic_clip_model.cpu()
return aesthetic_clip_model
def generate_imgs_embd(name, folder, batch_size): def generate_imgs_embd(name, folder, batch_size):
# clipModel = CLIPModel.from_pretrained( model = aesthetic_clip().to(device)
# shared.sd_model.cond_stage_model.clipModel.name_or_path
# )
model = shared.clip_model.to(device)
processor = CLIPProcessor.from_pretrained(model.name_or_path) processor = CLIPProcessor.from_pretrained(model.name_or_path)
with torch.no_grad(): with torch.no_grad():
...@@ -91,7 +105,7 @@ def generate_imgs_embd(name, folder, batch_size): ...@@ -91,7 +105,7 @@ def generate_imgs_embd(name, folder, batch_size):
path = str(Path(shared.cmd_opts.aesthetic_embeddings_dir) / f"{name}.pt") path = str(Path(shared.cmd_opts.aesthetic_embeddings_dir) / f"{name}.pt")
torch.save(embs, path) torch.save(embs, path)
model = model.cpu() model.cpu()
del processor del processor
del embs del embs
gc.collect() gc.collect()
...@@ -132,7 +146,7 @@ class AestheticCLIP: ...@@ -132,7 +146,7 @@ class AestheticCLIP:
self.image_embs = None self.image_embs = None
self.load_image_embs(None) self.load_image_embs(None)
def set_aesthetic_params(self, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None, def set_aesthetic_params(self, p, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None,
aesthetic_slerp=True, aesthetic_imgs_text="", aesthetic_slerp=True, aesthetic_imgs_text="",
aesthetic_slerp_angle=0.15, aesthetic_slerp_angle=0.15,
aesthetic_text_negative=False): aesthetic_text_negative=False):
...@@ -145,6 +159,18 @@ class AestheticCLIP: ...@@ -145,6 +159,18 @@ class AestheticCLIP:
self.aesthetic_steps = aesthetic_steps self.aesthetic_steps = aesthetic_steps
self.load_image_embs(image_embs_name) self.load_image_embs(image_embs_name)
if self.image_embs_name is not None:
p.extra_generation_params.update({
"Aesthetic LR": aesthetic_lr,
"Aesthetic weight": aesthetic_weight,
"Aesthetic steps": aesthetic_steps,
"Aesthetic embedding": self.image_embs_name,
"Aesthetic slerp": aesthetic_slerp,
"Aesthetic text": aesthetic_imgs_text,
"Aesthetic text negative": aesthetic_text_negative,
"Aesthetic slerp angle": aesthetic_slerp_angle,
})
def set_skip(self, skip): def set_skip(self, skip):
self.skip = skip self.skip = skip
...@@ -168,7 +194,7 @@ class AestheticCLIP: ...@@ -168,7 +194,7 @@ class AestheticCLIP:
tokens = torch.asarray(remade_batch_tokens).to(device) tokens = torch.asarray(remade_batch_tokens).to(device)
model = copy.deepcopy(shared.clip_model).to(device) model = copy.deepcopy(aesthetic_clip()).to(device)
model.requires_grad_(True) model.requires_grad_(True)
if self.aesthetic_imgs_text is not None and len(self.aesthetic_imgs_text) > 0: if self.aesthetic_imgs_text is not None and len(self.aesthetic_imgs_text) > 0:
text_embs_2 = model.get_text_features( text_embs_2 = model.get_text_features(
......
...@@ -4,13 +4,22 @@ import gradio as gr ...@@ -4,13 +4,22 @@ import gradio as gr
from modules.shared import script_path from modules.shared import script_path
from modules import shared from modules import shared
re_param_code = r"\s*([\w ]+):\s*([^,]+)(?:,|$)" re_param_code = r'\s*([\w ]+):\s*("(?:\\|\"|[^\"])+"|[^,]*)(?:,|$)'
re_param = re.compile(re_param_code) re_param = re.compile(re_param_code)
re_params = re.compile(r"^(?:" + re_param_code + "){3,}$") re_params = re.compile(r"^(?:" + re_param_code + "){3,}$")
re_imagesize = re.compile(r"^(\d+)x(\d+)$") re_imagesize = re.compile(r"^(\d+)x(\d+)$")
type_of_gr_update = type(gr.update()) type_of_gr_update = type(gr.update())
def quote(text):
if ',' not in str(text):
return text
text = str(text)
text = text.replace('\\', '\\\\')
text = text.replace('"', '\\"')
return f'"{text}"'
def parse_generation_parameters(x: str): def parse_generation_parameters(x: str):
"""parses generation parameters string, the one you see in text field under the picture in UI: """parses generation parameters string, the one you see in text field under the picture in UI:
``` ```
...@@ -83,7 +92,12 @@ def connect_paste(button, paste_fields, input_comp, js=None): ...@@ -83,7 +92,12 @@ def connect_paste(button, paste_fields, input_comp, js=None):
else: else:
try: try:
valtype = type(output.value) valtype = type(output.value)
if valtype == bool and v == "False":
val = False
else:
val = valtype(v) val = valtype(v)
res.append(gr.update(value=val)) res.append(gr.update(value=val))
except Exception: except Exception:
res.append(gr.update()) res.append(gr.update())
......
...@@ -109,10 +109,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro ...@@ -109,10 +109,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
inpainting_mask_invert=inpainting_mask_invert, inpainting_mask_invert=inpainting_mask_invert,
) )
shared.aesthetic_clip.set_aesthetic_params(float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps), shared.aesthetic_clip.set_aesthetic_params(p, float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps), aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative)
aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text,
aesthetic_slerp_angle,
aesthetic_text_negative)
if shared.cmd_opts.enable_console_prompts: if shared.cmd_opts.enable_console_prompts:
print(f"\nimg2img: {prompt}", file=shared.progress_print_out) print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
......
...@@ -12,7 +12,7 @@ from skimage import exposure ...@@ -12,7 +12,7 @@ from skimage import exposure
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import modules.sd_hijack import modules.sd_hijack
from modules import devices, prompt_parser, masking, sd_samplers, lowvram from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste
from modules.sd_hijack import model_hijack from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
import modules.shared as shared import modules.shared as shared
...@@ -318,7 +318,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration ...@@ -318,7 +318,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
generation_params.update(p.extra_generation_params) generation_params.update(p.extra_generation_params)
generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None]) generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else "" negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""
......
...@@ -234,9 +234,6 @@ def load_model(checkpoint_info=None): ...@@ -234,9 +234,6 @@ def load_model(checkpoint_info=None):
sd_hijack.model_hijack.hijack(sd_model) sd_hijack.model_hijack.hijack(sd_model)
if shared.clip_model is None or shared.clip_model.transformer.name_or_path != sd_model.cond_stage_model.wrapped.transformer.name_or_path:
shared.clip_model = CLIPModel.from_pretrained(sd_model.cond_stage_model.wrapped.transformer.name_or_path)
sd_model.eval() sd_model.eval()
print(f"Model loaded.") print(f"Model loaded.")
......
...@@ -36,9 +36,7 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: ...@@ -36,9 +36,7 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
firstphase_height=firstphase_height if enable_hr else None, firstphase_height=firstphase_height if enable_hr else None,
) )
shared.aesthetic_clip.set_aesthetic_params(float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps), shared.aesthetic_clip.set_aesthetic_params(p, float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps), aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative)
aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, aesthetic_slerp_angle,
aesthetic_text_negative)
if cmd_opts.enable_console_prompts: if cmd_opts.enable_console_prompts:
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out) print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
......
...@@ -597,11 +597,7 @@ def apply_setting(key, value): ...@@ -597,11 +597,7 @@ def apply_setting(key, value):
return value return value
def create_ui(wrap_gradio_gpu_call): def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
import modules.img2img
import modules.txt2img
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
def refresh(): def refresh():
refresh_method() refresh_method()
args = refreshed_args() if callable(refreshed_args) else refreshed_args args = refreshed_args() if callable(refreshed_args) else refreshed_args
...@@ -613,12 +609,18 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -613,12 +609,18 @@ def create_ui(wrap_gradio_gpu_call):
refresh_button = gr.Button(value=refresh_symbol, elem_id=elem_id) refresh_button = gr.Button(value=refresh_symbol, elem_id=elem_id)
refresh_button.click( refresh_button.click(
fn = refresh, fn=refresh,
inputs = [], inputs=[],
outputs = [refresh_component] outputs=[refresh_component]
) )
return refresh_button return refresh_button
def create_ui(wrap_gradio_gpu_call):
import modules.img2img
import modules.txt2img
with gr.Blocks(analytics_enabled=False) as txt2img_interface: with gr.Blocks(analytics_enabled=False) as txt2img_interface:
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
dummy_component = gr.Label(visible=False) dummy_component = gr.Label(visible=False)
...@@ -802,6 +804,14 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -802,6 +804,14 @@ def create_ui(wrap_gradio_gpu_call):
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
(firstphase_width, "First pass size-1"), (firstphase_width, "First pass size-1"),
(firstphase_height, "First pass size-2"), (firstphase_height, "First pass size-2"),
(aesthetic_lr, "Aesthetic LR"),
(aesthetic_weight, "Aesthetic weight"),
(aesthetic_steps, "Aesthetic steps"),
(aesthetic_imgs, "Aesthetic embedding"),
(aesthetic_slerp, "Aesthetic slerp"),
(aesthetic_imgs_text, "Aesthetic text"),
(aesthetic_text_negative, "Aesthetic text negative"),
(aesthetic_slerp_angle, "Aesthetic slerp angle"),
] ]
txt2img_preview_params = [ txt2img_preview_params = [
...@@ -1077,6 +1087,14 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1077,6 +1087,14 @@ def create_ui(wrap_gradio_gpu_call):
(seed_resize_from_w, "Seed resize from-1"), (seed_resize_from_w, "Seed resize from-1"),
(seed_resize_from_h, "Seed resize from-2"), (seed_resize_from_h, "Seed resize from-2"),
(denoising_strength, "Denoising strength"), (denoising_strength, "Denoising strength"),
(aesthetic_lr_im, "Aesthetic LR"),
(aesthetic_weight_im, "Aesthetic weight"),
(aesthetic_steps_im, "Aesthetic steps"),
(aesthetic_imgs_im, "Aesthetic embedding"),
(aesthetic_slerp_im, "Aesthetic slerp"),
(aesthetic_imgs_text_im, "Aesthetic text"),
(aesthetic_text_negative_im, "Aesthetic text negative"),
(aesthetic_slerp_angle_im, "Aesthetic slerp angle"),
] ]
token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter]) token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
......
...@@ -477,7 +477,7 @@ input[type="range"]{ ...@@ -477,7 +477,7 @@ input[type="range"]{
padding: 0; padding: 0;
} }
#refresh_sd_model_checkpoint, #refresh_sd_hypernetwork, #refresh_train_hypernetwork_name, #refresh_train_embedding_name, #refresh_localization{ #refresh_sd_model_checkpoint, #refresh_sd_hypernetwork, #refresh_train_hypernetwork_name, #refresh_train_embedding_name, #refresh_localization, #refresh_aesthetic_embeddings{
max-width: 2.5em; max-width: 2.5em;
min-width: 2.5em; min-width: 2.5em;
height: 2.4em; height: 2.4em;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment