Commit 61788c05 authored by DepFA's avatar DepFA Committed by GitHub

shift embedding logic out of textual_inversion

parent e5fbf5c7
......@@ -7,124 +7,11 @@ import tqdm
import html
import datetime
from PIL import Image,PngImagePlugin,ImageDraw
from ..images import captionImageOverlay
import numpy as np
import base64
import json
import zlib
from PIL import Image,PngImagePlugin
from modules import shared, devices, sd_hijack, processing, sd_models
import modules.textual_inversion.dataset
class EmbeddingEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, torch.Tensor):
return {'TORCHTENSOR':obj.cpu().detach().numpy().tolist()}
return json.JSONEncoder.default(self, obj)
class EmbeddingDecoder(json.JSONDecoder):
def __init__(self, *args, **kwargs):
json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs)
def object_hook(self, d):
if 'TORCHTENSOR' in d:
return torch.from_numpy(np.array(d['TORCHTENSOR']))
return d
def embeddingToB64(data):
d = json.dumps(data,cls=EmbeddingEncoder)
return base64.b64encode(d.encode())
def embeddingFromB64(data):
d = base64.b64decode(data)
return json.loads(d,cls=EmbeddingDecoder)
def lcg(m=2**32, a=1664525, c=1013904223, seed=0):
while True:
seed = (a * seed + c) % m
yield seed
def xorBlock(block):
g = lcg()
randblock = np.array([next(g) for _ in range(np.product(block.shape))]).astype(np.uint8).reshape(block.shape)
return np.bitwise_xor(block.astype(np.uint8),randblock & 0x0F)
def styleBlock(block,sequence):
im = Image.new('RGB',(block.shape[1],block.shape[0]))
draw = ImageDraw.Draw(im)
i=0
for x in range(-6,im.size[0],8):
for yi,y in enumerate(range(-6,im.size[1],8)):
offset=0
if yi%2==0:
offset=4
shade = sequence[i%len(sequence)]
i+=1
draw.ellipse((x+offset, y, x+6+offset, y+6), fill =(shade,shade,shade) )
fg = np.array(im).astype(np.uint8) & 0xF0
return block ^ fg
def insertImageDataEmbed(image,data):
d = 3
data_compressed = zlib.compress( json.dumps(data,cls=EmbeddingEncoder).encode(),level=9)
dnp = np.frombuffer(data_compressed,np.uint8).copy()
dnphigh = dnp >> 4
dnplow = dnp & 0x0F
h = image.size[1]
next_size = dnplow.shape[0] + (h-(dnplow.shape[0]%h))
next_size = next_size + ((h*d)-(next_size%(h*d)))
dnplow.resize(next_size)
dnplow = dnplow.reshape((h,-1,d))
dnphigh.resize(next_size)
dnphigh = dnphigh.reshape((h,-1,d))
edgeStyleWeights = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024]
edgeStyleWeights = (np.abs(edgeStyleWeights)/np.max(np.abs(edgeStyleWeights))*255).astype(np.uint8)
dnplow = styleBlock(dnplow,sequence=edgeStyleWeights)
dnplow = xorBlock(dnplow)
dnphigh = styleBlock(dnphigh,sequence=edgeStyleWeights[::-1])
dnphigh = xorBlock(dnphigh)
imlow = Image.fromarray(dnplow,mode='RGB')
imhigh = Image.fromarray(dnphigh,mode='RGB')
background = Image.new('RGB',(image.size[0]+imlow.size[0]+imhigh.size[0]+2,image.size[1]),(0,0,0))
background.paste(imlow,(0,0))
background.paste(image,(imlow.size[0]+1,0))
background.paste(imhigh,(imlow.size[0]+1+image.size[0]+1,0))
return background
def crop_black(img,tol=0):
mask = (img>tol).all(2)
mask0,mask1 = mask.any(0),mask.any(1)
col_start,col_end = mask0.argmax(),mask.shape[1]-mask0[::-1].argmax()
row_start,row_end = mask1.argmax(),mask.shape[0]-mask1[::-1].argmax()
return img[row_start:row_end,col_start:col_end]
def extractImageDataEmbed(image):
d=3
outarr = crop_black(np.array(image.convert('RGB').getdata()).reshape(image.size[1],image.size[0],d ).astype(np.uint8) ) & 0x0F
blackCols = np.where( np.sum(outarr, axis=(0,2))==0)
if blackCols[0].shape[0] < 2:
print('No Image data blocks found.')
return None
dataBlocklower = outarr[:,:blackCols[0].min(),:].astype(np.uint8)
dataBlockupper = outarr[:,blackCols[0].max()+1:,:].astype(np.uint8)
dataBlocklower = xorBlock(dataBlocklower)
dataBlockupper = xorBlock(dataBlockupper)
dataBlock = (dataBlockupper << 4) | (dataBlocklower)
dataBlock = dataBlock.flatten().tobytes()
data = zlib.decompress(dataBlock)
return json.loads(data,cls=EmbeddingDecoder)
class Embedding:
def __init__(self, vec, name, step=None):
......@@ -199,10 +86,10 @@ class EmbeddingDatabase:
if filename.upper().endswith('.PNG'):
embed_image = Image.open(path)
if 'sd-ti-embedding' in embed_image.text:
data = embeddingFromB64(embed_image.text['sd-ti-embedding'])
data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
name = data.get('name',name)
else:
data = extractImageDataEmbed(embed_image)
data = extract_image_data_embed(embed_image)
name = data.get('name',name)
else:
data = torch.load(path, map_location="cpu")
......@@ -393,7 +280,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
info = PngImagePlugin.PngInfo()
data = torch.load(last_saved_file)
info.add_text("sd-ti-embedding", embeddingToB64(data))
info.add_text("sd-ti-embedding", embedding_to_b64(data))
title = "<{}>".format(data.get('name','???'))
checkpoint = sd_models.select_checkpoint()
......@@ -401,8 +288,8 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
footer_mid = '[{}]'.format(checkpoint.hash)
footer_right = '{}'.format(embedding.step)
captioned_image = captionImageOverlay(image,title,footer_left,footer_mid,footer_right)
captioned_image = insertImageDataEmbed(captioned_image,data)
captioned_image = caption_image_overlay(image,title,footer_left,footer_mid,footer_right)
captioned_image = insert_image_data_embed(captioned_image,data)
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment