Commit d1f64ed9 authored by gd1551's avatar gd1551 Committed by GitHub

add metadata to output images, remove unused endpoint, fix model load path

parent ba6cb51b
...@@ -12,6 +12,7 @@ from sentry_sdk import capture_exception ...@@ -12,6 +12,7 @@ from sentry_sdk import capture_exception
from sentry_sdk.integrations.threading import ThreadingIntegration from sentry_sdk.integrations.threading import ThreadingIntegration
from hydra_node.models import StableDiffusionModel, DalleMiniModel from hydra_node.models import StableDiffusionModel, DalleMiniModel
import traceback import traceback
import zlib
model_map = {"stable-diffusion": StableDiffusionModel, "dalle-mini": DalleMiniModel} model_map = {"stable-diffusion": StableDiffusionModel, "dalle-mini": DalleMiniModel}
...@@ -31,6 +32,14 @@ def no_init(loading_code): ...@@ -31,6 +32,14 @@ def no_init(loading_code):
return result return result
def crc32(filename, chunksize=65536):
"""Compute the CRC-32 checksum of the contents of the given filename"""
with open(filename, "rb") as f:
checksum = 0
while (chunk := f.read(chunksize)) :
checksum = zlib.crc32(chunk, checksum)
return '%08X' % (checksum & 0xFFFFFFFF)
def init_config_model(): def init_config_model():
config = DotMap() config = DotMap()
config.dtype = os.getenv("DTYPE", "float16") config.dtype = os.getenv("DTYPE", "float16")
...@@ -113,6 +122,14 @@ def init_config_model(): ...@@ -113,6 +122,14 @@ def init_config_model():
#exit gunicorn #exit gunicorn
sys.exit(4) sys.exit(4)
if config.model_name == "stable-diffusion":
folder = Path(config.model_path)
if (folder / "pruned.ckpt").is_file():
model_path = folder / "pruned.ckpt"
else:
model_path = folder / "model.ckpt"
model_hash = crc32(model_path)
config.model = model config.model = model
# Mark that our model is loaded. # Mark that our model is loaded.
...@@ -125,4 +142,4 @@ def init_config_model(): ...@@ -125,4 +142,4 @@ def init_config_model():
time_load = time.time() - load_time time_load = time.time() - load_time
logger.info(f"Models loaded in {time_load:.2f}s") logger.info(f"Models loaded in {time_load:.2f}s")
return model, config return model, config, model_hash
\ No newline at end of file \ No newline at end of file
...@@ -178,7 +178,7 @@ class StableDiffusionModel(nn.Module): ...@@ -178,7 +178,7 @@ class StableDiffusionModel(nn.Module):
model_path = folder / "pruned.ckpt" model_path = folder / "pruned.ckpt"
else: else:
model_path = folder / "model.ckpt" model_path = folder / "model.ckpt"
model = self.load_model_from_config(model_config, folder / "pruned.ckpt") model = self.load_model_from_config(model_config, model_path)
return model, model_config return model, model_config
def load_model_from_config(self, config, ckpt, verbose=False): def load_model_from_config(self, config, ckpt, verbose=False):
......
...@@ -20,9 +20,11 @@ import simplejpeg ...@@ -20,9 +20,11 @@ import simplejpeg
import base64 import base64
import traceback import traceback
from PIL import Image from PIL import Image
from PIL.PngImagePlugin import PngInfo
import json
#Initialize model and config #Initialize model and config
model, config = init_config_model() model, config, model_hash = init_config_model()
logger = config.logger logger = config.logger
config.mainpid = int(open("gunicorn.pid", "r").read()) config.mainpid = int(open("gunicorn.pid", "r").read())
mainpid = config.mainpid mainpid = config.mainpid
...@@ -101,69 +103,27 @@ def generate(request: GenerationRequest): ...@@ -101,69 +103,27 @@ def generate(request: GenerationRequest):
else: else:
images = model.sample(request) images = model.sample(request)
seed = request.seed
images_encoded = [] images_encoded = []
for x in range(len(images)): for x in range(len(images)):
metadata = PngInfo()
metadata.add_text("Title", "AI generated image")
metadata.add_text("Description", request.prompt)
metadata.add_text("Software", "NovelAI")
metadata.add_text("Source", "Stable Diffusion "+model_hash)
request_copy = request.copy()
del request_copy.prompt
if request_copy.image is not None:
del request_copy.image
if seed is not None:
request_copy.seed = seed
seed += 1
metadata.add_text("Comment", json.dumps(request_copy))
#image = simplejpeg.encode_jpeg(images[x], quality=95) #image = simplejpeg.encode_jpeg(images[x], quality=95)
image = Image.fromarray(images[x]) image = Image.fromarray(images[x])
#save pillow image with bytesIO #save pillow image with bytesIO
output = io.BytesIO() output = io.BytesIO()
image.save(output, format='PNG') image.save(output, format='PNG', pnginfo=metadata)
image = output.getvalue()
#get base64 of image
image = base64.b64encode(image).decode("ascii")
images_encoded.append(image)
del images
process_time = time.perf_counter() - t
logger.info(f"Request took {process_time:0.3f} seconds")
data = ""
ptr = 0
for x in images_encoded:
ptr += 1
data += ("event: newImage\nid: {}\ndata:{}\n\n").format(ptr, x)
return Response(content=data, media_type="text/event-stream")
#return GenerationOutput(output=images)
except Exception as e:
traceback.print_exc()
capture_exception(e)
logger.error(str(e))
e_s = str(e)
gc.collect()
if "CUDA out of memory" in e_s or \
"an illegal memory access" in e_s or "CUDA" in e_s:
logger.error("GPU error, committing seppuku.")
os.kill(mainpid, signal.SIGTERM)
return {"error": str(e)}
@app.post('/generate-advanced-stream')
def generate_advanced(request: GenerationRequest):
t = time.perf_counter()
try:
output = sanitize_input(config, request)
if output[0]:
request = output[1]
else:
return ErrorOutput(error=output[1])
if request.advanced:
if request.n_samples > 1:
return ErrorOutput(error="advanced mode does not support n_samples > 1")
images = model.sample_two_stages(request)
else:
images = model.sample(request)
images_encoded = []
for x in range(len(images)):
#image = simplejpeg.encode_jpeg(images[x], quality=95)
image = Image.fromarray(images[x])
#save pillow image with bytesIO
output = io.BytesIO()
image.save(output, format='PNG')
image = output.getvalue() image = output.getvalue()
#get base64 of image #get base64 of image
image = base64.b64encode(image).decode("ascii") image = base64.b64encode(image).decode("ascii")
...@@ -193,7 +153,6 @@ def generate_advanced(request: GenerationRequest): ...@@ -193,7 +153,6 @@ def generate_advanced(request: GenerationRequest):
os.kill(mainpid, signal.SIGTERM) os.kill(mainpid, signal.SIGTERM)
return {"error": str(e)} return {"error": str(e)}
''' '''
@app.post('/image-to-image') @app.post('/image-to-image')
def image_to_image(request: GenerationRequest): def image_to_image(request: GenerationRequest):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment