Commit 1ca5e76f authored by AUTOMATIC's avatar AUTOMATIC

fix for conds of second hires fox pass being calculated using first pass's...

fix for conds of second hires fox pass being calculated using first pass's networks, and add an option to revert to old behavior
parent 1c6dca93
...@@ -15,6 +15,8 @@ def send_everything_to_cpu(): ...@@ -15,6 +15,8 @@ def send_everything_to_cpu():
def setup_for_low_vram(sd_model, use_medvram): def setup_for_low_vram(sd_model, use_medvram):
sd_model.lowvram = True
parents = {} parents = {}
def send_me_to_gpu(module, _): def send_me_to_gpu(module, _):
...@@ -96,3 +98,7 @@ def setup_for_low_vram(sd_model, use_medvram): ...@@ -96,3 +98,7 @@ def setup_for_low_vram(sd_model, use_medvram):
diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu) diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu)
for block in diff_model.output_blocks: for block in diff_model.output_blocks:
block.register_forward_pre_hook(send_me_to_gpu) block.register_forward_pre_hook(send_me_to_gpu)
def is_enabled(sd_model):
return getattr(sd_model, 'lowvram', False)
...@@ -739,7 +739,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: ...@@ -739,7 +739,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
del samples_ddim del samples_ddim
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: if lowvram.is_enabled(shared.sd_model):
lowvram.send_everything_to_cpu() lowvram.send_everything_to_cpu()
devices.torch_gc() devices.torch_gc()
...@@ -894,6 +894,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): ...@@ -894,6 +894,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.hr_negative_prompts = None self.hr_negative_prompts = None
self.hr_extra_network_data = None self.hr_extra_network_data = None
self.cached_hr_uc = [None, None]
self.cached_hr_c = [None, None]
self.hr_c = None self.hr_c = None
self.hr_uc = None self.hr_uc = None
...@@ -1056,6 +1058,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): ...@@ -1056,6 +1058,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
with devices.autocast(): with devices.autocast():
extra_networks.activate(self, self.hr_extra_network_data) extra_networks.activate(self, self.hr_extra_network_data)
with devices.autocast():
self.calculate_hr_conds()
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True)) sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))
samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning) samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
...@@ -1067,6 +1072,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): ...@@ -1067,6 +1072,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
return samples return samples
def close(self): def close(self):
self.cached_hr_uc = [None, None]
self.cached_hr_c = [None, None]
self.hr_c = None self.hr_c = None
self.hr_uc = None self.hr_uc = None
...@@ -1095,12 +1102,31 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): ...@@ -1095,12 +1102,31 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.all_hr_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_hr_prompts] self.all_hr_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_hr_prompts]
self.all_hr_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_hr_negative_prompts] self.all_hr_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_hr_negative_prompts]
def calculate_hr_conds(self):
if self.hr_c is not None:
return
self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, self.cached_hr_uc, self.hr_extra_network_data)
self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, self.cached_hr_c, self.hr_extra_network_data)
def setup_conds(self): def setup_conds(self):
super().setup_conds() super().setup_conds()
self.hr_uc = None
self.hr_c = None
if self.enable_hr: if self.enable_hr:
self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, self.cached_uc, self.hr_extra_network_data) if shared.opts.hires_fix_use_firstpass_conds:
self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, self.cached_c, self.hr_extra_network_data) self.calculate_hr_conds()
elif lowvram.is_enabled(shared.sd_model): # if in lowvram mode, we need to calculate conds right away, before the cond NN is unloaded
with devices.autocast():
extra_networks.activate(self, self.hr_extra_network_data)
self.calculate_hr_conds()
with devices.autocast():
extra_networks.activate(self, self.extra_network_data)
def parse_extra_network_prompts(self): def parse_extra_network_prompts(self):
res = super().parse_extra_network_prompts() res = super().parse_extra_network_prompts()
......
...@@ -429,6 +429,7 @@ options_templates.update(options_section(('compatibility', "Compatibility"), { ...@@ -429,6 +429,7 @@ options_templates.update(options_section(('compatibility', "Compatibility"), {
"no_dpmpp_sde_batch_determinism": OptionInfo(False, "Do not make DPM++ SDE deterministic across different batch sizes."), "no_dpmpp_sde_batch_determinism": OptionInfo(False, "Do not make DPM++ SDE deterministic across different batch sizes."),
"use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."), "use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."),
"dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."), "dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."),
"hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."),
})) }))
options_templates.update(options_section(('interrogate', "Interrogate Options"), { options_templates.update(options_section(('interrogate', "Interrogate Options"), {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment