Commit b2c9a5ee authored by novelailab's avatar novelailab

Fuse model correctly

parent 460d3b42
......@@ -193,7 +193,7 @@ def init_config_model():
# enable JIT
if config.jit_optimize == "1":
model.fuse_model(ema=config.enable_ema == "1")
model.fuse_model()
config.model = model
......
......@@ -233,9 +233,9 @@ class StableDiffusionModel(nn.Module):
self.model_config = model_config
self.plms = PLMSSampler(model)
self.ddim = DDIMSampler(model)
self.ema_manager = self.model.ema_scope
self.ema = True
if self.config.enable_ema == "0":
self.ema_manager = contextlib.nullcontext
self.ema = False
config.logger.info("Disabling EMA")
else:
config.logger.info(f"Using EMA")
......@@ -251,8 +251,10 @@ class StableDiffusionModel(nn.Module):
}
if config.prior_path:
self.prior = VectorAdjustPrior.load_model(config.prior_path).to(self.device)
self.copied_ema = True
def fuse_model(self):
def fuse_model(self, requires_grad=False):
ema = self.ema
for param in self.model.model.parameters():
param.requires_grad = False
......@@ -264,15 +266,29 @@ class StableDiffusionModel(nn.Module):
test_sigma = sigmas[1] * x_0.new_ones([x_0.shape[0]])
with torch.autocast("cuda", torch.float16):
self.single_step(ema)
x_two = torch.cat([x_0] * 2)
cnd = torch.cat([uc, c])
sigma_two = torch.cat([test_sigma] * 2)
inputs = {'apply_model': (x_two, sigma_two, cnd)}
traced_model = torch.jit.trace_module(self.model, inputs)
if requires_grad:
self.model.apply_model = lambda x, t, c : traced_model.apply_model(x, t, c)
#traced_model = traced_model.half()
self.k_model = K.external.CompVisDenoiser(traced_model)
self.k_model = K.external.StableInterface(self.k_model)
self.k_model = StableInterface(self.k_model)
self.single_step(ema)
for param in self.model.model.parameters():
param.requires_grad = requires_grad
def single_step(self, ema):
config = self.get_default_config
config.steps = 1
config.prompt = ""
self.sample(config, ema=ema)
def from_folder(self, folder):
folder = Path(folder)
......@@ -317,6 +333,15 @@ class StableDiffusionModel(nn.Module):
@torch.no_grad()
@torch.autocast("cuda", enabled=True, dtype=torch.float16)
def sample(self, request):
ema_manager = contextlib.nullcontext
if self.ema and not self.copied_ema:
self.model.model_ema.store(self.model.model.parameters())
self.model.model_ema.copy_to(self.model.model)
self.copied_ema = True
if not self.ema and self.copied_ema:
self.model.model_ema.restore(self.model.model.parameters())
self.copied_ema = False
if request.module is not None:
if request.module == "vanilla":
pass
......@@ -466,6 +491,15 @@ class StableDiffusionModel(nn.Module):
@torch.no_grad()
def sample_two_stages(self, request):
ema_manager = contextlib.nullcontext
if self.ema and not self.copied_ema:
self.model.model_ema.store(self.model.model.parameters())
self.model.model_ema.copy_to(self.model.model)
self.copied_ema = True
if not self.ema and self.copied_ema:
self.model.model_ema.restore(self.model.model.parameters())
self.copied_ema = False
request = DotMap(request)
if request.seed is not None:
seed_everything(request.seed)
......@@ -498,7 +532,7 @@ class StableDiffusionModel(nn.Module):
request.width // request.downsampling_factor
]
with torch.autocast("cuda", enabled=self.config.amp):
with self.model.ema_scope():
with ema_manager():
samples, _ = sampler.sample(
S=request.steps,
conditioning=prompt_condition,
......@@ -521,7 +555,7 @@ class StableDiffusionModel(nn.Module):
np.random.seed(request.stage_two_seed)
with torch.autocast("cuda", enabled=self.config.amp):
with self.model.ema_scope():
with ema_manager():
init_latent = self.model.get_first_stage_encoding(self.model.encode_first_stage(x_samples_ddim))
self.ddim.make_schedule(ddim_num_steps=request.steps, ddim_eta=request.ddim_eta, verbose=False)
t_enc = int(request.strength * request.steps)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment