Commit 0087079c authored by DepFA's avatar DepFA Committed by GitHub

allow overwrite old embedding

parent 166be391
...@@ -153,7 +153,7 @@ class EmbeddingDatabase: ...@@ -153,7 +153,7 @@ class EmbeddingDatabase:
return None, None return None, None
def create_embedding(name, num_vectors_per_token, init_text='*'): def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
cond_model = shared.sd_model.cond_stage_model cond_model = shared.sd_model.cond_stage_model
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
...@@ -165,7 +165,8 @@ def create_embedding(name, num_vectors_per_token, init_text='*'): ...@@ -165,7 +165,8 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token] vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt") fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
assert not os.path.exists(fn), f"file {fn} already exists" if not overwrite_old:
assert not os.path.exists(fn), f"file {fn} already exists"
embedding = Embedding(vec, name) embedding = Embedding(vec, name)
embedding.step = 0 embedding.step = 0
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment