Commit b443fdcf authored by AUTOMATIC1111's avatar AUTOMATIC1111

prevent accidental creation of CLIP models in float32 type when user wants float16

parent 7ee2114c
......@@ -61,9 +61,9 @@ class SD3Cond(torch.nn.Module):
self.tokenizer = SD3Tokenizer()
with torch.no_grad():
self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=torch.float32)
self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=torch.float32, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG)
self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=torch.float32)
self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=devices.dtype)
self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG)
self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype)
self.weights_loaded = False
......
......@@ -406,6 +406,7 @@ def set_model_fields(model):
if not hasattr(model, 'latent_channels'):
model.latent_channels = 4
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
sd_model_hash = checkpoint_info.calculate_shorthash()
timer.record("calculate hash")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment