Commit 45caeca0 authored by kurumuz's avatar kurumuz

modules, maybe works

parent ad4eaba6
......@@ -14,6 +14,7 @@ from hydra_node.models import StableDiffusionModel, DalleMiniModel
import traceback
import zlib
from pathlib import Path
from ldm.modules.attention import CrossAttention, HyperLogic
model_map = {"stable-diffusion": StableDiffusionModel, "dalle-mini": DalleMiniModel}
......@@ -41,6 +42,38 @@ def crc32(filename, chunksize=65536):
checksum = zlib.crc32(chunk, checksum)
return '%08X' % (checksum & 0xFFFFFFFF)
def load_modules(path):
path = Path(path)
modules = {}
if not path.is_dir():
return
for file in path.iterdir():
module = load_module(file, "cuda")
modules[file.stem] = module
print(f"Loaded module {file.stem}")
return modules
def load_module(path, device):
path = Path(path)
if not path.is_file():
print("Module path {} is not a file".format(path))
network = {
768: (HyperLogic(768).to(device), HyperLogic(768).to(device)),
1280: (HyperLogic(1280).to(device), HyperLogic(1280).to(device)),
640: (HyperLogic(640).to(device), HyperLogic(640).to(device)),
320: (HyperLogic(320).to(device), HyperLogic(320).to(device)),
}
state_dict = torch.load(path)
for key in state_dict.keys():
network[key][0].load_state_dict(state_dict[key][0])
network[key][1].load_state_dict(state_dict[key][1])
return network
def init_config_model():
config = DotMap()
config.dtype = os.getenv("DTYPE", "float16")
......@@ -131,6 +164,11 @@ def init_config_model():
model_path = folder / "model.ckpt"
model_hash = crc32(model_path)
#Load Modules
modules = load_modules(folder)
#attach it to the model
model.premodules = modules
config.model = model
# Mark that our model is loaded.
......
......@@ -12,6 +12,7 @@ from torchvision.utils import make_grid
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.modules.attention import CrossAttention, HyperLogic
import time
from PIL import Image
import k_diffusion as K
......@@ -154,6 +155,7 @@ class StableDiffusionModel(nn.Module):
def __init__(self, config):
nn.Module.__init__(self)
self.config = config
self.premodules = None
model, model_config = self.from_folder(config.model_path)
if config.dtype == "float16":
typex = torch.float16
......@@ -210,6 +212,10 @@ class StableDiffusionModel(nn.Module):
@torch.no_grad()
@torch.autocast("cuda", enabled=True, dtype=torch.float16)
def sample(self, request):
if request.module is not None:
module = self.premodules[request.module]
CrossAttention.set_hypernetwork(module)
if request.seed is not None:
torch.manual_seed(request.seed)
np.random.seed(request.seed)
......
......@@ -19,6 +19,7 @@ v1pp_defaults = {
'dynamic_threshold': None,
'seed': None,
'stage_two_seed': None,
'module': None,
}
v1pp_forced_defaults = {
......@@ -58,7 +59,7 @@ def closest_multiple(num, mult):
ceil = math.ceil(num_int / mult) * mult
return floor if (num_int - floor) < (ceil - num_int) else ceil
def sanitize_stable_diffusion(request):
def sanitize_stable_diffusion(request, config):
if request.width * request.height == 0:
return False, "width and height must be non-zero"
......@@ -101,6 +102,10 @@ def sanitize_stable_diffusion(request):
request.seed = random.randint(0, 2**32)
random.setstate(state)
if request.module is not None:
if request.module not in config.model.premodules:
return False, "module should be one of {}".format(config.model.premodules.keys())
if request.image is not None:
#decode from base64
try:
......@@ -148,7 +153,7 @@ def sanitize_input(config, request):
request[k] = v
if config.model_name == 'stable-diffusion':
return sanitize_stable_diffusion(request)
return sanitize_stable_diffusion(request, config)
elif config.model_name == 'dalle-mini':
return sanitize_dalle_mini(request)
\ No newline at end of file
......@@ -77,6 +77,7 @@ class GenerationRequest(BaseModel):
strength: float = 0.69
noise: float = 0.667
mitigate: bool = False
module: str = None
class GenerationOutput(BaseModel):
output: List[str]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment