Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions extensions/openai/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,21 @@ def load_embedding_model(model: str) -> SentenceTransformer:
initialize_embedding_params()
global embeddings_device, embeddings_model
try:
embeddings_model = 'loading...' # flag
print(f"\Try embedding model: {model} on {embeddings_device}")
# see: https://www.sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer
emb_model = SentenceTransformer(model, device=embeddings_device)
# ... emb_model.device doesn't seem to work, always cpu anyways? but specify cpu anyways to free more VRAM
print(f"\nLoaded embedding model: {model} on {emb_model.device} [always seems to say 'cpu', even if 'cuda'], max sequence length: {emb_model.max_seq_length}")
embeddings_model = SentenceTransformer(model, device=embeddings_device)
# ... embeddings_model.device doesn't seem to work, always cpu anyways? but specify cpu anyways to free more VRAM
print(f"\nLoaded embedding model: {model} on {embeddings_model.device} [always seems to say 'cpu', even if 'cuda'], max sequence length: {embeddings_model.max_seq_length}")
except Exception as e:
embeddings_model = None
raise ServiceUnavailableError(f"Error: Failed to load embedding model: {model}", internal_message=repr(e))

return emb_model


def get_embeddings_model() -> SentenceTransformer:
initialize_embedding_params()
global embeddings_model, st_model
if st_model and not embeddings_model:
embeddings_model = load_embedding_model(st_model) # lazy load the model
load_embedding_model(st_model) # lazy load the model
return embeddings_model


Expand All @@ -53,7 +51,11 @@ def get_embeddings_model_name() -> str:


def get_embeddings(input: list) -> np.ndarray:
return get_embeddings_model().encode(input, convert_to_numpy=True, normalize_embeddings=True, convert_to_tensor=False, device=embeddings_device)
model = get_embeddings_model()
debug_msg(f"embedding model : {model}")
embedding = model.encode(input, convert_to_numpy=True, normalize_embeddings=True, convert_to_tensor=False)
debug_msg(f"embedding result : {embedding}") # might be too long even for debug, use at you own will
return embedding


def embeddings(input: list, encoding_format: str) -> dict:
Expand Down