Commit d56a9cfe authored by Kohaku-Blueleaf's avatar Kohaku-Blueleaf

Merge branch 'dev' into efficient-vae-methods

parents a6b245e4 a32f270a
...@@ -2,16 +2,14 @@ function setupExtraNetworksForTab(tabname) { ...@@ -2,16 +2,14 @@ function setupExtraNetworksForTab(tabname) {
gradioApp().querySelector('#' + tabname + '_extra_tabs').classList.add('extra-networks'); gradioApp().querySelector('#' + tabname + '_extra_tabs').classList.add('extra-networks');
var tabs = gradioApp().querySelector('#' + tabname + '_extra_tabs > div'); var tabs = gradioApp().querySelector('#' + tabname + '_extra_tabs > div');
var search = gradioApp().querySelector('#' + tabname + '_extra_search textarea'); var searchDiv = gradioApp().getElementById(tabname + '_extra_search');
var search = searchDiv.querySelector('textarea');
var sort = gradioApp().getElementById(tabname + '_extra_sort'); var sort = gradioApp().getElementById(tabname + '_extra_sort');
var sortOrder = gradioApp().getElementById(tabname + '_extra_sortorder'); var sortOrder = gradioApp().getElementById(tabname + '_extra_sortorder');
var refresh = gradioApp().getElementById(tabname + '_extra_refresh'); var refresh = gradioApp().getElementById(tabname + '_extra_refresh');
search.classList.add('search');
sort.classList.add('sort');
sortOrder.classList.add('sortorder');
sort.dataset.sortkey = 'sortDefault'; sort.dataset.sortkey = 'sortDefault';
tabs.appendChild(search); tabs.appendChild(searchDiv);
tabs.appendChild(sort); tabs.appendChild(sort);
tabs.appendChild(sortOrder); tabs.appendChild(sortOrder);
tabs.appendChild(refresh); tabs.appendChild(refresh);
...@@ -179,7 +177,7 @@ function saveCardPreview(event, tabname, filename) { ...@@ -179,7 +177,7 @@ function saveCardPreview(event, tabname, filename) {
} }
function extraNetworksSearchButton(tabs_id, event) { function extraNetworksSearchButton(tabs_id, event) {
var searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > div > textarea'); var searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > label > textarea');
var button = event.target; var button = event.target;
var text = button.classList.contains("search-all") ? "" : button.textContent.trim(); var text = button.classList.contains("search-all") ? "" : button.textContent.trim();
......
...@@ -280,6 +280,9 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model ...@@ -280,6 +280,9 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
if "Hires sampler" not in res: if "Hires sampler" not in res:
res["Hires sampler"] = "Use same sampler" res["Hires sampler"] = "Use same sampler"
if "Hires checkpoint" not in res:
res["Hires checkpoint"] = "Use same checkpoint"
if "Hires prompt" not in res: if "Hires prompt" not in res:
res["Hires prompt"] = "" res["Hires prompt"] = ""
......
...@@ -318,7 +318,7 @@ def resize_image(resize_mode, im, width, height, upscaler_name=None): ...@@ -318,7 +318,7 @@ def resize_image(resize_mode, im, width, height, upscaler_name=None):
return res return res
invalid_filename_chars = '<>:"/\\|?*\n' invalid_filename_chars = '<>:"/\\|?*\n\r\t'
invalid_filename_prefix = ' ' invalid_filename_prefix = ' '
invalid_filename_postfix = ' .' invalid_filename_postfix = ' .'
re_nonletters = re.compile(r'[\s' + string.punctuation + ']+') re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
......
...@@ -15,6 +15,9 @@ def send_everything_to_cpu(): ...@@ -15,6 +15,9 @@ def send_everything_to_cpu():
def setup_for_low_vram(sd_model, use_medvram): def setup_for_low_vram(sd_model, use_medvram):
if getattr(sd_model, 'lowvram', False):
return
sd_model.lowvram = True sd_model.lowvram = True
parents = {} parents = {}
......
This diff is collapsed.
...@@ -5,7 +5,7 @@ from types import MethodType ...@@ -5,7 +5,7 @@ from types import MethodType
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
from modules.hypernetworks import hypernetwork from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts from modules.shared import cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, sd_hijack_inpainting
import ldm.modules.attention import ldm.modules.attention
import ldm.modules.diffusionmodules.model import ldm.modules.diffusionmodules.model
...@@ -29,8 +29,12 @@ ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.Cros ...@@ -29,8 +29,12 @@ ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.Cros
ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
# silence new console spam from SD2 # silence new console spam from SD2
ldm.modules.attention.print = lambda *args: None ldm.modules.attention.print = shared.ldm_print
ldm.modules.diffusionmodules.model.print = lambda *args: None ldm.modules.diffusionmodules.model.print = shared.ldm_print
ldm.util.print = shared.ldm_print
ldm.models.diffusion.ddpm.print = shared.ldm_print
sd_hijack_inpainting.do_inpainting_hijack()
optimizers = [] optimizers = []
current_optimizer: sd_hijack_optimizations.SdOptimization = None current_optimizer: sd_hijack_optimizations.SdOptimization = None
......
...@@ -92,6 +92,4 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F ...@@ -92,6 +92,4 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
def do_inpainting_hijack(): def do_inpainting_hijack():
# p_sample_plms is needed because PLMS can't work with dicts as conditionings
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
This diff is collapsed.
...@@ -98,10 +98,10 @@ def extend_sdxl(model): ...@@ -98,10 +98,10 @@ def extend_sdxl(model):
model.conditioner.wrapped = torch.nn.Module() model.conditioner.wrapped = torch.nn.Module()
sgm.modules.attention.print = lambda *args: None sgm.modules.attention.print = shared.ldm_print
sgm.modules.diffusionmodules.model.print = lambda *args: None sgm.modules.diffusionmodules.model.print = shared.ldm_print
sgm.modules.diffusionmodules.openaimodel.print = lambda *args: None sgm.modules.diffusionmodules.openaimodel.print = shared.ldm_print
sgm.modules.encoders.modules.print = lambda *args: None sgm.modules.encoders.modules.print = shared.ldm_print
# this gets the code to load the vanilla attention that we override # this gets the code to load the vanilla attention that we override
sgm.modules.attention.SDP_IS_AVAILABLE = True sgm.modules.attention.SDP_IS_AVAILABLE = True
......
This diff is collapsed.
...@@ -9,7 +9,7 @@ from modules.ui import plaintext_to_html ...@@ -9,7 +9,7 @@ from modules.ui import plaintext_to_html
import gradio as gr import gradio as gr
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args): def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
override_settings = create_override_settings_dict(override_settings_texts) override_settings = create_override_settings_dict(override_settings_texts)
p = processing.StableDiffusionProcessingTxt2Img( p = processing.StableDiffusionProcessingTxt2Img(
...@@ -41,6 +41,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step ...@@ -41,6 +41,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
hr_second_pass_steps=hr_second_pass_steps, hr_second_pass_steps=hr_second_pass_steps,
hr_resize_x=hr_resize_x, hr_resize_x=hr_resize_x,
hr_resize_y=hr_resize_y, hr_resize_y=hr_resize_y,
hr_checkpoint_name=None if hr_checkpoint_name == 'Use same checkpoint' else hr_checkpoint_name,
hr_sampler_name=sd_samplers.samplers_for_img2img[hr_sampler_index - 1].name if hr_sampler_index != 0 else None, hr_sampler_name=sd_samplers.samplers_for_img2img[hr_sampler_index - 1].name if hr_sampler_index != 0 else None,
hr_prompt=hr_prompt, hr_prompt=hr_prompt,
hr_negative_prompt=hr_negative_prompt, hr_negative_prompt=hr_negative_prompt,
......
...@@ -261,7 +261,6 @@ class Toprow: ...@@ -261,7 +261,6 @@ class Toprow:
with gr.Row(): with gr.Row():
self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"]) self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
self.button_interrogate = None self.button_interrogate = None
self.button_deepbooru = None self.button_deepbooru = None
if is_img2img: if is_img2img:
...@@ -290,7 +289,6 @@ class Toprow: ...@@ -290,7 +289,6 @@ class Toprow:
with gr.Row(elem_id=f"{id_part}_tools"): with gr.Row(elem_id=f"{id_part}_tools"):
self.paste = ToolButton(value=paste_symbol, elem_id="paste") self.paste = ToolButton(value=paste_symbol, elem_id="paste")
self.clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt") self.clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
self.extra_networks_button = ToolButton(value=extra_networks_symbol, elem_id=f"{id_part}_extra_networks")
self.restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{id_part}_restore_progress", visible=False) self.restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{id_part}_restore_progress", visible=False)
self.token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"]) self.token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"])
...@@ -404,11 +402,10 @@ def create_ui(): ...@@ -404,11 +402,10 @@ def create_ui():
dummy_component = gr.Label(visible=False) dummy_component = gr.Label(visible=False)
with FormRow(variant='compact', elem_id="txt2img_extra_networks", visible=False) as extra_networks: extra_tabs = gr.Tabs(elem_id="txt2img_extra_tabs")
from modules import ui_extra_networks extra_tabs.__enter__()
extra_networks_ui = ui_extra_networks.create_ui(extra_networks, toprow.extra_networks_button, 'txt2img')
with gr.Row(equal_height=False): with gr.Tab("Generation", id="txt2img_generation") as txt2img_generation_tab, gr.Row().style(equal_height=False):
with gr.Column(variant='compact', elem_id="txt2img_settings"): with gr.Column(variant='compact', elem_id="txt2img_settings"):
scripts.scripts_txt2img.prepare_ui() scripts.scripts_txt2img.prepare_ui()
...@@ -456,6 +453,10 @@ def create_ui(): ...@@ -456,6 +453,10 @@ def create_ui():
hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y") hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y")
with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container: with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container:
hr_checkpoint_name = gr.Dropdown(label='Hires checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint")
create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh")
hr_sampler_index = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + [x.name for x in samplers_for_img2img], value="Use same sampler", type="index") hr_sampler_index = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + [x.name for x in samplers_for_img2img], value="Use same sampler", type="index")
with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container: with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:
...@@ -533,6 +534,7 @@ def create_ui(): ...@@ -533,6 +534,7 @@ def create_ui():
hr_second_pass_steps, hr_second_pass_steps,
hr_resize_x, hr_resize_x,
hr_resize_y, hr_resize_y,
hr_checkpoint_name,
hr_sampler_index, hr_sampler_index,
hr_prompt, hr_prompt,
hr_negative_prompt, hr_negative_prompt,
...@@ -599,8 +601,9 @@ def create_ui(): ...@@ -599,8 +601,9 @@ def create_ui():
(hr_second_pass_steps, "Hires steps"), (hr_second_pass_steps, "Hires steps"),
(hr_resize_x, "Hires resize-1"), (hr_resize_x, "Hires resize-1"),
(hr_resize_y, "Hires resize-2"), (hr_resize_y, "Hires resize-2"),
(hr_checkpoint_name, "Hires checkpoint"),
(hr_sampler_index, "Hires sampler"), (hr_sampler_index, "Hires sampler"),
(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" else gr.update()), (hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()),
(hr_prompt, "Hires prompt"), (hr_prompt, "Hires prompt"),
(hr_negative_prompt, "Hires negative prompt"), (hr_negative_prompt, "Hires negative prompt"),
(hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()), (hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()),
...@@ -625,7 +628,11 @@ def create_ui(): ...@@ -625,7 +628,11 @@ def create_ui():
toprow.token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps], outputs=[toprow.token_counter]) toprow.token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps], outputs=[toprow.token_counter])
toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter]) toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter])
ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery) from modules import ui_extra_networks
extra_networks_ui = ui_extra_networks.create_ui(txt2img_interface, [txt2img_generation_tab], 'txt2img')
ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery)
extra_tabs.__exit__()
scripts.scripts_current = scripts.scripts_img2img scripts.scripts_current = scripts.scripts_img2img
scripts.scripts_img2img.initialize_scripts(is_img2img=True) scripts.scripts_img2img.initialize_scripts(is_img2img=True)
...@@ -633,11 +640,10 @@ def create_ui(): ...@@ -633,11 +640,10 @@ def create_ui():
with gr.Blocks(analytics_enabled=False) as img2img_interface: with gr.Blocks(analytics_enabled=False) as img2img_interface:
toprow = Toprow(is_img2img=True) toprow = Toprow(is_img2img=True)
with FormRow(variant='compact', elem_id="img2img_extra_networks", visible=False) as extra_networks: extra_tabs = gr.Tabs(elem_id="img2img_extra_tabs")
from modules import ui_extra_networks extra_tabs.__enter__()
extra_networks_ui_img2img = ui_extra_networks.create_ui(extra_networks, toprow.extra_networks_button, 'img2img')
with FormRow(equal_height=False): with gr.Tab("Generation", id="img2img_generation") as img2img_generation_tab, FormRow().style(equal_height=False):
with gr.Column(variant='compact', elem_id="img2img_settings"): with gr.Column(variant='compact', elem_id="img2img_settings"):
copy_image_buttons = [] copy_image_buttons = []
copy_image_destinations = {} copy_image_destinations = {}
...@@ -663,15 +669,15 @@ def create_ui(): ...@@ -663,15 +669,15 @@ def create_ui():
add_copy_image_controls('img2img', init_img) add_copy_image_controls('img2img', init_img)
with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch: with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch:
sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA", height=opts.img2img_editor_height) sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_sketch_default_brush_color)
add_copy_image_controls('sketch', sketch) add_copy_image_controls('sketch', sketch)
with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint: with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint:
init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color='#ffffff') init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_mask_brush_color)
add_copy_image_controls('inpaint', init_img_with_mask) add_copy_image_controls('inpaint', init_img_with_mask)
with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color: with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color:
inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color='#ffffff') inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_sketch_default_brush_color)
inpaint_color_sketch_orig = gr.State(None) inpaint_color_sketch_orig = gr.State(None)
add_copy_image_controls('inpaint_sketch', inpaint_color_sketch) add_copy_image_controls('inpaint_sketch', inpaint_color_sketch)
...@@ -959,8 +965,6 @@ def create_ui(): ...@@ -959,8 +965,6 @@ def create_ui():
toprow.token_button.click(fn=update_token_counter, inputs=[toprow.prompt, steps], outputs=[toprow.token_counter]) toprow.token_button.click(fn=update_token_counter, inputs=[toprow.prompt, steps], outputs=[toprow.token_counter])
toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter]) toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter])
ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery)
img2img_paste_fields = [ img2img_paste_fields = [
(toprow.prompt, "Prompt"), (toprow.prompt, "Prompt"),
(toprow.negative_prompt, "Negative prompt"), (toprow.negative_prompt, "Negative prompt"),
...@@ -989,6 +993,12 @@ def create_ui(): ...@@ -989,6 +993,12 @@ def create_ui():
paste_button=toprow.paste, tabname="img2img", source_text_component=toprow.prompt, source_image_component=None, paste_button=toprow.paste, tabname="img2img", source_text_component=toprow.prompt, source_image_component=None,
)) ))
from modules import ui_extra_networks
extra_networks_ui_img2img = ui_extra_networks.create_ui(img2img_interface, [img2img_generation_tab], 'img2img')
ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery)
extra_tabs.__exit__()
scripts.scripts_current = None scripts.scripts_current = None
with gr.Blocks(analytics_enabled=False) as extras_interface: with gr.Blocks(analytics_enabled=False) as extras_interface:
......
...@@ -155,7 +155,7 @@ class ExtraNetworksPage: ...@@ -155,7 +155,7 @@ class ExtraNetworksPage:
subdirs = {"": 1, **subdirs} subdirs = {"": 1, **subdirs}
subdirs_html = "".join([f""" subdirs_html = "".join([f"""
<button class='lg secondary gradio-button custom-button{" search-all" if subdir=="" else ""}' onclick='extraNetworksSearchButton("{tabname}_extra_tabs", event)'> <button class='lg secondary gradio-button custom-button{" search-all" if subdir=="" else ""}' onclick='extraNetworksSearchButton("{tabname}_extra_search", event)'>
{html.escape(subdir if subdir!="" else "all")} {html.escape(subdir if subdir!="" else "all")}
</button> </button>
""" for subdir in subdirs]) """ for subdir in subdirs])
...@@ -347,7 +347,7 @@ def pages_in_preferred_order(pages): ...@@ -347,7 +347,7 @@ def pages_in_preferred_order(pages):
return sorted(pages, key=lambda x: tab_scores[x.name]) return sorted(pages, key=lambda x: tab_scores[x.name])
def create_ui(container, button, tabname): def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
ui = ExtraNetworksUi() ui = ExtraNetworksUi()
ui.pages = [] ui.pages = []
ui.pages_contents = [] ui.pages_contents = []
...@@ -355,48 +355,41 @@ def create_ui(container, button, tabname): ...@@ -355,48 +355,41 @@ def create_ui(container, button, tabname):
ui.stored_extra_pages = pages_in_preferred_order(extra_pages.copy()) ui.stored_extra_pages = pages_in_preferred_order(extra_pages.copy())
ui.tabname = tabname ui.tabname = tabname
with gr.Tabs(elem_id=tabname+"_extra_tabs"): related_tabs = []
for page in ui.stored_extra_pages:
with gr.Tab(page.title, id=page.id_page):
elem_id = f"{tabname}_{page.id_page}_cards_html"
page_elem = gr.HTML('Loading...', elem_id=elem_id)
ui.pages.append(page_elem)
page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + quote_js(tabname) + '); return []}', inputs=[], outputs=[]) for page in ui.stored_extra_pages:
with gr.Tab(page.title, id=page.id_page) as tab:
elem_id = f"{tabname}_{page.id_page}_cards_html"
page_elem = gr.HTML('Loading...', elem_id=elem_id)
ui.pages.append(page_elem)
editor = page.create_user_metadata_editor(ui, tabname) page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + quote_js(tabname) + '); return []}', inputs=[], outputs=[])
editor.create_ui()
ui.user_metadata_editors.append(editor)
gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False) editor = page.create_user_metadata_editor(ui, tabname)
gr.Dropdown(choices=['Default Sort', 'Date Created', 'Date Modified', 'Name'], value='Default Sort', elem_id=tabname+"_extra_sort", multiselect=False, visible=False, show_label=False, interactive=True) editor.create_ui()
ToolButton(up_down_symbol, elem_id=tabname+"_extra_sortorder") ui.user_metadata_editors.append(editor)
button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh")
related_tabs.append(tab)
edit_search = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True)
dropdown_sort = gr.Dropdown(choices=['Default Sort', 'Date Created', 'Date Modified', 'Name'], value='Default Sort', elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order")
button_sortorder = ToolButton(up_down_symbol, elem_id=tabname+"_extra_sortorder", elem_classes="sortorder", visible=False)
button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False)
ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False)
def toggle_visibility(is_visible): for tab in unrelated_tabs:
is_visible = not is_visible tab.select(fn=lambda: [gr.update(visible=False) for _ in range(5)], inputs=[], outputs=[edit_search, edit_search, dropdown_sort, button_sortorder, button_refresh], show_progress=False)
return is_visible, gr.update(visible=is_visible), gr.update(variant=("secondary-down" if is_visible else "secondary")) for tab in related_tabs:
tab.select(fn=lambda: [gr.update(visible=True) for _ in range(5)], inputs=[], outputs=[edit_search, edit_search, dropdown_sort, button_sortorder, button_refresh], show_progress=False)
def fill_tabs(is_empty):
"""Creates HTML for extra networks' tabs when the extra networks button is clicked for the first time."""
def pages_html():
if not ui.pages_contents: if not ui.pages_contents:
refresh() return refresh()
if is_empty:
return True, *ui.pages_contents
return True, *[gr.update() for _ in ui.pages_contents] return ui.pages_contents
state_visible = gr.State(value=False)
button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container, button], show_progress=False)
state_empty = gr.State(value=True)
button.click(fn=fill_tabs, inputs=[state_empty], outputs=[state_empty, *ui.pages], show_progress=False)
def refresh(): def refresh():
for pg in ui.stored_extra_pages: for pg in ui.stored_extra_pages:
...@@ -406,6 +399,7 @@ def create_ui(container, button, tabname): ...@@ -406,6 +399,7 @@ def create_ui(container, button, tabname):
return ui.pages_contents return ui.pages_contents
interface.load(fn=pages_html, inputs=[], outputs=[*ui.pages])
button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages) button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages)
return ui return ui
......
...@@ -779,9 +779,14 @@ footer { ...@@ -779,9 +779,14 @@ footer {
/* extra networks UI */ /* extra networks UI */
.extra-network-cards{ .extra-network-cards{
height: 725px; height: calc(100vh - 24rem);
overflow: scroll; overflow: clip scroll;
resize: vertical; resize: vertical;
min-height: 52rem;
}
.extra-networks > div.tab-nav{
min-height: 3.4rem;
} }
.extra-networks > div > [id *= '_extra_']{ .extra-networks > div > [id *= '_extra_']{
...@@ -797,7 +802,6 @@ footer { ...@@ -797,7 +802,6 @@ footer {
} }
.extra-networks .tab-nav .search, .extra-networks .tab-nav .search,
.extra-networks .tab-nav .sort{ .extra-networks .tab-nav .sort{
display: inline-block;
margin: 0.3em; margin: 0.3em;
align-self: center; align-self: center;
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment