Commit d3e32e79 authored by kurumuz's avatar kurumuz

StableInterface for k-diffusion

parent b43d6ea1
...@@ -105,6 +105,25 @@ def sanitize_image(image): ...@@ -105,6 +105,25 @@ def sanitize_image(image):
image = image.convert('RGB') image = image.convert('RGB')
return image return image
class StableInterface(nn.Module):
def __init__(self, model, thresholder = None):
super().__init__()
self.inner_model = model
self.sigma_to_t = model.sigma_to_t
self.thresholder = thresholder
self.get_sigmas = model.get_sigmas
@torch.no_grad()
def forward(self, x, sigma, uncond, cond, cond_scale):
x_two = torch.cat([x] * 2)
sigma_two = torch.cat([sigma] * 2)
cond_full = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_two, sigma_two, cond=cond_full).chunk(2)
x_0 = uncond + (cond - uncond) * cond_scale
if self.thresholder is not None:
x_0 = self.thresholder(x_0)
return x_0
class StableDiffusionModel(nn.Module): class StableDiffusionModel(nn.Module):
def __init__(self, config): def __init__(self, config):
...@@ -116,7 +135,7 @@ class StableDiffusionModel(nn.Module): ...@@ -116,7 +135,7 @@ class StableDiffusionModel(nn.Module):
else: else:
typex = torch.float32 typex = torch.float32
self.k_model = K.external.CompVisDenoiser(model) self.k_model = K.external.CompVisDenoiser(model)
self.k_model = K.external.StableInterface(self.k_model) self.k_model = StableInterface(self.k_model)
self.device = config.device self.device = config.device
self.model_config = model_config self.model_config = model_config
self.plms = PLMSSampler(model) self.plms = PLMSSampler(model)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment