Commit e14b586d authored by Sakura-Luna's avatar Sakura-Luna

Add Tiny AE live preview

parent b08500ce
...@@ -2,7 +2,7 @@ from collections import namedtuple ...@@ -2,7 +2,7 @@ from collections import namedtuple
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from modules import devices, processing, images, sd_vae_approx from modules import devices, processing, images, sd_vae_approx, sd_vae_taesd
from modules.shared import opts, state from modules.shared import opts, state
import modules.shared as shared import modules.shared as shared
...@@ -22,21 +22,26 @@ def setup_img2img_steps(p, steps=None): ...@@ -22,21 +22,26 @@ def setup_img2img_steps(p, steps=None):
return steps, t_enc return steps, t_enc
approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2} approximation_indexes = {"Full": 0, "Tiny AE": 1, "Approx NN": 2, "Approx cheap": 3}
def single_sample_to_image(sample, approximation=None): def single_sample_to_image(sample, approximation=None):
if approximation is None: if approximation is None:
approximation = approximation_indexes.get(opts.show_progress_type, 0) approximation = approximation_indexes.get(opts.show_progress_type, 0)
if approximation == 2: if approximation == 1:
x_sample = sd_vae_approx.cheap_approximation(sample) x_sample = sd_vae_taesd.decode()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
elif approximation == 1: x_sample = sd_vae_taesd.TAESD.unscale_latents(x_sample)
x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() x_sample = torch.clamp((x_sample * 0.25) + 0.5, 0, 1)
else: else:
x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] if approximation == 3:
x_sample = sd_vae_approx.cheap_approximation(sample)
elif approximation == 2:
x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
else:
x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8) x_sample = x_sample.astype(np.uint8)
return Image.fromarray(x_sample) return Image.fromarray(x_sample)
......
"""
Tiny AutoEncoder for Stable Diffusion
(DNN for encoding / decoding SD's latent space)
https://github.com/madebyollin/taesd
"""
import os
import torch
import torch.nn as nn
from modules import devices, paths_internal
sd_vae_taesd = None
def conv(n_in, n_out, **kwargs):
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
class Clamp(nn.Module):
@staticmethod
def forward(x):
return torch.tanh(x / 3) * 3
class Block(nn.Module):
def __init__(self, n_in, n_out):
super().__init__()
self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out))
self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
self.fuse = nn.ReLU()
def forward(self, x):
return self.fuse(self.conv(x) + self.skip(x))
def decoder():
return nn.Sequential(
Clamp(), conv(4, 64), nn.ReLU(),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), conv(64, 3),
)
class TAESD(nn.Module):
latent_magnitude = 2
latent_shift = 0.5
def __init__(self, decoder_path="taesd_decoder.pth"):
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
super().__init__()
self.decoder = decoder()
self.decoder.load_state_dict(
torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))
@staticmethod
def unscale_latents(x):
"""[0, 1] -> raw latents"""
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
def decode():
global sd_vae_taesd
if sd_vae_taesd is None:
model_path = os.path.join(paths_internal.models_path, "VAE-approx", "taesd_decoder.pth")
if os.path.exists(model_path):
sd_vae_taesd = TAESD(model_path)
sd_vae_taesd.eval()
sd_vae_taesd.to(devices.device, devices.dtype)
else:
raise FileNotFoundError('Tiny AE mdoel not found')
return sd_vae_taesd.decoder
...@@ -425,7 +425,7 @@ options_templates.update(options_section(('ui', "Live previews"), { ...@@ -425,7 +425,7 @@ options_templates.update(options_section(('ui', "Live previews"), {
"live_previews_enable": OptionInfo(True, "Show live previews of the created image"), "live_previews_enable": OptionInfo(True, "Show live previews of the created image"),
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"), "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
"show_progress_every_n_steps": OptionInfo(10, "Show new live preview image every N sampling steps. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}), "show_progress_every_n_steps": OptionInfo(10, "Show new live preview image every N sampling steps. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
"show_progress_type": OptionInfo("Approx NN", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}), "show_progress_type": OptionInfo("Tiny AE", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Tiny AE", "Approx NN", "Approx cheap"]}),
"live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}), "live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
"live_preview_refresh_period": OptionInfo(1000, "Progressbar/preview update period, in milliseconds") "live_preview_refresh_period": OptionInfo(1000, "Progressbar/preview update period, in milliseconds")
})) }))
......
...@@ -144,10 +144,21 @@ Use --skip-version-check commandline argument to disable this check. ...@@ -144,10 +144,21 @@ Use --skip-version-check commandline argument to disable this check.
""".strip()) """.strip())
def check_taesd():
from modules.paths_internal import models_path
model_url = 'https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth'
model_path = os.path.join(models_path, "VAE-approx", "taesd_decoder.pth")
if not os.path.exists(model_path):
print('download taesd model')
torch.hub.download_url_to_file(model_url, os.path.dirname(model_path))
def initialize(): def initialize():
fix_asyncio_event_loop_policy() fix_asyncio_event_loop_policy()
check_versions() check_versions()
check_taesd()
extensions.list_extensions() extensions.list_extensions()
localization.list_localizations(cmd_opts.localizations_dir) localization.list_localizations(cmd_opts.localizations_dir)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment