Commit 86361329 authored by Robert Barron's avatar Robert Barron

Merge branch 'shared-hires-prompt-raw' into shared-hires-prompt-test

parents 9af5cce4 d1ba46b6
...@@ -326,12 +326,14 @@ class StableDiffusionProcessing: ...@@ -326,12 +326,14 @@ class StableDiffusionProcessing:
self.main_prompt = self.all_prompts[0] self.main_prompt = self.all_prompts[0]
self.main_negative_prompt = self.all_negative_prompts[0] self.main_negative_prompt = self.all_negative_prompts[0]
def cached_params(self, required_prompts, steps, extra_network_data): def cached_params(self, required_prompts, steps, hires_steps, extra_network_data, use_old_scheduling):
"""Returns parameters that invalidate the cond cache if changed""" """Returns parameters that invalidate the cond cache if changed"""
return ( return (
required_prompts, required_prompts,
steps, steps,
hires_steps,
use_old_scheduling,
opts.CLIP_stop_at_last_layers, opts.CLIP_stop_at_last_layers,
shared.sd_model.sd_checkpoint_info, shared.sd_model.sd_checkpoint_info,
extra_network_data, extra_network_data,
...@@ -341,7 +343,7 @@ class StableDiffusionProcessing: ...@@ -341,7 +343,7 @@ class StableDiffusionProcessing:
self.height, self.height,
) )
def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data): def get_conds_with_caching(self, function, required_prompts, steps, hires_steps, caches, extra_network_data):
""" """
Returns the result of calling function(shared.sd_model, required_prompts, steps) Returns the result of calling function(shared.sd_model, required_prompts, steps)
using a cache to store the result if the same arguments have been used before. using a cache to store the result if the same arguments have been used before.
...@@ -354,7 +356,7 @@ class StableDiffusionProcessing: ...@@ -354,7 +356,7 @@ class StableDiffusionProcessing:
caches is a list with items described above. caches is a list with items described above.
""" """
cached_params = self.cached_params(required_prompts, steps, extra_network_data) cached_params = self.cached_params(required_prompts, steps, hires_steps, extra_network_data, shared.opts.use_old_scheduling)
for cache in caches: for cache in caches:
if cache[0] is not None and cached_params == cache[0]: if cache[0] is not None and cached_params == cache[0]:
...@@ -363,7 +365,7 @@ class StableDiffusionProcessing: ...@@ -363,7 +365,7 @@ class StableDiffusionProcessing:
cache = caches[0] cache = caches[0]
with devices.autocast(): with devices.autocast():
cache[1] = function(shared.sd_model, required_prompts, steps) cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling)
cache[0] = cached_params cache[0] = cached_params
return cache[1] return cache[1]
...@@ -374,8 +376,9 @@ class StableDiffusionProcessing: ...@@ -374,8 +376,9 @@ class StableDiffusionProcessing:
sampler_config = sd_samplers.find_sampler_config(self.sampler_name) sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1 self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data) self.firstpass_steps = self.steps * self.step_multiplier
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data) self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, self.firstpass_steps, None, [self.cached_uc], self.extra_network_data)
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, self.firstpass_steps, None, [self.cached_c], self.extra_network_data)
def get_conds(self): def get_conds(self):
return self.c, self.uc return self.c, self.uc
...@@ -1188,8 +1191,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): ...@@ -1188,8 +1191,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
hr_prompts = prompt_parser.SdConditioning(self.hr_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y) hr_prompts = prompt_parser.SdConditioning(self.hr_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y)
hr_negative_prompts = prompt_parser.SdConditioning(self.hr_negative_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y, is_negative_prompt=True) hr_negative_prompts = prompt_parser.SdConditioning(self.hr_negative_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y, is_negative_prompt=True)
self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data) hires_steps = (self.hr_second_pass_steps or self.steps) * self.step_multiplier
self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data) self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.firstpass_steps, hires_steps, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.firstpass_steps, hires_steps, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)
def setup_conds(self): def setup_conds(self):
super().setup_conds() super().setup_conds()
......
...@@ -26,7 +26,7 @@ plain: /([^\\\[\]():|]|\\.)+/ ...@@ -26,7 +26,7 @@ plain: /([^\\\[\]():|]|\\.)+/
%import common.SIGNED_NUMBER -> NUMBER %import common.SIGNED_NUMBER -> NUMBER
""") """)
def get_learned_conditioning_prompt_schedules(prompts, steps): def get_learned_conditioning_prompt_schedules(prompts, base_steps, hires_steps=None, use_old_scheduling=False):
""" """
>>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10)[0] >>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10)[0]
>>> g("test") >>> g("test")
...@@ -57,18 +57,39 @@ def get_learned_conditioning_prompt_schedules(prompts, steps): ...@@ -57,18 +57,39 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
[[1, 'female'], [2, 'male'], [3, 'female'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'female'], [8, 'male'], [9, 'female'], [10, 'male']] [[1, 'female'], [2, 'male'], [3, 'female'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'female'], [8, 'male'], [9, 'female'], [10, 'male']]
>>> g("[fe|||]male") >>> g("[fe|||]male")
[[1, 'female'], [2, 'male'], [3, 'male'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'male'], [8, 'male'], [9, 'female'], [10, 'male']] [[1, 'female'], [2, 'male'], [3, 'male'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'male'], [8, 'male'], [9, 'female'], [10, 'male']]
>>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10, 10)[0]
>>> g("a [b:.5] c")
[[10, 'a b c']]
>>> g("a [b:1.5] c")
[[5, 'a c'], [10, 'a b c']]
""" """
if hires_steps is None or use_old_scheduling:
int_offset = 0
flt_offset = 0
steps = base_steps
else:
int_offset = base_steps
flt_offset = 1.0
steps = hires_steps
def collect_steps(steps, tree): def collect_steps(steps, tree):
res = [steps] res = [steps]
class CollectSteps(lark.Visitor): class CollectSteps(lark.Visitor):
def scheduled(self, tree): def scheduled(self, tree):
tree.children[-2] = float(tree.children[-2]) s = tree.children[-2]
if tree.children[-2] < 1: v = float(s)
tree.children[-2] *= steps if use_old_scheduling:
tree.children[-2] = min(steps, int(tree.children[-2])) v = v*steps if v<1 else v
res.append(tree.children[-2]) else:
if "." in s:
v = (v - flt_offset) * steps
else:
v = (v - int_offset)
tree.children[-2] = min(steps, int(v))
if tree.children[-2] >= 1:
res.append(tree.children[-2])
def alternate(self, tree): def alternate(self, tree):
res.extend(range(1, steps+1)) res.extend(range(1, steps+1))
...@@ -134,7 +155,7 @@ class SdConditioning(list): ...@@ -134,7 +155,7 @@ class SdConditioning(list):
def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps): def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps, hires_steps=None, use_old_scheduling=False):
"""converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond), """converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
and the sampling step at which this condition is to be replaced by the next one. and the sampling step at which this condition is to be replaced by the next one.
...@@ -154,7 +175,7 @@ def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps): ...@@ -154,7 +175,7 @@ def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps):
""" """
res = [] res = []
prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps) prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps, hires_steps, use_old_scheduling)
cache = {} cache = {}
for prompt, prompt_schedule in zip(prompts, prompt_schedules): for prompt, prompt_schedule in zip(prompts, prompt_schedules):
...@@ -229,7 +250,7 @@ class MulticondLearnedConditioning: ...@@ -229,7 +250,7 @@ class MulticondLearnedConditioning:
self.batch: List[List[ComposableScheduledPromptConditioning]] = batch self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning: def get_multicond_learned_conditioning(model, prompts, steps, hires_steps=None, use_old_scheduling=False) -> MulticondLearnedConditioning:
"""same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt. """same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
For each prompt, the list is obtained by splitting the prompt using the AND separator. For each prompt, the list is obtained by splitting the prompt using the AND separator.
...@@ -238,7 +259,7 @@ def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearne ...@@ -238,7 +259,7 @@ def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearne
res_indexes, prompt_flat_list, prompt_indexes = get_multicond_prompt_list(prompts) res_indexes, prompt_flat_list, prompt_indexes = get_multicond_prompt_list(prompts)
learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps) learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps, hires_steps, use_old_scheduling)
res = [] res = []
for indexes in res_indexes: for indexes in res_indexes:
......
This diff is collapsed.
...@@ -197,6 +197,7 @@ options_templates.update(options_section(('compatibility', "Compatibility"), { ...@@ -197,6 +197,7 @@ options_templates.update(options_section(('compatibility', "Compatibility"), {
"use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."), "use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."),
"dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."), "dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."),
"hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."), "hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."),
"use_old_scheduling": OptionInfo(False, "Use old prompt where first pass and hires both used the same timeline, and < 1 meant relative and >= 1 meant absolute"),
})) }))
options_templates.update(options_section(('interrogate', "Interrogate"), { options_templates.update(options_section(('interrogate', "Interrogate"), {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment