Commit 4f5f7865 authored by Connum's avatar Connum

make UI restraints (currently sampling method only) more flexible and reusable across scripts

parent c1493632
...@@ -185,38 +185,46 @@ onUiUpdate(function(){ ...@@ -185,38 +185,46 @@ onUiUpdate(function(){
}) })
/** /**
* force Euler method for the "img2img alternative test" script * Implement script-dependent UI restraints, e.g. forcing a specific sampling method
*/ */
let prev_sampling_method; let prev_ui_states = {};
onUiTabChange(function() { function updateScriptRestraints() {
const currentTab = get_uiCurrentTab(); const currentTab = get_uiCurrentTab()?.textContent.trim();
if ( ! currentTab || currentTab?.textContent.trim() !== 'img2img' ) { const restraintsField = Array.from(gradioApp().querySelectorAll(`#${currentTab}_script_restraints_json textarea`))
.filter(el => uiElementIsVisible(el.closest('.gr-form')))?.[0];
if ( ! restraintsField ) {
return; return;
} }
const altScriptName = 'img2img alternative test'; if ( typeof prev_ui_states[currentTab] === 'undefined' ) {
const scriptSelect = gradioApp().querySelector('#component-223 select'); prev_ui_states[currentTab] = {};
const methodRadios = gradioApp().querySelectorAll('[name="radio-component-182"]'); }
scriptSelect.addEventListener( 'change', function() {
if( scriptSelect.value === altScriptName) { window.requestAnimationFrame(() => {
prev_sampling_method = gradioApp().querySelector('[name="radio-component-182"]:checked'); const restraints = JSON.parse(restraintsField.value);
// const scriptSelect = gradioApp().querySelector(`#${currentTab}_scripts select`);
const methodRadios = gradioApp().querySelectorAll(`[name="radio-${currentTab}_sampling"]`);
if( restraints?.methods?.length ) {
prev_ui_states[currentTab].sampling_method = gradioApp().querySelector(`[name="radio-${currentTab}_sampling"]:checked`);
methodRadios.forEach(radio => { methodRadios.forEach(radio => {
const isEuler = radio.value === 'Euler'; const isAllowed = restraints.methods.includes(radio.value);
const label = radio.closest('label'); const label = radio.closest('label');
radio.disabled = !isEuler; radio.disabled = !isAllowed;
radio.checked = isEuler; radio.checked = isAllowed;
label.classList[isEuler ? 'remove' : 'add']('!cursor-not-allowed'); label.classList[isAllowed ? 'remove' : 'add']('!cursor-not-allowed','disabled');
label.title = !isEuler ? `${altScriptName} only works with the Euler method` : ''; label.title = !isAllowed ? `The selected script does not work with this method` : '';
}); });
} else { } else {
// reset to previous method // reset to previously selected method
methodRadios.forEach(radio => { methodRadios.forEach(radio => {
const label = radio.closest('label'); const label = radio.closest('label');
radio.disabled = false; radio.disabled = false;
radio.checked = radio === prev_sampling_method; radio.checked = radio === prev_ui_states[currentTab].sampling_method;
label.classList.remove('!cursor-not-allowed'); label.classList.remove('!cursor-not-allowed','disabled');
label.title = ''; label.title = '';
}); });
} }
}); })
}) }
\ No newline at end of file \ No newline at end of file
import os import os
import sys import sys
import traceback import traceback
import json
import modules.ui as ui import modules.ui as ui
import gradio as gr import gradio as gr
...@@ -24,6 +25,14 @@ class Script: ...@@ -24,6 +25,14 @@ class Script:
def ui(self, is_img2img): def ui(self, is_img2img):
pass pass
# Put restraints on UI elements when this script is selected.
# Restricting the available sampling methods:
# {
# "methods": [ "Euler", "DDIM" ]
# }
def ui_restraints(self):
return {}
# Determines when the script should be shown in the dropdown menu via the # Determines when the script should be shown in the dropdown menu via the
# returned value. As an example: # returned value. As an example:
# is_img2img is True if the current tab is img2img, and False if it is txt2img. # is_img2img is True if the current tab is img2img, and False if it is txt2img.
...@@ -106,7 +115,9 @@ class ScriptRunner: ...@@ -106,7 +115,9 @@ class ScriptRunner:
titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.scripts] titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.scripts]
dropdown = gr.Dropdown(label="Script", choices=["None"] + titles, value="None", type="index") id_prefix = "img2img_" if is_img2img else "txt2img_"
dropdown = gr.Dropdown(label="Script", choices=["None"] + titles, value="None", type="index", elem_id=id_prefix+"scripts")
inputs = [dropdown] inputs = [dropdown]
for script in self.scripts: for script in self.scripts:
...@@ -125,16 +136,23 @@ class ScriptRunner: ...@@ -125,16 +136,23 @@ class ScriptRunner:
inputs += controls inputs += controls
script.args_to = len(inputs) script.args_to = len(inputs)
script_restraints_json = gr.Textbox(value="{}", elem_id=id_prefix+"script_restraints_json", show_label=False, visible=False)
inputs += [script_restraints_json];
def select_script(script_index): def select_script(script_index):
if 0 < script_index <= len(self.scripts): if 0 < script_index <= len(self.scripts):
script = self.scripts[script_index-1] script = self.scripts[script_index-1]
args_from = script.args_from args_from = script.args_from
args_to = script.args_to args_to = script.args_to
else: else:
script = None
args_from = 0 args_from = 0
args_to = 0 args_to = 0
return [ui.gr_show(True if i == 0 else args_from <= i < args_to) for i in range(len(inputs))] return (
[ui.gr_show(True if i == 0 else args_from <= i < args_to) for i in range(len(inputs)-1)]
+ [gr.Textbox.update(value=json.dumps(script.ui_restraints() if script is not None else {}), visible=False)]
)
dropdown.change( dropdown.change(
fn=select_script, fn=select_script,
...@@ -142,6 +160,13 @@ class ScriptRunner: ...@@ -142,6 +160,13 @@ class ScriptRunner:
outputs=inputs outputs=inputs
) )
script_restraints_json.change(
_js="updateScriptRestraints",
fn=lambda: None,
inputs=[],
outputs=[]
)
return inputs return inputs
def run(self, p: StableDiffusionProcessing, *args): def run(self, p: StableDiffusionProcessing, *args):
......
...@@ -409,7 +409,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ...@@ -409,7 +409,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
with gr.Row().style(equal_height=False): with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'): with gr.Column(variant='panel'):
steps = gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20) steps = gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20, elem_id="txt2img_steps")
sampler_index = gr.Radio(label='Sampling method', elem_id="txt2img_sampling", choices=[x.name for x in samplers], value=samplers[0].name, type="index") sampler_index = gr.Radio(label='Sampling method', elem_id="txt2img_sampling", choices=[x.name for x in samplers], value=samplers[0].name, type="index")
with gr.Group(): with gr.Group():
...@@ -588,8 +588,8 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ...@@ -588,8 +588,8 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
with gr.Row(): with gr.Row():
resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", show_label=False, choices=["Just resize", "Crop and resize", "Resize and fill"], type="index", value="Just resize") resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", show_label=False, choices=["Just resize", "Crop and resize", "Resize and fill"], type="index", value="Just resize")
steps = gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20) steps = gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20, elem_id="img2img_steps")
sampler_index = gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index") sampler_index = gr.Radio(label='Sampling method', elem_id="img2img_sampling", choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index")
with gr.Group(): with gr.Group():
width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
......
...@@ -129,6 +129,12 @@ class Script(scripts.Script): ...@@ -129,6 +129,12 @@ class Script(scripts.Script):
sigma_adjustment = gr.Checkbox(label="Sigma adjustment for finding noise for image", value=False) sigma_adjustment = gr.Checkbox(label="Sigma adjustment for finding noise for image", value=False)
return [original_prompt, original_negative_prompt, cfg, st, randomness, sigma_adjustment] return [original_prompt, original_negative_prompt, cfg, st, randomness, sigma_adjustment]
def ui_restraints(self):
restraints = {
"methods": ["Euler"]
}
return restraints
def run(self, p, original_prompt, original_negative_prompt, cfg, st, randomness, sigma_adjustment): def run(self, p, original_prompt, original_negative_prompt, cfg, st, randomness, sigma_adjustment):
p.batch_size = 1 p.batch_size = 1
p.batch_count = 1 p.batch_count = 1
......
...@@ -222,6 +222,10 @@ input[type="range"]{ ...@@ -222,6 +222,10 @@ input[type="range"]{
margin: 0.5em 0 -0.3em 0; margin: 0.5em 0 -0.3em 0;
} }
.gr-input-label.disabled {
opacity: 0.48;
}
#txt2img_sampling label{ #txt2img_sampling label{
padding-left: 0.6em; padding-left: 0.6em;
padding-right: 0.6em; padding-right: 0.6em;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment