Commit cab1d839 authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub

Merge pull request #14563 from Nuullll/model-loaded-callback

Execute model_loaded_callback after moving to target device
parents 71e00571 a183de04
...@@ -842,13 +842,13 @@ def reload_model_weights(sd_model=None, info=None, forced_reload=False): ...@@ -842,13 +842,13 @@ def reload_model_weights(sd_model=None, info=None, forced_reload=False):
sd_hijack.model_hijack.hijack(sd_model) sd_hijack.model_hijack.hijack(sd_model)
timer.record("hijack") timer.record("hijack")
script_callbacks.model_loaded_callback(sd_model)
timer.record("script callbacks")
if not sd_model.lowvram: if not sd_model.lowvram:
sd_model.to(devices.device) sd_model.to(devices.device)
timer.record("move model to device") timer.record("move model to device")
script_callbacks.model_loaded_callback(sd_model)
timer.record("script callbacks")
print(f"Weights loaded in {timer.summary()}.") print(f"Weights loaded in {timer.summary()}.")
model_data.set_sd_model(sd_model) model_data.set_sd_model(sd_model)
......
...@@ -273,10 +273,11 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified): ...@@ -273,10 +273,11 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
load_vae(sd_model, vae_file, vae_source) load_vae(sd_model, vae_file, vae_source)
sd_hijack.model_hijack.hijack(sd_model) sd_hijack.model_hijack.hijack(sd_model)
script_callbacks.model_loaded_callback(sd_model)
if not sd_model.lowvram: if not sd_model.lowvram:
sd_model.to(devices.device) sd_model.to(devices.device)
script_callbacks.model_loaded_callback(sd_model)
print("VAE weights loaded.") print("VAE weights loaded.")
return sd_model return sd_model
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment