Commit 50a21cb0 authored by Kohaku-Blueleaf's avatar Kohaku-Blueleaf

Ensure the cached weight will not be affected

parent 110485d5
......@@ -435,9 +435,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
for module in model.modules():
if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
if shared.opts.cache_fp16_weight:
module.fp16_weight = module.weight.clone().half()
module.fp16_weight = module.weight.data.clone().cpu().half()
if module.bias is not None:
module.fp16_bias = module.bias.clone().half()
module.fp16_bias = module.bias.data.clone().cpu().half()
module.to(torch.float8_e4m3fn)
model.first_stage_model = first_stage
timer.record("apply fp8")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment