Commit 27b9ec60 authored by Brad Smith's avatar Brad Smith

sort embeddings by name (case insensitive)

parent 22bcc7be
...@@ -2,7 +2,7 @@ import os ...@@ -2,7 +2,7 @@ import os
import sys import sys
import traceback import traceback
import inspect import inspect
from collections import namedtuple from collections import namedtuple, OrderedDict
import torch import torch
import tqdm import tqdm
...@@ -108,7 +108,7 @@ class DirWithTextualInversionEmbeddings: ...@@ -108,7 +108,7 @@ class DirWithTextualInversionEmbeddings:
class EmbeddingDatabase: class EmbeddingDatabase:
def __init__(self): def __init__(self):
self.ids_lookup = {} self.ids_lookup = {}
self.word_embeddings = {} self.word_embeddings = OrderedDict()
self.skipped_embeddings = {} self.skipped_embeddings = {}
self.expected_shape = -1 self.expected_shape = -1
self.embedding_dirs = {} self.embedding_dirs = {}
...@@ -233,6 +233,9 @@ class EmbeddingDatabase: ...@@ -233,6 +233,9 @@ class EmbeddingDatabase:
self.load_from_dir(embdir) self.load_from_dir(embdir)
embdir.update() embdir.update()
# re-sort word_embeddings because load_from_dir may not load in alphabetic order.
self.word_embeddings = {e.name: e for e in sorted(self.word_embeddings.values(), key=lambda e: e.name.lower())}
displayed_embeddings = (tuple(self.word_embeddings.keys()), tuple(self.skipped_embeddings.keys())) displayed_embeddings = (tuple(self.word_embeddings.keys()), tuple(self.skipped_embeddings.keys()))
if self.previously_displayed_embeddings != displayed_embeddings: if self.previously_displayed_embeddings != displayed_embeddings:
self.previously_displayed_embeddings = displayed_embeddings self.previously_displayed_embeddings = displayed_embeddings
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment