Commit df02498d authored by AUTOMATIC's avatar AUTOMATIC

add an option to show selected setting in main txt2img/img2img UI

split some code from ui.py into ui_settings.py ui_gradio_edxtensions.py
add before_process callback for scripts
add ability for alwayson scripts to specify section and let user reorder those sections
parent 583fb9f0
import gradio as gr
from modules import scripts, shared, ui_components, ui_settings
from modules.ui_components import FormColumn
class ExtraOptionsSection(scripts.Script):
section = "extra_options"
def __init__(self):
self.comps = None
self.setting_names = None
def title(self):
return "Extra options"
def show(self, is_img2img):
return scripts.AlwaysVisible
def ui(self, is_img2img):
self.comps = []
self.setting_names = []
with gr.Blocks() as interface:
with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and len(shared.opts.extra_options) > 0 else gr.Group(), gr.Row():
for setting_name in shared.opts.extra_options:
with FormColumn():
comp = ui_settings.create_setting_component(setting_name)
self.comps.append(comp)
self.setting_names.append(setting_name)
def get_settings_values():
return [ui_settings.get_value_for_setting(key) for key in self.setting_names]
interface.load(fn=get_settings_values, inputs=[], outputs=self.comps, queue=False, show_progress=False)
return self.comps
def before_process(self, p, *args):
for name, value in zip(self.setting_names, args):
if name not in p.override_settings:
p.override_settings[name] = value
shared.options_templates.update(shared.options_section(('ui', "User interface"), {
"extra_options": shared.OptionInfo([], "Options in main UI", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img/img2img interfaces").needs_restart(),
"extra_options_accordion": shared.OptionInfo(False, "Place options in main UI into an accordion")
}))
...@@ -588,11 +588,15 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter ...@@ -588,11 +588,15 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
def process_images(p: StableDiffusionProcessing) -> Processed: def process_images(p: StableDiffusionProcessing) -> Processed:
if p.scripts is not None:
p.scripts.before_process(p)
stored_opts = {k: opts.data[k] for k in p.override_settings.keys()} stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}
try: try:
# if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint # if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
if sd_models.checkpoint_alisases.get(p.override_settings.get('sd_model_checkpoint')) is None: override_checkpoint = p.override_settings.get('sd_model_checkpoint')
if override_checkpoint is not None and sd_models.checkpoint_alisases.get(override_checkpoint) is None:
p.override_settings.pop('sd_model_checkpoint', None) p.override_settings.pop('sd_model_checkpoint', None)
sd_models.reload_model_weights() sd_models.reload_model_weights()
......
...@@ -19,6 +19,9 @@ class Script: ...@@ -19,6 +19,9 @@ class Script:
name = None name = None
"""script's internal name derived from title""" """script's internal name derived from title"""
section = None
"""name of UI section that the script's controls will be placed into"""
filename = None filename = None
args_from = None args_from = None
args_to = None args_to = None
...@@ -81,6 +84,15 @@ class Script: ...@@ -81,6 +84,15 @@ class Script:
pass pass
def before_process(self, p, *args):
"""
This function is called very early before processing begins for AlwaysVisible scripts.
You can modify the processing object (p) here, inject hooks, etc.
args contains all values returned by components from ui()
"""
pass
def process(self, p, *args): def process(self, p, *args):
""" """
This function is called before processing begins for AlwaysVisible scripts. This function is called before processing begins for AlwaysVisible scripts.
...@@ -293,6 +305,7 @@ class ScriptRunner: ...@@ -293,6 +305,7 @@ class ScriptRunner:
self.titles = [] self.titles = []
self.infotext_fields = [] self.infotext_fields = []
self.paste_field_names = [] self.paste_field_names = []
self.inputs = [None]
def initialize_scripts(self, is_img2img): def initialize_scripts(self, is_img2img):
from modules import scripts_auto_postprocessing from modules import scripts_auto_postprocessing
...@@ -320,17 +333,11 @@ class ScriptRunner: ...@@ -320,17 +333,11 @@ class ScriptRunner:
self.scripts.append(script) self.scripts.append(script)
self.selectable_scripts.append(script) self.selectable_scripts.append(script)
def setup_ui(self): def create_script_ui(self, script):
import modules.api.models as api_models import modules.api.models as api_models
self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts] script.args_from = len(self.inputs)
script.args_to = len(self.inputs)
inputs = [None]
inputs_alwayson = [True]
def create_script_ui(script, inputs, inputs_alwayson):
script.args_from = len(inputs)
script.args_to = len(inputs)
controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img) controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
...@@ -365,24 +372,34 @@ class ScriptRunner: ...@@ -365,24 +372,34 @@ class ScriptRunner:
if script.paste_field_names is not None: if script.paste_field_names is not None:
self.paste_field_names += script.paste_field_names self.paste_field_names += script.paste_field_names
inputs += controls self.inputs += controls
inputs_alwayson += [script.alwayson for _ in controls] script.args_to = len(self.inputs)
script.args_to = len(inputs)
for script in self.alwayson_scripts: def setup_ui_for_section(self, section, scriptlist=None):
with gr.Group() as group: if scriptlist is None:
create_script_ui(script, inputs, inputs_alwayson) scriptlist = self.alwayson_scripts
for script in scriptlist:
if script.alwayson and script.section != section:
continue
with gr.Group(visible=script.alwayson) as group:
self.create_script_ui(script)
script.group = group script.group = group
dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index") def prepare_ui(self):
inputs[0] = dropdown self.inputs = [None]
for script in self.selectable_scripts: def setup_ui(self):
with gr.Group(visible=False) as group: self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
create_script_ui(script, inputs, inputs_alwayson)
script.group = group self.setup_ui_for_section(None)
dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index")
self.inputs[0] = dropdown
self.setup_ui_for_section(None, self.selectable_scripts)
def select_script(script_index): def select_script(script_index):
selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None
...@@ -407,6 +424,7 @@ class ScriptRunner: ...@@ -407,6 +424,7 @@ class ScriptRunner:
) )
self.script_load_ctr = 0 self.script_load_ctr = 0
def onload_script_visibility(params): def onload_script_visibility(params):
title = params.get('Script', None) title = params.get('Script', None)
if title: if title:
...@@ -417,10 +435,10 @@ class ScriptRunner: ...@@ -417,10 +435,10 @@ class ScriptRunner:
else: else:
return gr.update(visible=False) return gr.update(visible=False)
self.infotext_fields.append( (dropdown, lambda x: gr.update(value=x.get('Script', 'None'))) ) self.infotext_fields.append((dropdown, lambda x: gr.update(value=x.get('Script', 'None'))))
self.infotext_fields.extend( [(script.group, onload_script_visibility) for script in self.selectable_scripts] ) self.infotext_fields.extend([(script.group, onload_script_visibility) for script in self.selectable_scripts])
return inputs return self.inputs
def run(self, p, *args): def run(self, p, *args):
script_index = args[0] script_index = args[0]
...@@ -440,6 +458,14 @@ class ScriptRunner: ...@@ -440,6 +458,14 @@ class ScriptRunner:
return processed return processed
def before_process(self, p):
for script in self.alwayson_scripts:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.before_process(p, *script_args)
except Exception:
errors.report(f"Error running before_process: {script.filename}", exc_info=True)
def process(self, p): def process(self, p):
for script in self.alwayson_scripts: for script in self.alwayson_scripts:
try: try:
......
...@@ -55,5 +55,15 @@ ui_reorder_categories_builtin_items = [ ...@@ -55,5 +55,15 @@ ui_reorder_categories_builtin_items = [
def ui_reorder_categories(): def ui_reorder_categories():
from modules import scripts
yield from ui_reorder_categories_builtin_items yield from ui_reorder_categories_builtin_items
sections = {}
for script in scripts.scripts_txt2img.scripts + scripts.scripts_img2img.scripts:
if isinstance(script.section, str):
sections[script.section] = 1
yield from sections
yield "scripts" yield "scripts"
This diff is collapsed.
...@@ -10,8 +10,11 @@ import subprocess as sp ...@@ -10,8 +10,11 @@ import subprocess as sp
from modules import call_queue, shared from modules import call_queue, shared
from modules.generation_parameters_copypaste import image_from_url_text from modules.generation_parameters_copypaste import image_from_url_text
import modules.images import modules.images
from modules.ui_components import ToolButton
folder_symbol = '\U0001f4c2' # 📂 folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄
def update_generation_info(generation_info, html_info, img_index): def update_generation_info(generation_info, html_info, img_index):
...@@ -216,3 +219,23 @@ Requested path was: {f} ...@@ -216,3 +219,23 @@ Requested path was: {f}
)) ))
return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
def refresh():
refresh_method()
args = refreshed_args() if callable(refreshed_args) else refreshed_args
for k, v in args.items():
setattr(refresh_component, k, v)
return gr.update(**(args or {}))
refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id)
refresh_button.click(
fn=refresh,
inputs=[],
outputs=[refresh_component]
)
return refresh_button
import os
import gradio as gr
from modules import localization, shared, scripts
from modules.paths import script_path, data_path
def webpath(fn):
if fn.startswith(script_path):
web_path = os.path.relpath(fn, script_path).replace('\\', '/')
else:
web_path = os.path.abspath(fn)
return f'file={web_path}?{os.path.getmtime(fn)}'
def javascript_html():
# Ensure localization is in `window` before scripts
head = f'<script type="text/javascript">{localization.localization_js(shared.opts.localization)}</script>\n'
script_js = os.path.join(script_path, "script.js")
head += f'<script type="text/javascript" src="{webpath(script_js)}"></script>\n'
for script in scripts.list_scripts("javascript", ".js"):
head += f'<script type="text/javascript" src="{webpath(script.path)}"></script>\n'
for script in scripts.list_scripts("javascript", ".mjs"):
head += f'<script type="module" src="{webpath(script.path)}"></script>\n'
if shared.cmd_opts.theme:
head += f'<script type="text/javascript">set_theme(\"{shared.cmd_opts.theme}\");</script>\n'
return head
def css_html():
head = ""
def stylesheet(fn):
return f'<link rel="stylesheet" property="stylesheet" href="{webpath(fn)}">'
for cssfile in scripts.list_files_with_name("style.css"):
if not os.path.isfile(cssfile):
continue
head += stylesheet(cssfile)
if os.path.exists(os.path.join(data_path, "user.css")):
head += stylesheet(os.path.join(data_path, "user.css"))
return head
def reload_javascript():
js = javascript_html()
css = css_html()
def template_response(*args, **kwargs):
res = shared.GradioTemplateResponseOriginal(*args, **kwargs)
res.body = res.body.replace(b'</head>', f'{js}</head>'.encode("utf8"))
res.body = res.body.replace(b'</body>', f'{css}</body>'.encode("utf8"))
res.init_headers()
return res
gr.routes.templates.TemplateResponse = template_response
if not hasattr(shared, 'GradioTemplateResponseOriginal'):
shared.GradioTemplateResponseOriginal = gr.routes.templates.TemplateResponse
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment