Commit f098e726 authored by w-e-w's avatar w-e-w

fix conds caching with extra network

parent 30bbb8bc
...@@ -32,6 +32,9 @@ class ExtraNetworkParams: ...@@ -32,6 +32,9 @@ class ExtraNetworkParams:
else: else:
self.positional.append(item) self.positional.append(item)
def __eq__(self, other):
return self.items == other.items and self.positional == other.positional and self.named == other.named
class ExtraNetwork: class ExtraNetwork:
def __init__(self, name): def __init__(self, name):
......
...@@ -171,6 +171,7 @@ class StableDiffusionProcessing: ...@@ -171,6 +171,7 @@ class StableDiffusionProcessing:
self.prompts = None self.prompts = None
self.negative_prompts = None self.negative_prompts = None
self.extra_network_data = None
self.seeds = None self.seeds = None
self.subseeds = None self.subseeds = None
...@@ -311,7 +312,7 @@ class StableDiffusionProcessing: ...@@ -311,7 +312,7 @@ class StableDiffusionProcessing:
self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts] self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts]
self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts] self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts]
def get_conds_with_caching(self, function, required_prompts, steps, cache): def get_conds_with_caching(self, function, required_prompts, steps, cache, extra_network_data):
""" """
Returns the result of calling function(shared.sd_model, required_prompts, steps) Returns the result of calling function(shared.sd_model, required_prompts, steps)
using a cache to store the result if the same arguments have been used before. using a cache to store the result if the same arguments have been used before.
...@@ -321,21 +322,21 @@ class StableDiffusionProcessing: ...@@ -321,21 +322,21 @@ class StableDiffusionProcessing:
have been used before. The second element is where the previously have been used before. The second element is where the previously
computed result is stored. computed result is stored.
""" """
if cache[0] is not None and (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info) == cache[0]: if cache[0] is not None and (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data) == cache[0]:
return cache[1] return cache[1]
with devices.autocast(): with devices.autocast():
cache[1] = function(shared.sd_model, required_prompts, steps) cache[1] = function(shared.sd_model, required_prompts, steps)
cache[0] = (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info) cache[0] = (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data)
return cache[1] return cache[1]
def setup_conds(self): def setup_conds(self):
sampler_config = sd_samplers.find_sampler_config(self.sampler_name) sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1 self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.negative_prompts, self.steps * self.step_multiplier, self.cached_uc) self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.negative_prompts, self.steps * self.step_multiplier, self.cached_uc, self.extra_network_data)
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.prompts, self.steps * self.step_multiplier, self.cached_c) self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.prompts, self.steps * self.step_multiplier, self.cached_c, self.extra_network_data)
def parse_extra_network_prompts(self): def parse_extra_network_prompts(self):
self.prompts, extra_network_data = extra_networks.parse_prompts(self.prompts) self.prompts, extra_network_data = extra_networks.parse_prompts(self.prompts)
...@@ -681,7 +682,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: ...@@ -681,7 +682,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if state.job_count == -1: if state.job_count == -1:
state.job_count = p.n_iter state.job_count = p.n_iter
extra_network_data = None
for n in range(p.n_iter): for n in range(p.n_iter):
p.iteration = n p.iteration = n
...@@ -702,11 +702,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: ...@@ -702,11 +702,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if len(p.prompts) == 0: if len(p.prompts) == 0:
break break
extra_network_data = p.parse_extra_network_prompts() p.extra_network_data = p.parse_extra_network_prompts()
if not p.disable_extra_networks: if not p.disable_extra_networks:
with devices.autocast(): with devices.autocast():
extra_networks.activate(p, extra_network_data) extra_networks.activate(p, p.extra_network_data)
if p.scripts is not None: if p.scripts is not None:
p.scripts.process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds) p.scripts.process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
...@@ -828,8 +828,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: ...@@ -828,8 +828,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if opts.grid_save: if opts.grid_save:
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True) images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
if not p.disable_extra_networks and extra_network_data: if not p.disable_extra_networks and p.extra_network_data:
extra_networks.deactivate(p, extra_network_data) extra_networks.deactivate(p, p.extra_network_data)
devices.torch_gc() devices.torch_gc()
...@@ -1101,8 +1101,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): ...@@ -1101,8 +1101,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
super().setup_conds() super().setup_conds()
if self.enable_hr: if self.enable_hr:
self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, self.cached_uc) self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, self.cached_uc, self.hr_extra_network_data)
self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, self.cached_c) self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, self.cached_c, self.hr_extra_network_data)
def parse_extra_network_prompts(self): def parse_extra_network_prompts(self):
res = super().parse_extra_network_prompts() res = super().parse_extra_network_prompts()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment