Commit b43d6ea1 authored by kurumuz's avatar kurumuz

k-diffusion samplers

parent 56787cb0
...@@ -14,6 +14,7 @@ from ldm.models.diffusion.ddim import DDIMSampler ...@@ -14,6 +14,7 @@ from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler from ldm.models.diffusion.plms import PLMSSampler
import time import time
from PIL import Image from PIL import Image
import k_diffusion as K
def pil_upscale(image, scale=1): def pil_upscale(image, scale=1):
device = image.device device = image.device
...@@ -58,6 +59,52 @@ def prompt_mixing(model, prompt_body, batch_size): ...@@ -58,6 +59,52 @@ def prompt_mixing(model, prompt_body, batch_size):
else: else:
return fix_batch(model.get_learned_conditioning([prompt_body]), batch_size) return fix_batch(model.get_learned_conditioning([prompt_body]), batch_size)
@torch.no_grad()
#@torch.autocast("cuda", enabled=True, dtype=torch.float16)
def encode_image(image, model):
if isinstance(image, Image.Image):
image = np.asarray(image)
image = torch.from_numpy(image).clone()
if isinstance(image, np.ndarray):
image = torch.from_numpy(image)
#gets image as numpy array and returns as tensor
def preprocess_vqgan(x):
x = x / 255.0
x = 2.*x - 1.
return x
image = image.permute(2, 0, 1).unsqueeze(0).float().cuda()
image = preprocess_vqgan(image)
image = model.encode(image).sample()
return image
@torch.no_grad()
def decode_image(image, model):
def custom_to_pil(x):
x = x.detach().float().cpu()
x = torch.clamp(x, -1., 1.)
x = (x + 1.)/2.
x = x.permute(1,2,0).numpy()
x = (255*x).astype(np.uint8)
x = Image.fromarray(x)
if not x.mode == "RGB":
x = x.convert("RGB")
return x
image = model.decode(image)
image = image.squeeze(0)
image = custom_to_pil(image)
return image
def sanitize_image(image):
#Open image with PIL and get rid of alpha channel, scale to given res with center crop
image = Image.open(image)
image = image.convert('RGB')
return image
class StableDiffusionModel(nn.Module): class StableDiffusionModel(nn.Module):
def __init__(self, config): def __init__(self, config):
...@@ -68,11 +115,22 @@ class StableDiffusionModel(nn.Module): ...@@ -68,11 +115,22 @@ class StableDiffusionModel(nn.Module):
typex = torch.float16 typex = torch.float16
else: else:
typex = torch.float32 typex = torch.float32
self.model = model.to(config.device).to(typex) self.k_model = K.external.CompVisDenoiser(model)
self.k_model = K.external.StableInterface(self.k_model)
self.device = config.device self.device = config.device
self.model_config = model_config self.model_config = model_config
self.plms = PLMSSampler(model) self.plms = PLMSSampler(model)
self.ddim = DDIMSampler(model) self.ddim = DDIMSampler(model)
self.sampler_map = {
'plms': self.plms.sample,
'ddim': self.ddim.sample,
'k_euler': K.sampling.sample_euler,
'k_euler_ancestral': K.sampling.sample_euler_ancestral,
'k_heun': K.sampling.sample_heun,
'k_dpm_2': K.sampling.sample_dpm_2,
'k_dpm_2_ancestral': K.sampling.sample_dpm_2_ancestral,
'k_lms': K.sampling.sample_lms,
}
def from_folder(self, folder): def from_folder(self, folder):
folder = Path(folder) folder = Path(folder)
...@@ -99,25 +157,41 @@ class StableDiffusionModel(nn.Module): ...@@ -99,25 +157,41 @@ class StableDiffusionModel(nn.Module):
return model return model
@torch.no_grad() @torch.no_grad()
@torch.autocast("cuda", enabled=True, dtype=torch.float16)
def sample(self, request): def sample(self, request):
request = DotMap(request) if request.image is not None:
request.sampler = "ddim_img2img" #enforce ddim for now
self.ddim.make_schedule(ddim_num_steps=request.steps, ddim_eta=request.ddim_eta, verbose=False)
image = sanitize_image(request.image)
image = image.resize((request.width, request.height), resample=Image.Resampling.LANCZOS)
start_code = encode_image(image, self.model.first_stage_model).to(self.device)
start_code = self.model.get_first_stage_encoding(start_code)
print(start_code.shape)
start_code = start_code + (torch.randn_like(start_code) * request.noise)
t_enc = int(request.strength * request.steps)
if request.seed is not None: if request.seed is not None:
torch.manual_seed(request.seed) torch.manual_seed(request.seed)
np.random.seed(request.seed) np.random.seed(request.seed)
if request.plms: if request.sampler.startswith("k_"):
sampler = self.plms sampler = "k-diffusion"
else:
sampler = self.ddim elif request.sampler == 'ddim_img2img':
sampler = 'img2img'
start_code = None else:
if request.fixed_code: sampler = "normal"
start_code = torch.randn([
request.n_samples, if request.image is None:
request.latent_channels, start_code = None
request.height // request.downsampling_factor, if request.fixed_code or sampler == "k-diffusion":
request.width // request.downsampling_factor, start_code = torch.randn([
], device=self.device) request.n_samples,
request.latent_channels,
request.height // request.downsampling_factor,
request.width // request.downsampling_factor,
], device=self.device)
prompt = [request.prompt] * request.n_samples prompt = [request.prompt] * request.n_samples
prompt_condition = prompt_mixing(self.model, prompt[0], request.n_samples) prompt_condition = prompt_mixing(self.model, prompt[0], request.n_samples)
...@@ -131,9 +205,9 @@ class StableDiffusionModel(nn.Module): ...@@ -131,9 +205,9 @@ class StableDiffusionModel(nn.Module):
request.height // request.downsampling_factor, request.height // request.downsampling_factor,
request.width // request.downsampling_factor request.width // request.downsampling_factor
] ]
with torch.autocast("cuda", enabled=self.config.amp): if sampler == "normal":
with self.model.ema_scope(): with self.model.ema_scope():
samples, _ = sampler.sample( samples, _ = self.sampler_map[request.sampler](
S=request.steps, S=request.steps,
conditioning=prompt_condition, conditioning=prompt_condition,
batch_size=request.n_samples, batch_size=request.n_samples,
...@@ -146,6 +220,18 @@ class StableDiffusionModel(nn.Module): ...@@ -146,6 +220,18 @@ class StableDiffusionModel(nn.Module):
x_T=start_code, x_T=start_code,
) )
elif sampler == 'img2img':
with self.model.ema_scope():
start_code = self.ddim.stochastic_encode(start_code, torch.tensor([t_enc]*request.n_samples).to(self.device), noise=None)
samples = self.ddim.decode(start_code, prompt_condition, t_enc, unconditional_guidance_scale=request.scale, unconditional_conditioning=uc)
elif sampler == "k-diffusion":
with self.model.ema_scope():
sigmas = self.k_model.get_sigmas(request.steps)
start_code = start_code * sigmas[0]
extra_args = {'cond': prompt_condition, 'uncond': uc, 'cond_scale': request.scale}
samples = self.sampler_map[request.sampler](self.k_model, start_code, sigmas, extra_args=extra_args)
x_samples_ddim = self.model.decode_first_stage(samples) x_samples_ddim = self.model.decode_first_stage(samples)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment