Commit 71901b3d authored by C43H66N12O12S2's avatar C43H66N12O12S2 Committed by AUTOMATIC1111

add karras scheduling variants

parent c1a068ed
......@@ -26,6 +26,17 @@ samplers_k_diffusion = [
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad']),
]
if opts.show_karras_scheduler_variants:
k_diffusion.sampling.sample_dpm_2_ka = k_diffusion.sampling.sample_dpm_2
k_diffusion.sampling.sample_dpm_2_ancestral_ka = k_diffusion.sampling.sample_dpm_2_ancestral
k_diffusion.sampling.sample_lms_ka = k_diffusion.sampling.sample_lms
samplers_k_diffusion_ka = [
('LMS K Scheduling', 'sample_lms_ka', ['k_lms_ka']),
('DPM2 K Scheduling', 'sample_dpm_2_ka', ['k_dpm_2_ka']),
('DPM2 a K Scheduling', 'sample_dpm_2_ancestral_ka', ['k_dpm_2_a_ka']),
]
samplers_k_diffusion.extend(samplers_k_diffusion_ka)
samplers_data_k_diffusion = [
SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases)
for label, funcname, aliases in samplers_k_diffusion
......@@ -345,6 +356,8 @@ class KDiffusionSampler:
if p.sampler_noise_scheduler_override:
sigmas = p.sampler_noise_scheduler_override(steps)
elif self.funcname.endswith('ka'):
sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device)
else:
sigmas = self.model_wrap.get_sigmas(steps)
x = x * sigmas[0]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment