Commit 542a3d3a authored by AUTOMATIC's avatar AUTOMATIC

fix btoken hypernetworks in XY plot

parent 77a71964
...@@ -49,15 +49,18 @@ def list_hypernetworks(path): ...@@ -49,15 +49,18 @@ def list_hypernetworks(path):
def load_hypernetwork(filename): def load_hypernetwork(filename):
print(f"Loading hypernetwork {filename}")
path = shared.hypernetworks.get(filename, None) path = shared.hypernetworks.get(filename, None)
if (path is not None): if path is not None:
print(f"Loading hypernetwork {filename}")
try: try:
shared.loaded_hypernetwork = Hypernetwork(path) shared.loaded_hypernetwork = Hypernetwork(path)
except Exception: except Exception:
print(f"Error loading hypernetwork {path}", file=sys.stderr) print(f"Error loading hypernetwork {path}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
else: else:
if shared.loaded_hypernetwork is not None:
print(f"Unloading hypernetwork")
shared.loaded_hypernetwork = None shared.loaded_hypernetwork = None
......
...@@ -10,7 +10,7 @@ import numpy as np ...@@ -10,7 +10,7 @@ import numpy as np
import modules.scripts as scripts import modules.scripts as scripts
import gradio as gr import gradio as gr
from modules import images from modules import images, hypernetwork
from modules.processing import process_images, Processed, get_correct_sampler from modules.processing import process_images, Processed, get_correct_sampler
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
import modules.shared as shared import modules.shared as shared
...@@ -80,8 +80,7 @@ def apply_checkpoint(p, x, xs): ...@@ -80,8 +80,7 @@ def apply_checkpoint(p, x, xs):
def apply_hypernetwork(p, x, xs): def apply_hypernetwork(p, x, xs):
hn = shared.hypernetworks.get(x, None) hypernetwork.load_hypernetwork(x)
opts.data["sd_hypernetwork"] = hn.name if hn is not None else 'None'
def format_value_add_label(p, opt, x): def format_value_add_label(p, opt, x):
...@@ -203,8 +202,6 @@ class Script(scripts.Script): ...@@ -203,8 +202,6 @@ class Script(scripts.Script):
p.batch_size = 1 p.batch_size = 1
initial_hn = opts.sd_hypernetwork
def process_axis(opt, vals): def process_axis(opt, vals):
if opt.label == 'Nothing': if opt.label == 'Nothing':
return [0] return [0]
...@@ -321,6 +318,6 @@ class Script(scripts.Script): ...@@ -321,6 +318,6 @@ class Script(scripts.Script):
# restore checkpoint in case it was changed by axes # restore checkpoint in case it was changed by axes
modules.sd_models.reload_model_weights(shared.sd_model) modules.sd_models.reload_model_weights(shared.sd_model)
opts.data["sd_hypernetwork"] = initial_hn hypernetwork.load_hypernetwork(opts.sd_hypernetwork)
return processed return processed
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment