Commit c6954d2d authored by novelailab's avatar novelailab

upscaler

parent d8c57f5b
......@@ -7,12 +7,57 @@ from dotmap import DotMap
import numpy as np
import base64
from torch import autocast
from einops import rearrange
from einops import rearrange, repeat
from torchvision.utils import make_grid
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
import time
from PIL import Image
def pil_upscale(image, scale=1):
device = image.device
dtype = image.dtype
#image = Image.open("bob_Ross_as_captain_America__oil_on_canvas_artstation_by_J._C._Leyendecker_and_Edmund_Blair_Leighton_and_Charlie_Bowater_octane_render-0.jpg").convert("RGB") #image = Image.load("./Untitle524245425d.png")#
image = Image.fromarray((image.cpu().permute(1,2,0).numpy().astype(np.float32) * 255.).astype(np.uint8))
if scale > 1:
image = image.resize((int(image.width * scale), int(image.height * scale)), resample=Image.LANCZOS)
image = np.array(image)
image = image.astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
image = 2.*image - 1.
image = repeat(image, '1 ... -> b ...', b=1)
return image.to(device)
def fix_batch(tensor, bs):
return torch.stack([tensor.squeeze(0)]*bs, dim=0)
# mix conditioning vectors for prompts
def prompt_mixing(model, prompt_body, batch_size):
if "|" in prompt_body:
prompt_parts = prompt_body.split("|")
prompt_total_power = 0
prompt_sum = None
for prompt_part in prompt_parts:
prompt_power = 1
if ":" in prompt_part:
prompt_sub_parts = prompt_part.split(":")
try:
prompt_power = float(prompt_sub_parts[1])
prompt_part = prompt_sub_parts[0]
except:
print("Error parsing prompt power! Assuming 1")
prompt_vector = model.get_learned_conditioning([prompt_part])
if prompt_sum is None:
prompt_sum = prompt_vector * prompt_power
else:
prompt_sum = prompt_sum + (prompt_vector * prompt_power)
prompt_total_power = prompt_total_power + prompt_power
return fix_batch(prompt_sum / prompt_total_power, batch_size)
else:
return fix_batch(model.get_learned_conditioning([prompt_body]), batch_size)
class StableDiffusionModel(nn.Module):
def __init__(self, config):
......@@ -116,6 +161,97 @@ class StableDiffusionModel(nn.Module):
return images
@torch.no_grad()
def sample_two_stages(self, request):
request = DotMap(request)
if request.seed is not None:
torch.manual_seed(request.seed)
np.random.seed(request.seed)
if request.plms:
sampler = self.plms
else:
sampler = self.ddim
start_code = None
if request.fixed_code:
start_code = torch.randn([
request.n_samples,
request.latent_channels,
request.height // request.downsampling_factor,
request.width // request.downsampling_factor,
], device=self.device)
prompt = [request.prompt] * request.n_samples
prompt_condition = self.model.get_learned_conditioning(prompt)
uc = None
if request.scale != 1.0:
uc = self.model.get_learned_conditioning(request.n_samples * [""])
shape = [
request.latent_channels,
request.height // request.downsampling_factor,
request.width // request.downsampling_factor
]
with torch.autocast("cuda", enabled=self.config.amp):
with self.model.ema_scope():
samples, _ = sampler.sample(
S=request.steps,
conditioning=prompt_condition,
batch_size=request.n_samples,
shape=shape,
verbose=False,
unconditional_guidance_scale=request.scale,
unconditional_conditioning=uc,
eta=request.ddim_eta,
dynamic_threshold=request.dynamic_threshold,
x_T=start_code,
)
x_samples_ddim = self.model.decode_first_stage(samples)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
with torch.autocast("cuda", enabled=self.config.amp):
with self.model.ema_scope():
init_latent = self.model.get_first_stage_encoding(self.model.encode_first_stage(x_samples_ddim))
sampler.make_schedule(ddim_num_steps=request.steps, ddim_eta=request.ddim_eta, verbose=False)
t_enc = int(0.69 * request.steps)
print("init latent shape:")
print(init_latent.shape)
init_latent = init_latent + (torch.randn_like(init_latent) * 0.667)
uc = None
if request.scale != 1.0:
uc = self.model.get_learned_conditioning(request.n_samples * [""])
c = prompt_mixing(self.model, prompt[0], request.n_samples)#(model.get_learned_conditioning(prompts) + model.get_learned_conditioning(["taken at night"])) / 2
# encode (scaled latent)
start_code_terped=None
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*request.n_samples).to(self.device), noise=start_code_terped)
# decode it
samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=request.scale,
unconditional_conditioning=uc,)
x_samples_ddim = self.model.decode_first_stage(samples)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
images = []
for x_sample in x_samples_ddim:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
x_sample = x_sample.astype(np.uint8)
x_sample = np.ascontiguousarray(x_sample)
images.append(x_sample)
if request.seed is not None:
torch.seed()
np.random.seed()
return images
@torch.no_grad()
def sample_from_image(self, request):
return
......
......@@ -117,6 +117,51 @@ def generate(request: GenerationRequest):
os.kill(mainpid, signal.SIGTERM)
return {"error": str(e)}
@app.post('/generate-advanced-stream')
def generate_advanced(request: GenerationRequest):
t = time.perf_counter()
try:
output = sanitize_input(config, request)
if output[0]:
request = output[1]
else:
return ErrorOutput(error=output[1])
images = model.sample_two_stages(request)
images_encoded = []
for x in range(len(images)):
image = simplejpeg.encode_jpeg(images[x], quality=95)
#get base64 of image
image = base64.b64encode(image).decode("ascii")
images_encoded.append(image)
del images
process_time = time.perf_counter() - t
logger.info(f"Request took {process_time:0.3f} seconds")
data = ""
ptr = 0
for x in images_encoded:
ptr += 1
data += ("event: newImage\nid: {}\ndata:{}\n\n").format(ptr, x)
return Response(content=data, media_type="text/event-stream")
#return GenerationOutput(output=images)
except Exception as e:
traceback.print_exc()
capture_exception(e)
logger.error(str(e))
e_s = str(e)
gc.collect()
if "CUDA out of memory" in e_s or \
"an illegal memory access" in e_s or "CUDA" in e_s:
logger.error("GPU error, committing seppuku.")
os.kill(mainpid, signal.SIGTERM)
return {"error": str(e)}
'''
@app.post('/image-to-image')
def image_to_image(request: GenerationRequest):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment