Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixup OpenAI Embeddings - fix the weighted mean #3778

Merged
merged 1 commit into from
May 1, 2023

Conversation

ravwojdyla
Copy link
Contributor

@ravwojdyla ravwojdyla commented Apr 29, 2023

Re: #3777

Copy pasting from the issue:

While working on #3722 I have noticed that there might be a bug in the current implementation of the OpenAI length safe embeddings in _get_len_safe_embeddings, which before #3722 was actually the default implementation regardless of the length of the context (via #2330).

It appears the weights used are constant and the length of the embedding vector (1536) and NOT the number of tokens in the batch, as in the reference implementation at https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb


Here's some debug info:

image


We can also validate this against the reference implementation:

Reference implementation (click to unroll)

This implementation is copy pasted from https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb

import openai
from itertools import islice
import numpy as np
from tenacity import retry, wait_random_exponential, stop_after_attempt, retry_if_not_exception_type


EMBEDDING_MODEL = 'text-embedding-ada-002'
EMBEDDING_CTX_LENGTH = 8191
EMBEDDING_ENCODING = 'cl100k_base'

# let's make sure to not retry on an invalid request, because that is what we want to demonstrate
@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6), retry=retry_if_not_exception_type(openai.InvalidRequestError))
def get_embedding(text_or_tokens, model=EMBEDDING_MODEL):
    return openai.Embedding.create(input=text_or_tokens, model=model)["data"][0]["embedding"]

def batched(iterable, n):
    """Batch data into tuples of length n. The last batch may be shorter."""
    # batched('ABCDEFG', 3) --> ABC DEF G
    if n < 1:
        raise ValueError('n must be at least one')
    it = iter(iterable)
    while (batch := tuple(islice(it, n))):
        yield batch
        
def chunked_tokens(text, encoding_name, chunk_length):
    encoding = tiktoken.get_encoding(encoding_name)
    tokens = encoding.encode(text)
    chunks_iterator = batched(tokens, chunk_length)
    yield from chunks_iterator


def reference_safe_get_embedding(text, model=EMBEDDING_MODEL, max_tokens=EMBEDDING_CTX_LENGTH, encoding_name=EMBEDDING_ENCODING, average=True):
    chunk_embeddings = []
    chunk_lens = []
    for chunk in chunked_tokens(text, encoding_name=encoding_name, chunk_length=max_tokens):
        chunk_embeddings.append(get_embedding(chunk, model=model))
        chunk_lens.append(len(chunk))

    if average:
        chunk_embeddings = np.average(chunk_embeddings, axis=0, weights=chunk_lens)
        chunk_embeddings = chunk_embeddings / np.linalg.norm(chunk_embeddings)  # normalizes length to 1
        chunk_embeddings = chunk_embeddings.tolist()
    return chunk_embeddings
long_text = 'foo bar' * 5000

reference_safe_get_embedding(long_text, average=True)[:10]

# Here's the first 10 floats from the reference embeddings:
[0.004407593824276758,
 0.0017611146161865465,
 -0.019824815970984996,
 -0.02177626039794025,
 -0.012060967454897886,
 0.0017955296329155309,
 -0.015609168983609643,
 -0.012059823076681351,
 -0.016990468527792825,
 -0.004970484452089445]


# and now langchain implementation
from langchain.embeddings.openai import OpenAIEmbeddings
OpenAIEmbeddings().embed_query(long_text)[:10]

[0.003791506184693747,
 0.0025310066579390025,
 -0.019282322699514628,
 -0.021492679249899803,
 -0.012598522213242891,
 0.0022181168611315662,
 -0.015858940621301307,
 -0.011754004130791204,
 -0.016402944319627515,
 -0.004125287485127554]

# clearly they are different ^

@ravwojdyla
Copy link
Contributor Author

ravwojdyla commented Apr 29, 2023

The original implementation of the length safe embeddings comes from #991. Would love your input @Hase-U in case there's something I've misunderstood here.

@ravwojdyla ravwojdyla force-pushed the rav-fixup-embeddings branch from 00bc52b to 2a20444 Compare April 29, 2023 06:48
@ravwojdyla
Copy link
Contributor Author

Revert the style changes ^

@Hase-U
Copy link
Contributor

Hase-U commented Apr 29, 2023

Thank you for your careful correction.
Certainly, token length should be used for that part

@ravwojdyla
Copy link
Contributor Author

Just want to point out for posterity that since #2330 _get_len_safe_embeddings became the default embedding method, which means some embeddings computed for texts longer than context max token length may be corrupted (the weight on the last batch may be too high).

Copy link
Contributor

@hwchase17 hwchase17 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks - great catch

@hwchase17 hwchase17 merged commit 039b672 into langchain-ai:master May 1, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants