Skip to content

Conversation

@kacperlukawski
Copy link

Problem

The encode method raises a ValueError when we request precision different than float32 and output_value="token_embeddings", as reported in #2882.

Solution

This PR provides a fix that combines all the token embeddings into a single array, runs the normalization, and eventually reconstructs the shape of the original array so we can distinguish token embeddings coming from each input example.

@kacperlukawski kacperlukawski changed the title [fix] Quantization of token embeddings [fix] Quantization of token embeddings Aug 8, 2024
# It will happen when we request token_embeddings
lengths = [embedding.shape[0] for embedding in embeddings]
embeddings = np.concatenate(embeddings)
if isinstance(embeddings[0], Tensor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this if statement be above the previous if, as sending in a list of Tensors is also valid?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ir2718 You were absolutely right, thank you! Changed the order of the statements.

@tomaarsen
Copy link
Member

Hello!

Apologies, I haven't yet had time to look into this a bit deeper, but I think an edge case that might be missed is output_value=None. This is not very well documented, but it returns both the sentence embedding and the token embeddings. I can imagine that this might be valuable for some use cases.

  • Tom Aarsen

@ir2718
Copy link
Contributor

ir2718 commented Aug 9, 2024

Not sure if I can modify the PR, but following Tom's dict edge case, I think adding this should suffice:

        if isinstance(embeddings[0], dict):
            sentence_embeddings = [x["sentence_embedding"].unsqueeze(0).cpu().numpy() for x in embeddings]

            token_embeddings = []
            for emb_dict in embeddings:
                token_emb = emb_dict["token_embeddings"]
                attention = emb_dict["attention_mask"]
                last_mask_id = len(attention) - 1
                while last_mask_id > 0 and attention[last_mask_id].item() == 0:
                    last_mask_id -= 1

                token_embeddings.append(token_emb[0 : last_mask_id + 1])

            token_embeddings = [x.cpu().numpy() for x in token_embeddings]
            embeddings = token_embeddings + sentence_embeddings
            lengths = [x.shape[0] for x in embeddings]

with a modification in SentenceTransformer.py, line 638, right before the return statement:

        if output_value is None:
            return {
                "token_embeddings": all_embeddings[:len(all_embeddings)//2],
                "sentence_embedding": all_embeddings[len(all_embeddings)//2:]
            }

@kacperlukawski
Copy link
Author

Thanks, @ir2718! I wonder whether we should return a dictionary. That breaks the interface of the encode method. @tomaarsen Would that be the expected behaviour?

@kacperlukawski
Copy link
Author

kacperlukawski commented Aug 30, 2024

I decided to implement the quantization for this edge case differently than suggested. The quantize_embeddings wasn't modified, but I extended the encode method. The all_embeddings were already a dictionary there, so I combined token and sentence embeddings and passed them all together to quantize. The output dictionary structure remains unchanged except for a different precision.

@ir2718 I didn't use the attention mask on purpose. I thought it would be best to keep the shapes consistent, no matter if we use float32 or any other precision.

@ir2718
Copy link
Contributor

ir2718 commented Aug 30, 2024

I wonder whether we should return a dictionary. That breaks the interface of the encode method.

Agreed, I was thinking about that myself, but since transformers mostly handle things in dicts my first idea was to implement it that way. Not breaking the interface is probably a better solution, but requires adding some kind of note in the docs about the ordering of embeddings.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants