Commit 1df6c8bf authored by Kohaku-Blueleaf's avatar Kohaku-Blueleaf

fp8 for TE

parent 9c1eba2a
......@@ -407,6 +407,13 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
module.to(torch.float8_e4m3fn)
timer.record("apply fp8 unet for cpu")
else:
if model.is_sdxl:
cond_stage = model.conditioner
else:
cond_stage = model.cond_stage_model
for module in cond_stage.modules():
if isinstance(module, torch.nn.Linear):
module.to(torch.float8_e4m3fn)
model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn)
timer.record("apply fp8 unet")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment