Commit f30be23f authored by novelailab's avatar novelailab

fix no init

parent 02b5dfc0
...@@ -12,6 +12,22 @@ from sentry_sdk import capture_exception ...@@ -12,6 +12,22 @@ from sentry_sdk import capture_exception
from sentry_sdk.integrations.threading import ThreadingIntegration from sentry_sdk.integrations.threading import ThreadingIntegration
from hydra_node.models import StableDiffusionModel from hydra_node.models import StableDiffusionModel
def no_init(loading_code):
def dummy(self):
return
modules = [torch.nn.Linear, torch.nn.Embedding, torch.nn.LayerNorm]
original = {}
for mod in modules:
original[mod] = mod.reset_parameters
mod.reset_parameters = dummy
result = loading_code()
for mod in modules:
mod.reset_parameters = original[mod]
return result
def init_config_model(): def init_config_model():
config = DotMap() config = DotMap()
config.dtype = os.getenv("DTYPE", "float16") config.dtype = os.getenv("DTYPE", "float16")
...@@ -81,7 +97,7 @@ def init_config_model(): ...@@ -81,7 +97,7 @@ def init_config_model():
load_time = time.time() load_time = time.time()
try: try:
model = StableDiffusionModel(config) model = no_init(lambda: StableDiffusionModel(config))
except Exception as e: except Exception as e:
logger.error(f"Failed to load model: {str(e)}") logger.error(f"Failed to load model: {str(e)}")
capture_exception(e) capture_exception(e)
...@@ -97,7 +113,7 @@ def init_config_model(): ...@@ -97,7 +113,7 @@ def init_config_model():
f = open("/tmp/health_readiness", "w") f = open("/tmp/health_readiness", "w")
f.close() f.close()
time_load = time.time() - load_time time_load = time.time() - load_time
logger.info(f"Models loaded in {time_load:.2f}s") logger.info(f"Models loaded in {time_load:.2f}s")
......
...@@ -14,27 +14,11 @@ from ldm.models.diffusion.ddim import DDIMSampler ...@@ -14,27 +14,11 @@ from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler from ldm.models.diffusion.plms import PLMSSampler
import time import time
def no_init(loading_code):
def dummy(self):
return
modules = [torch.nn.Linear, torch.nn.Embedding, torch.nn.LayerNorm]
original = {}
for mod in modules:
original[mod] = mod.reset_parameters
mod.reset_parameters = dummy
result = loading_code()
for mod in modules:
mod.reset_parameters = original[mod]
return result
class StableDiffusionModel(nn.Module): class StableDiffusionModel(nn.Module):
def __init__(self, config): def __init__(self, config):
nn.Module.__init__(self) nn.Module.__init__(self)
self.config = config self.config = config
model, model_config = no_init(lambda: self.from_folder(config.model_path)) model, model_config = self.from_folder(config.model_path)
if config.dtype == "float16": if config.dtype == "float16":
typex = torch.float16 typex = torch.float16
else: else:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment