Commit 45caeca0 authored by kurumuz's avatar kurumuz

modules, maybe works

parent ad4eaba6
...@@ -14,6 +14,7 @@ from hydra_node.models import StableDiffusionModel, DalleMiniModel ...@@ -14,6 +14,7 @@ from hydra_node.models import StableDiffusionModel, DalleMiniModel
import traceback import traceback
import zlib import zlib
from pathlib import Path from pathlib import Path
from ldm.modules.attention import CrossAttention, HyperLogic
model_map = {"stable-diffusion": StableDiffusionModel, "dalle-mini": DalleMiniModel} model_map = {"stable-diffusion": StableDiffusionModel, "dalle-mini": DalleMiniModel}
...@@ -41,6 +42,38 @@ def crc32(filename, chunksize=65536): ...@@ -41,6 +42,38 @@ def crc32(filename, chunksize=65536):
checksum = zlib.crc32(chunk, checksum) checksum = zlib.crc32(chunk, checksum)
return '%08X' % (checksum & 0xFFFFFFFF) return '%08X' % (checksum & 0xFFFFFFFF)
def load_modules(path):
path = Path(path)
modules = {}
if not path.is_dir():
return
for file in path.iterdir():
module = load_module(file, "cuda")
modules[file.stem] = module
print(f"Loaded module {file.stem}")
return modules
def load_module(path, device):
path = Path(path)
if not path.is_file():
print("Module path {} is not a file".format(path))
network = {
768: (HyperLogic(768).to(device), HyperLogic(768).to(device)),
1280: (HyperLogic(1280).to(device), HyperLogic(1280).to(device)),
640: (HyperLogic(640).to(device), HyperLogic(640).to(device)),
320: (HyperLogic(320).to(device), HyperLogic(320).to(device)),
}
state_dict = torch.load(path)
for key in state_dict.keys():
network[key][0].load_state_dict(state_dict[key][0])
network[key][1].load_state_dict(state_dict[key][1])
return network
def init_config_model(): def init_config_model():
config = DotMap() config = DotMap()
config.dtype = os.getenv("DTYPE", "float16") config.dtype = os.getenv("DTYPE", "float16")
...@@ -131,6 +164,11 @@ def init_config_model(): ...@@ -131,6 +164,11 @@ def init_config_model():
model_path = folder / "model.ckpt" model_path = folder / "model.ckpt"
model_hash = crc32(model_path) model_hash = crc32(model_path)
#Load Modules
modules = load_modules(folder)
#attach it to the model
model.premodules = modules
config.model = model config.model = model
# Mark that our model is loaded. # Mark that our model is loaded.
......
...@@ -12,6 +12,7 @@ from torchvision.utils import make_grid ...@@ -12,6 +12,7 @@ from torchvision.utils import make_grid
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler from ldm.models.diffusion.plms import PLMSSampler
from ldm.modules.attention import CrossAttention, HyperLogic
import time import time
from PIL import Image from PIL import Image
import k_diffusion as K import k_diffusion as K
...@@ -154,6 +155,7 @@ class StableDiffusionModel(nn.Module): ...@@ -154,6 +155,7 @@ class StableDiffusionModel(nn.Module):
def __init__(self, config): def __init__(self, config):
nn.Module.__init__(self) nn.Module.__init__(self)
self.config = config self.config = config
self.premodules = None
model, model_config = self.from_folder(config.model_path) model, model_config = self.from_folder(config.model_path)
if config.dtype == "float16": if config.dtype == "float16":
typex = torch.float16 typex = torch.float16
...@@ -210,6 +212,10 @@ class StableDiffusionModel(nn.Module): ...@@ -210,6 +212,10 @@ class StableDiffusionModel(nn.Module):
@torch.no_grad() @torch.no_grad()
@torch.autocast("cuda", enabled=True, dtype=torch.float16) @torch.autocast("cuda", enabled=True, dtype=torch.float16)
def sample(self, request): def sample(self, request):
if request.module is not None:
module = self.premodules[request.module]
CrossAttention.set_hypernetwork(module)
if request.seed is not None: if request.seed is not None:
torch.manual_seed(request.seed) torch.manual_seed(request.seed)
np.random.seed(request.seed) np.random.seed(request.seed)
......
...@@ -19,6 +19,7 @@ v1pp_defaults = { ...@@ -19,6 +19,7 @@ v1pp_defaults = {
'dynamic_threshold': None, 'dynamic_threshold': None,
'seed': None, 'seed': None,
'stage_two_seed': None, 'stage_two_seed': None,
'module': None,
} }
v1pp_forced_defaults = { v1pp_forced_defaults = {
...@@ -58,7 +59,7 @@ def closest_multiple(num, mult): ...@@ -58,7 +59,7 @@ def closest_multiple(num, mult):
ceil = math.ceil(num_int / mult) * mult ceil = math.ceil(num_int / mult) * mult
return floor if (num_int - floor) < (ceil - num_int) else ceil return floor if (num_int - floor) < (ceil - num_int) else ceil
def sanitize_stable_diffusion(request): def sanitize_stable_diffusion(request, config):
if request.width * request.height == 0: if request.width * request.height == 0:
return False, "width and height must be non-zero" return False, "width and height must be non-zero"
...@@ -101,6 +102,10 @@ def sanitize_stable_diffusion(request): ...@@ -101,6 +102,10 @@ def sanitize_stable_diffusion(request):
request.seed = random.randint(0, 2**32) request.seed = random.randint(0, 2**32)
random.setstate(state) random.setstate(state)
if request.module is not None:
if request.module not in config.model.premodules:
return False, "module should be one of {}".format(config.model.premodules.keys())
if request.image is not None: if request.image is not None:
#decode from base64 #decode from base64
try: try:
...@@ -148,7 +153,7 @@ def sanitize_input(config, request): ...@@ -148,7 +153,7 @@ def sanitize_input(config, request):
request[k] = v request[k] = v
if config.model_name == 'stable-diffusion': if config.model_name == 'stable-diffusion':
return sanitize_stable_diffusion(request) return sanitize_stable_diffusion(request, config)
elif config.model_name == 'dalle-mini': elif config.model_name == 'dalle-mini':
return sanitize_dalle_mini(request) return sanitize_dalle_mini(request)
\ No newline at end of file
...@@ -77,6 +77,7 @@ class GenerationRequest(BaseModel): ...@@ -77,6 +77,7 @@ class GenerationRequest(BaseModel):
strength: float = 0.69 strength: float = 0.69
noise: float = 0.667 noise: float = 0.667
mitigate: bool = False mitigate: bool = False
module: str = None
class GenerationOutput(BaseModel): class GenerationOutput(BaseModel):
output: List[str] output: List[str]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment