Commit 7e5e6733 authored by AUTOMATIC1111's avatar AUTOMATIC1111

add UI for reordering callbacks

parent 0411eced
...@@ -8,7 +8,7 @@ from typing import Optional, Any ...@@ -8,7 +8,7 @@ from typing import Optional, Any
from fastapi import FastAPI from fastapi import FastAPI
from gradio import Blocks from gradio import Blocks
from modules import errors, timer, extensions from modules import errors, timer, extensions, shared
def report_exception(c, job): def report_exception(c, job):
...@@ -124,9 +124,10 @@ class ScriptCallback: ...@@ -124,9 +124,10 @@ class ScriptCallback:
name: str = None name: str = None
def add_callback(callbacks, fun, *, name=None, category='unknown'): def add_callback(callbacks, fun, *, name=None, category='unknown', filename=None):
stack = [x for x in inspect.stack() if x.filename != __file__] if filename is None:
filename = stack[0].filename if stack else 'unknown file' stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if stack else 'unknown file'
extension = extensions.find_extension(filename) extension = extensions.find_extension(filename)
extension_name = extension.canonical_name if extension else 'base' extension_name = extension.canonical_name if extension else 'base'
...@@ -146,6 +147,43 @@ def add_callback(callbacks, fun, *, name=None, category='unknown'): ...@@ -146,6 +147,43 @@ def add_callback(callbacks, fun, *, name=None, category='unknown'):
callbacks.append(ScriptCallback(filename, fun, unique_callback_name)) callbacks.append(ScriptCallback(filename, fun, unique_callback_name))
def sort_callbacks(category, unordered_callbacks, *, enable_user_sort=True):
callbacks = unordered_callbacks.copy()
if enable_user_sort:
for name in reversed(getattr(shared.opts, 'prioritized_callbacks_' + category, [])):
index = next((i for i, callback in enumerate(callbacks) if callback.name == name), None)
if index is not None:
callbacks.insert(0, callbacks.pop(index))
return callbacks
def ordered_callbacks(category, unordered_callbacks=None, *, enable_user_sort=True):
if unordered_callbacks is None:
unordered_callbacks = callback_map.get('callbacks_' + category, [])
if not enable_user_sort:
return sort_callbacks(category, unordered_callbacks, enable_user_sort=False)
callbacks = ordered_callbacks_map.get(category)
if callbacks is not None and len(callbacks) == len(unordered_callbacks):
return callbacks
callbacks = sort_callbacks(category, unordered_callbacks)
ordered_callbacks_map[category] = callbacks
return callbacks
def enumerate_callbacks():
for category, callbacks in callback_map.items():
if category.startswith('callbacks_'):
category = category[10:]
yield category, callbacks
callback_map = dict( callback_map = dict(
callbacks_app_started=[], callbacks_app_started=[],
callbacks_model_loaded=[], callbacks_model_loaded=[],
...@@ -170,14 +208,18 @@ callback_map = dict( ...@@ -170,14 +208,18 @@ callback_map = dict(
callbacks_before_token_counter=[], callbacks_before_token_counter=[],
) )
ordered_callbacks_map = {}
def clear_callbacks(): def clear_callbacks():
for callback_list in callback_map.values(): for callback_list in callback_map.values():
callback_list.clear() callback_list.clear()
ordered_callbacks_map.clear()
def app_started_callback(demo: Optional[Blocks], app: FastAPI): def app_started_callback(demo: Optional[Blocks], app: FastAPI):
for c in callback_map['callbacks_app_started']: for c in ordered_callbacks('app_started'):
try: try:
c.callback(demo, app) c.callback(demo, app)
timer.startup_timer.record(os.path.basename(c.script)) timer.startup_timer.record(os.path.basename(c.script))
...@@ -186,7 +228,7 @@ def app_started_callback(demo: Optional[Blocks], app: FastAPI): ...@@ -186,7 +228,7 @@ def app_started_callback(demo: Optional[Blocks], app: FastAPI):
def app_reload_callback(): def app_reload_callback():
for c in callback_map['callbacks_on_reload']: for c in ordered_callbacks('on_reload'):
try: try:
c.callback() c.callback()
except Exception: except Exception:
...@@ -194,7 +236,7 @@ def app_reload_callback(): ...@@ -194,7 +236,7 @@ def app_reload_callback():
def model_loaded_callback(sd_model): def model_loaded_callback(sd_model):
for c in callback_map['callbacks_model_loaded']: for c in ordered_callbacks('model_loaded'):
try: try:
c.callback(sd_model) c.callback(sd_model)
except Exception: except Exception:
...@@ -204,7 +246,7 @@ def model_loaded_callback(sd_model): ...@@ -204,7 +246,7 @@ def model_loaded_callback(sd_model):
def ui_tabs_callback(): def ui_tabs_callback():
res = [] res = []
for c in callback_map['callbacks_ui_tabs']: for c in ordered_callbacks('ui_tabs'):
try: try:
res += c.callback() or [] res += c.callback() or []
except Exception: except Exception:
...@@ -214,7 +256,7 @@ def ui_tabs_callback(): ...@@ -214,7 +256,7 @@ def ui_tabs_callback():
def ui_train_tabs_callback(params: UiTrainTabParams): def ui_train_tabs_callback(params: UiTrainTabParams):
for c in callback_map['callbacks_ui_train_tabs']: for c in ordered_callbacks('ui_train_tabs'):
try: try:
c.callback(params) c.callback(params)
except Exception: except Exception:
...@@ -222,7 +264,7 @@ def ui_train_tabs_callback(params: UiTrainTabParams): ...@@ -222,7 +264,7 @@ def ui_train_tabs_callback(params: UiTrainTabParams):
def ui_settings_callback(): def ui_settings_callback():
for c in callback_map['callbacks_ui_settings']: for c in ordered_callbacks('ui_settings'):
try: try:
c.callback() c.callback()
except Exception: except Exception:
...@@ -230,7 +272,7 @@ def ui_settings_callback(): ...@@ -230,7 +272,7 @@ def ui_settings_callback():
def before_image_saved_callback(params: ImageSaveParams): def before_image_saved_callback(params: ImageSaveParams):
for c in callback_map['callbacks_before_image_saved']: for c in ordered_callbacks('before_image_saved'):
try: try:
c.callback(params) c.callback(params)
except Exception: except Exception:
...@@ -238,7 +280,7 @@ def before_image_saved_callback(params: ImageSaveParams): ...@@ -238,7 +280,7 @@ def before_image_saved_callback(params: ImageSaveParams):
def image_saved_callback(params: ImageSaveParams): def image_saved_callback(params: ImageSaveParams):
for c in callback_map['callbacks_image_saved']: for c in ordered_callbacks('image_saved'):
try: try:
c.callback(params) c.callback(params)
except Exception: except Exception:
...@@ -246,7 +288,7 @@ def image_saved_callback(params: ImageSaveParams): ...@@ -246,7 +288,7 @@ def image_saved_callback(params: ImageSaveParams):
def extra_noise_callback(params: ExtraNoiseParams): def extra_noise_callback(params: ExtraNoiseParams):
for c in callback_map['callbacks_extra_noise']: for c in ordered_callbacks('extra_noise'):
try: try:
c.callback(params) c.callback(params)
except Exception: except Exception:
...@@ -254,7 +296,7 @@ def extra_noise_callback(params: ExtraNoiseParams): ...@@ -254,7 +296,7 @@ def extra_noise_callback(params: ExtraNoiseParams):
def cfg_denoiser_callback(params: CFGDenoiserParams): def cfg_denoiser_callback(params: CFGDenoiserParams):
for c in callback_map['callbacks_cfg_denoiser']: for c in ordered_callbacks('cfg_denoiser'):
try: try:
c.callback(params) c.callback(params)
except Exception: except Exception:
...@@ -262,7 +304,7 @@ def cfg_denoiser_callback(params: CFGDenoiserParams): ...@@ -262,7 +304,7 @@ def cfg_denoiser_callback(params: CFGDenoiserParams):
def cfg_denoised_callback(params: CFGDenoisedParams): def cfg_denoised_callback(params: CFGDenoisedParams):
for c in callback_map['callbacks_cfg_denoised']: for c in ordered_callbacks('cfg_denoised'):
try: try:
c.callback(params) c.callback(params)
except Exception: except Exception:
...@@ -270,7 +312,7 @@ def cfg_denoised_callback(params: CFGDenoisedParams): ...@@ -270,7 +312,7 @@ def cfg_denoised_callback(params: CFGDenoisedParams):
def cfg_after_cfg_callback(params: AfterCFGCallbackParams): def cfg_after_cfg_callback(params: AfterCFGCallbackParams):
for c in callback_map['callbacks_cfg_after_cfg']: for c in ordered_callbacks('cfg_after_cfg'):
try: try:
c.callback(params) c.callback(params)
except Exception: except Exception:
...@@ -278,7 +320,7 @@ def cfg_after_cfg_callback(params: AfterCFGCallbackParams): ...@@ -278,7 +320,7 @@ def cfg_after_cfg_callback(params: AfterCFGCallbackParams):
def before_component_callback(component, **kwargs): def before_component_callback(component, **kwargs):
for c in callback_map['callbacks_before_component']: for c in ordered_callbacks('before_component'):
try: try:
c.callback(component, **kwargs) c.callback(component, **kwargs)
except Exception: except Exception:
...@@ -286,7 +328,7 @@ def before_component_callback(component, **kwargs): ...@@ -286,7 +328,7 @@ def before_component_callback(component, **kwargs):
def after_component_callback(component, **kwargs): def after_component_callback(component, **kwargs):
for c in callback_map['callbacks_after_component']: for c in ordered_callbacks('after_component'):
try: try:
c.callback(component, **kwargs) c.callback(component, **kwargs)
except Exception: except Exception:
...@@ -294,7 +336,7 @@ def after_component_callback(component, **kwargs): ...@@ -294,7 +336,7 @@ def after_component_callback(component, **kwargs):
def image_grid_callback(params: ImageGridLoopParams): def image_grid_callback(params: ImageGridLoopParams):
for c in callback_map['callbacks_image_grid']: for c in ordered_callbacks('image_grid'):
try: try:
c.callback(params) c.callback(params)
except Exception: except Exception:
...@@ -302,7 +344,7 @@ def image_grid_callback(params: ImageGridLoopParams): ...@@ -302,7 +344,7 @@ def image_grid_callback(params: ImageGridLoopParams):
def infotext_pasted_callback(infotext: str, params: dict[str, Any]): def infotext_pasted_callback(infotext: str, params: dict[str, Any]):
for c in callback_map['callbacks_infotext_pasted']: for c in ordered_callbacks('infotext_pasted'):
try: try:
c.callback(infotext, params) c.callback(infotext, params)
except Exception: except Exception:
...@@ -310,7 +352,7 @@ def infotext_pasted_callback(infotext: str, params: dict[str, Any]): ...@@ -310,7 +352,7 @@ def infotext_pasted_callback(infotext: str, params: dict[str, Any]):
def script_unloaded_callback(): def script_unloaded_callback():
for c in reversed(callback_map['callbacks_script_unloaded']): for c in reversed(ordered_callbacks('script_unloaded')):
try: try:
c.callback() c.callback()
except Exception: except Exception:
...@@ -318,7 +360,7 @@ def script_unloaded_callback(): ...@@ -318,7 +360,7 @@ def script_unloaded_callback():
def before_ui_callback(): def before_ui_callback():
for c in reversed(callback_map['callbacks_before_ui']): for c in reversed(ordered_callbacks('before_ui')):
try: try:
c.callback() c.callback()
except Exception: except Exception:
...@@ -328,7 +370,7 @@ def before_ui_callback(): ...@@ -328,7 +370,7 @@ def before_ui_callback():
def list_optimizers_callback(): def list_optimizers_callback():
res = [] res = []
for c in callback_map['callbacks_list_optimizers']: for c in ordered_callbacks('list_optimizers'):
try: try:
c.callback(res) c.callback(res)
except Exception: except Exception:
...@@ -340,7 +382,7 @@ def list_optimizers_callback(): ...@@ -340,7 +382,7 @@ def list_optimizers_callback():
def list_unets_callback(): def list_unets_callback():
res = [] res = []
for c in callback_map['callbacks_list_unets']: for c in ordered_callbacks('list_unets'):
try: try:
c.callback(res) c.callback(res)
except Exception: except Exception:
...@@ -350,7 +392,7 @@ def list_unets_callback(): ...@@ -350,7 +392,7 @@ def list_unets_callback():
def before_token_counter_callback(params: BeforeTokenCounterParams): def before_token_counter_callback(params: BeforeTokenCounterParams):
for c in callback_map['callbacks_before_token_counter']: for c in ordered_callbacks('before_token_counter'):
try: try:
c.callback(params) c.callback(params)
except Exception: except Exception:
......
This diff is collapsed.
import html
import sys import sys
from modules import script_callbacks, scripts, ui_components
from modules.options import OptionHTML, OptionInfo
from modules.shared_cmd_options import cmd_opts from modules.shared_cmd_options import cmd_opts
...@@ -118,6 +121,45 @@ def ui_reorder_categories(): ...@@ -118,6 +121,45 @@ def ui_reorder_categories():
yield "scripts" yield "scripts"
def callbacks_order_settings():
options = {
"sd_vae_explanation": OptionHTML("""
For categories below, callbacks added to dropdowns happen before others, in order listed.
"""),
}
callback_options = {}
for category, _ in script_callbacks.enumerate_callbacks():
callback_options[category] = script_callbacks.ordered_callbacks(category, enable_user_sort=False)
for method_name in scripts.scripts_txt2img.callback_names:
callback_options["script_" + method_name] = scripts.scripts_txt2img.create_ordered_callbacks_list(method_name, enable_user_sort=False)
for method_name in scripts.scripts_img2img.callback_names:
callbacks = callback_options.get("script_" + method_name, [])
for addition in scripts.scripts_img2img.create_ordered_callbacks_list(method_name, enable_user_sort=False):
if any(x.name == addition.name for x in callbacks):
continue
callbacks.append(addition)
callback_options["script_" + method_name] = callbacks
for category, callbacks in callback_options.items():
if not callbacks:
continue
option_info = OptionInfo([], f"{category} callback priority", ui_components.DropdownMulti, {"choices": [x.name for x in callbacks]})
option_info.needs_restart()
option_info.html("<div class='info'>Default order: <ol>" + "".join(f"<li>{html.escape(x.name)}</li>\n" for x in callbacks) + "</ol></div>")
options['prioritized_callbacks_' + category] = option_info
return options
class Shared(sys.modules[__name__].__class__): class Shared(sys.modules[__name__].__class__):
""" """
this class is here to provide sd_model field as a property, so that it can be created and loaded on demand rather than this class is here to provide sd_model field as a property, so that it can be created and loaded on demand rather than
......
import gradio as gr import gradio as gr
from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo, timer from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo, timer, shared_items
from modules.call_queue import wrap_gradio_call from modules.call_queue import wrap_gradio_call
from modules.options import options_section
from modules.shared import opts from modules.shared import opts
from modules.ui_components import FormRow from modules.ui_components import FormRow
from modules.ui_gradio_extensions import reload_javascript from modules.ui_gradio_extensions import reload_javascript
...@@ -108,6 +109,11 @@ class UiSettings: ...@@ -108,6 +109,11 @@ class UiSettings:
shared.settings_components = self.component_dict shared.settings_components = self.component_dict
# we add this as late as possible so that scripts have already registered their callbacks
opts.data_labels.update(options_section(('callbacks', "Callbacks", "system"), {
**shared_items.callbacks_order_settings(),
}))
opts.reorder() opts.reorder()
with gr.Blocks(analytics_enabled=False) as settings_interface: with gr.Blocks(analytics_enabled=False) as settings_interface:
......
...@@ -528,6 +528,10 @@ table.popup-table .link{ ...@@ -528,6 +528,10 @@ table.popup-table .link{
opacity: 0.75; opacity: 0.75;
} }
.settings-comment .info ol{
margin: 0.4em 0 0.8em 1em;
}
#sysinfo_download a.sysinfo_big_link{ #sysinfo_download a.sysinfo_big_link{
font-size: 24pt; font-size: 24pt;
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment