Commit c6954d2d authored by novelailab's avatar novelailab

upscaler

parent d8c57f5b
...@@ -7,12 +7,57 @@ from dotmap import DotMap ...@@ -7,12 +7,57 @@ from dotmap import DotMap
import numpy as np import numpy as np
import base64 import base64
from torch import autocast from torch import autocast
from einops import rearrange from einops import rearrange, repeat
from torchvision.utils import make_grid from torchvision.utils import make_grid
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler from ldm.models.diffusion.plms import PLMSSampler
import time import time
from PIL import Image
def pil_upscale(image, scale=1):
device = image.device
dtype = image.dtype
#image = Image.open("bob_Ross_as_captain_America__oil_on_canvas_artstation_by_J._C._Leyendecker_and_Edmund_Blair_Leighton_and_Charlie_Bowater_octane_render-0.jpg").convert("RGB") #image = Image.load("./Untitle524245425d.png")#
image = Image.fromarray((image.cpu().permute(1,2,0).numpy().astype(np.float32) * 255.).astype(np.uint8))
if scale > 1:
image = image.resize((int(image.width * scale), int(image.height * scale)), resample=Image.LANCZOS)
image = np.array(image)
image = image.astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
image = 2.*image - 1.
image = repeat(image, '1 ... -> b ...', b=1)
return image.to(device)
def fix_batch(tensor, bs):
return torch.stack([tensor.squeeze(0)]*bs, dim=0)
# mix conditioning vectors for prompts
def prompt_mixing(model, prompt_body, batch_size):
if "|" in prompt_body:
prompt_parts = prompt_body.split("|")
prompt_total_power = 0
prompt_sum = None
for prompt_part in prompt_parts:
prompt_power = 1
if ":" in prompt_part:
prompt_sub_parts = prompt_part.split(":")
try:
prompt_power = float(prompt_sub_parts[1])
prompt_part = prompt_sub_parts[0]
except:
print("Error parsing prompt power! Assuming 1")
prompt_vector = model.get_learned_conditioning([prompt_part])
if prompt_sum is None:
prompt_sum = prompt_vector * prompt_power
else:
prompt_sum = prompt_sum + (prompt_vector * prompt_power)
prompt_total_power = prompt_total_power + prompt_power
return fix_batch(prompt_sum / prompt_total_power, batch_size)
else:
return fix_batch(model.get_learned_conditioning([prompt_body]), batch_size)
class StableDiffusionModel(nn.Module): class StableDiffusionModel(nn.Module):
def __init__(self, config): def __init__(self, config):
...@@ -116,6 +161,97 @@ class StableDiffusionModel(nn.Module): ...@@ -116,6 +161,97 @@ class StableDiffusionModel(nn.Module):
return images return images
@torch.no_grad()
def sample_two_stages(self, request):
request = DotMap(request)
if request.seed is not None:
torch.manual_seed(request.seed)
np.random.seed(request.seed)
if request.plms:
sampler = self.plms
else:
sampler = self.ddim
start_code = None
if request.fixed_code:
start_code = torch.randn([
request.n_samples,
request.latent_channels,
request.height // request.downsampling_factor,
request.width // request.downsampling_factor,
], device=self.device)
prompt = [request.prompt] * request.n_samples
prompt_condition = self.model.get_learned_conditioning(prompt)
uc = None
if request.scale != 1.0:
uc = self.model.get_learned_conditioning(request.n_samples * [""])
shape = [
request.latent_channels,
request.height // request.downsampling_factor,
request.width // request.downsampling_factor
]
with torch.autocast("cuda", enabled=self.config.amp):
with self.model.ema_scope():
samples, _ = sampler.sample(
S=request.steps,
conditioning=prompt_condition,
batch_size=request.n_samples,
shape=shape,
verbose=False,
unconditional_guidance_scale=request.scale,
unconditional_conditioning=uc,
eta=request.ddim_eta,
dynamic_threshold=request.dynamic_threshold,
x_T=start_code,
)
x_samples_ddim = self.model.decode_first_stage(samples)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
with torch.autocast("cuda", enabled=self.config.amp):
with self.model.ema_scope():
init_latent = self.model.get_first_stage_encoding(self.model.encode_first_stage(x_samples_ddim))
sampler.make_schedule(ddim_num_steps=request.steps, ddim_eta=request.ddim_eta, verbose=False)
t_enc = int(0.69 * request.steps)
print("init latent shape:")
print(init_latent.shape)
init_latent = init_latent + (torch.randn_like(init_latent) * 0.667)
uc = None
if request.scale != 1.0:
uc = self.model.get_learned_conditioning(request.n_samples * [""])
c = prompt_mixing(self.model, prompt[0], request.n_samples)#(model.get_learned_conditioning(prompts) + model.get_learned_conditioning(["taken at night"])) / 2
# encode (scaled latent)
start_code_terped=None
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*request.n_samples).to(self.device), noise=start_code_terped)
# decode it
samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=request.scale,
unconditional_conditioning=uc,)
x_samples_ddim = self.model.decode_first_stage(samples)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
images = []
for x_sample in x_samples_ddim:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
x_sample = x_sample.astype(np.uint8)
x_sample = np.ascontiguousarray(x_sample)
images.append(x_sample)
if request.seed is not None:
torch.seed()
np.random.seed()
return images
@torch.no_grad() @torch.no_grad()
def sample_from_image(self, request): def sample_from_image(self, request):
return return
......
...@@ -117,6 +117,51 @@ def generate(request: GenerationRequest): ...@@ -117,6 +117,51 @@ def generate(request: GenerationRequest):
os.kill(mainpid, signal.SIGTERM) os.kill(mainpid, signal.SIGTERM)
return {"error": str(e)} return {"error": str(e)}
@app.post('/generate-advanced-stream')
def generate_advanced(request: GenerationRequest):
t = time.perf_counter()
try:
output = sanitize_input(config, request)
if output[0]:
request = output[1]
else:
return ErrorOutput(error=output[1])
images = model.sample_two_stages(request)
images_encoded = []
for x in range(len(images)):
image = simplejpeg.encode_jpeg(images[x], quality=95)
#get base64 of image
image = base64.b64encode(image).decode("ascii")
images_encoded.append(image)
del images
process_time = time.perf_counter() - t
logger.info(f"Request took {process_time:0.3f} seconds")
data = ""
ptr = 0
for x in images_encoded:
ptr += 1
data += ("event: newImage\nid: {}\ndata:{}\n\n").format(ptr, x)
return Response(content=data, media_type="text/event-stream")
#return GenerationOutput(output=images)
except Exception as e:
traceback.print_exc()
capture_exception(e)
logger.error(str(e))
e_s = str(e)
gc.collect()
if "CUDA out of memory" in e_s or \
"an illegal memory access" in e_s or "CUDA" in e_s:
logger.error("GPU error, committing seppuku.")
os.kill(mainpid, signal.SIGTERM)
return {"error": str(e)}
''' '''
@app.post('/image-to-image') @app.post('/image-to-image')
def image_to_image(request: GenerationRequest): def image_to_image(request: GenerationRequest):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment