Commit 883a3277 authored by kurumuz's avatar kurumuz

push working

parent 87844d91
...@@ -3,6 +3,7 @@ import json ...@@ -3,6 +3,7 @@ import json
import base64 import base64
from PIL import Image from PIL import Image
import io import io
import random
#server hosts on 0.0.0.0 #server hosts on 0.0.0.0
IP_ADDR = '0.0.0.0' IP_ADDR = '0.0.0.0'
PORT = '4315' PORT = '4315'
...@@ -15,16 +16,14 @@ masks = [ ...@@ -15,16 +16,14 @@ masks = [
] ]
payload = { payload = {
'prompt': 'test', 'prompt': 'Tags: red bikini',
"width": 512, "width": 512,
"height": 512, "height": 512,
"scale": 12, "scale": 12,
"sampler": "k_lms", "sampler": "k_euler_ancestral",
"steps": 50, "steps": 50,
"seed": 3808250753, "seed": random.randint(0, 2**32),
"n_samples": 1, "n_samples": 1,
"strength": 0.7,
"noise": 0.6,
"masks": None "masks": None
} }
......
...@@ -139,6 +139,7 @@ def init_config_model(): ...@@ -139,6 +139,7 @@ def init_config_model():
# Resolve where we get our model and data from. # Resolve where we get our model and data from.
config.model_path = os.getenv('MODEL_PATH', None) config.model_path = os.getenv('MODEL_PATH', None)
config.enable_ema = os.getenv('ENABLE_EMA', "1")
config.vae_path = os.getenv('VAE_PATH', None) config.vae_path = os.getenv('VAE_PATH', None)
config.module_path = os.getenv('MODULE_PATH', None) config.module_path = os.getenv('MODULE_PATH', None)
config.prior_path = os.getenv('PRIOR_PATH', None) config.prior_path = os.getenv('PRIOR_PATH', None)
......
...@@ -16,6 +16,7 @@ from ldm.modules.attention import CrossAttention, HyperLogic ...@@ -16,6 +16,7 @@ from ldm.modules.attention import CrossAttention, HyperLogic
import time import time
from PIL import Image from PIL import Image
import k_diffusion as K import k_diffusion as K
import contextlib
def pil_upscale(image, scale=1): def pil_upscale(image, scale=1):
device = image.device device = image.device
...@@ -86,6 +87,8 @@ def encode_image(image, model): ...@@ -86,6 +87,8 @@ def encode_image(image, model):
if isinstance(image, np.ndarray): if isinstance(image, np.ndarray):
image = torch.from_numpy(image) image = torch.from_numpy(image)
dtype = image.dtype
image = image.to(torch.float32)
#gets image as numpy array and returns as tensor #gets image as numpy array and returns as tensor
def preprocess_vqgan(x): def preprocess_vqgan(x):
x = x / 255.0 x = x / 255.0
...@@ -95,6 +98,7 @@ def encode_image(image, model): ...@@ -95,6 +98,7 @@ def encode_image(image, model):
image = image.permute(2, 0, 1).unsqueeze(0).float().cuda() image = image.permute(2, 0, 1).unsqueeze(0).float().cuda()
image = preprocess_vqgan(image) image = preprocess_vqgan(image)
image = model.encode(image).sample() image = model.encode(image).sample()
image = image.to(dtype)
return image return image
...@@ -104,15 +108,14 @@ def decode_image(image, model): ...@@ -104,15 +108,14 @@ def decode_image(image, model):
x = x.detach().float().cpu() x = x.detach().float().cpu()
x = torch.clamp(x, -1., 1.) x = torch.clamp(x, -1., 1.)
x = (x + 1.)/2. x = (x + 1.)/2.
x = x.permute(1,2,0).numpy() x = x.permute(0, 2, 3, 1)#.numpy()
x = (255*x).astype(np.uint8) #x = (255*x).astype(np.uint8)
x = Image.fromarray(x) #x = Image.fromarray(x)
if not x.mode == "RGB": #if not x.mode == "RGB":
x = x.convert("RGB") # x = x.convert("RGB")
return x return x
image = model.decode(image) image = model.decode(image)
image = image.squeeze(0)
image = custom_to_pil(image) image = custom_to_pil(image)
return image return image
...@@ -178,7 +181,10 @@ class StableDiffusionModel(nn.Module): ...@@ -178,7 +181,10 @@ class StableDiffusionModel(nn.Module):
loss.append(i) loss.append(i)
for i in loss: for i in loss:
del ckpt["state_dict"][i] del ckpt["state_dict"][i]
model.first_stage_model = model.first_stage_model.float()
model.first_stage_model.load_state_dict(ckpt["state_dict"]) model.first_stage_model.load_state_dict(ckpt["state_dict"])
model.first_stage_model = model.first_stage_model.float()
del ckpt del ckpt
del loss del loss
...@@ -188,6 +194,9 @@ class StableDiffusionModel(nn.Module): ...@@ -188,6 +194,9 @@ class StableDiffusionModel(nn.Module):
self.model_config = model_config self.model_config = model_config
self.plms = PLMSSampler(model) self.plms = PLMSSampler(model)
self.ddim = DDIMSampler(model) self.ddim = DDIMSampler(model)
self.ema_manager = self.model.ema_scope
if self.config.enable_ema == "0":
self.ema_manager = contextlib.nullcontext
self.sampler_map = { self.sampler_map = {
'plms': self.plms.sample, 'plms': self.plms.sample,
'ddim': self.ddim.sample, 'ddim': self.ddim.sample,
...@@ -216,7 +225,7 @@ class StableDiffusionModel(nn.Module): ...@@ -216,7 +225,7 @@ class StableDiffusionModel(nn.Module):
pl_sd = torch.load(ckpt, map_location="cpu") pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd: if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}") print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"] sd = pl_sd
model = instantiate_from_config(config.model) model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False) m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose: if len(m) > 0 and verbose:
...@@ -321,7 +330,7 @@ class StableDiffusionModel(nn.Module): ...@@ -321,7 +330,7 @@ class StableDiffusionModel(nn.Module):
request.width // request.downsampling_factor request.width // request.downsampling_factor
] ]
if sampler == "normal": if sampler == "normal":
with self.model.ema_scope(): with self.ema_manager():
samples, _ = self.sampler_map[request.sampler]( samples, _ = self.sampler_map[request.sampler](
S=request.steps, S=request.steps,
conditioning=prompt_condition, conditioning=prompt_condition,
...@@ -336,7 +345,7 @@ class StableDiffusionModel(nn.Module): ...@@ -336,7 +345,7 @@ class StableDiffusionModel(nn.Module):
) )
elif sampler == "k-diffusion": elif sampler == "k-diffusion":
with self.model.ema_scope(): with self.ema_manager():
sigmas = self.k_model.get_sigmas(request.steps) sigmas = self.k_model.get_sigmas(request.steps)
if request.image is not None: if request.image is not None:
noise = main_noise * sigmas[request.steps - t_enc - 1] noise = main_noise * sigmas[request.steps - t_enc - 1]
...@@ -349,8 +358,11 @@ class StableDiffusionModel(nn.Module): ...@@ -349,8 +358,11 @@ class StableDiffusionModel(nn.Module):
extra_args = {'cond': prompt_condition, 'uncond': uc, 'cond_scale': request.scale} extra_args = {'cond': prompt_condition, 'uncond': uc, 'cond_scale': request.scale}
samples = self.sampler_map[request.sampler](self.k_model, start_code, sigmas, extra_args=extra_args) samples = self.sampler_map[request.sampler](self.k_model, start_code, sigmas, extra_args=extra_args)
x_samples_ddim = self.model.decode_first_stage(samples) with torch.autocast("cuda", enabled=False):
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) x_samples_ddim = self.model.decode_first_stage(samples.float())
#x_samples_ddim = decode_image(samples, self.model.first_stage_model)
#x_samples_ddim = self.model.first_stage_model.decode(samples.float())
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
images = [] images = []
for x_sample in x_samples_ddim: for x_sample in x_samples_ddim:
......
...@@ -2,9 +2,11 @@ export DTYPE="float32" ...@@ -2,9 +2,11 @@ export DTYPE="float32"
export AMP="1" export AMP="1"
export MODEL="stable-diffusion" export MODEL="stable-diffusion"
export DEV="True" export DEV="True"
export MODEL_PATH="/home/xuser/nvme1/stableckpt/v13" export MODEL_PATH="/home/xuser/nvme1/stableckpt/anime5000"
export MODULE_PATH="/home/xuser/nvme1/stableckpt/modules" export MODULE_PATH="/home/xuser/nvme1/stableckpt/modules"
export TRANSFORMERS_CACHE="/home/xuser/nvme1/transformer_cache" export TRANSFORMERS_CACHE="/home/xuser/nvme1/transformer_cache"
export SENTRY_URL="https://49ca8adcf4444f82a10eae1b3fd4182f@o846434.ingest.sentry.io/6612448" export SENTRY_URL="https://49ca8adcf4444f82a10eae1b3fd4182f@o846434.ingest.sentry.io/6612448"
export ENABLE_EMA="0"
export VAE_PATH="/home/xuser/nvme1/stableckpt/animevae.pt"
export PYTHONDONTWRITEBYTECODE=1 export PYTHONDONTWRITEBYTECODE=1
gunicorn main:app --workers 1 --worker-class uvicorn.workers.UvicornWorker --bind 0.0.0.0:4315 gunicorn main:app --workers 1 --worker-class uvicorn.workers.UvicornWorker --bind 0.0.0.0:4315
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment