Skip to content

Commit

Permalink
Fixup OpenAI Embeddings - fix the weighted mean
Browse files Browse the repository at this point in the history
  • Loading branch information
ravwojdyla committed Apr 29, 2023
1 parent 399065e commit 2a20444
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions langchain/embeddings/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def validate_environment(cls, values: Dict) -> Dict:
def _get_len_safe_embeddings(
self, texts: List[str], *, engine: str, chunk_size: Optional[int] = None
) -> List[List[float]]:
embeddings: List[List[float]] = [[] for i in range(len(texts))]
embeddings: List[List[float]] = [[] for _ in range(len(texts))]
try:
import tiktoken

Expand Down Expand Up @@ -180,10 +180,10 @@ def _get_len_safe_embeddings(
batched_embeddings += [r["embedding"] for r in response["data"]]

results: List[List[List[float]]] = [[] for _ in range(len(texts))]
lens: List[List[int]] = [[] for _ in range(len(texts))]
num_tokens_in_batch: List[List[int]] = [[] for _ in range(len(texts))]
for i in range(len(indices)):
results[indices[i]].append(batched_embeddings[i])
lens[indices[i]].append(len(batched_embeddings[i]))
num_tokens_in_batch[indices[i]].append(len(tokens[i]))

for i in range(len(texts)):
_result = results[i]
Expand All @@ -192,7 +192,9 @@ def _get_len_safe_embeddings(
"data"
][0]["embedding"]
else:
average = np.average(_result, axis=0, weights=lens[i])
average = np.average(
_result, axis=0, weights=num_tokens_in_batch[i]
)
embeddings[i] = (average / np.linalg.norm(average)).tolist()

return embeddings
Expand Down

0 comments on commit 2a20444

Please sign in to comment.