-
Notifications
You must be signed in to change notification settings - Fork 15.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
community[minor]: Add support for modle2vec embeddings (#28507)
This PR add an embeddings integration for model2vec, the `Model2vecEmbeddings` class. - **Description**: [Model2Vec](https://github.com/MinishLab/model2vec) lets you turn any sentence transformer into a really small static model and makes running the model faster. - **Issue**: - **Dependencies**: model2vec ([pypi](https://pypi.org/project/model2vec/)) - **Twitter handle:**: - [x] **Add tests and docs**: - [Test](https://github.com/blacksmithop/langchain/blob/model2vec_embeddings/libs/community/langchain_community/embeddings/model2vec.py), [docs](https://github.com/blacksmithop/langchain/blob/model2vec_embeddings/docs/docs/integrations/text_embedding/model2vec.ipynb) - [x] **Lint and test**: --------- Co-authored-by: Abhinav KM <[email protected]> Co-authored-by: Bagatur <[email protected]>
- Loading branch information
1 parent
fbf0704
commit 317a38b
Showing
5 changed files
with
284 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "e8712110", | ||
"metadata": {}, | ||
"source": [ | ||
"## Overview\n", | ||
"\n", | ||
"Model2Vec is a technique to turn any sentence transformer into a really small static model\n", | ||
"[model2vec](https://github.com/MinishLab/model2vec) can be used to generate embeddings." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "266dd424", | ||
"metadata": {}, | ||
"source": [ | ||
"## Setup\n", | ||
"\n", | ||
"```bash\n", | ||
"pip install -U langchain-community\n", | ||
"```\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "78ab91a6", | ||
"metadata": {}, | ||
"source": [ | ||
"## Instantiation" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "d06e7719", | ||
"metadata": {}, | ||
"source": [ | ||
"Ensure that `model2vec` is installed\n", | ||
"\n", | ||
"```bash\n", | ||
"pip install -U model2vec\n", | ||
"```" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "f8ea1ed5", | ||
"metadata": {}, | ||
"source": [ | ||
"## Indexing and Retrieval" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"id": "d25dc22d-b656-46c6-a42d-eace958590cd", | ||
"metadata": { | ||
"ExecuteTime": { | ||
"end_time": "2023-05-24T15:13:17.176956Z", | ||
"start_time": "2023-05-24T15:13:15.399076Z" | ||
}, | ||
"execution": { | ||
"iopub.execute_input": "2024-03-29T15:39:19.252281Z", | ||
"iopub.status.busy": "2024-03-29T15:39:19.252101Z", | ||
"iopub.status.idle": "2024-03-29T15:39:19.339106Z", | ||
"shell.execute_reply": "2024-03-29T15:39:19.338614Z", | ||
"shell.execute_reply.started": "2024-03-29T15:39:19.252260Z" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"from langchain_community.embeddings import Model2vecEmbeddings" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"id": "8397b91f-a1f9-4be6-a699-fedaada7c37a", | ||
"metadata": { | ||
"ExecuteTime": { | ||
"end_time": "2023-05-24T15:13:17.193751Z", | ||
"start_time": "2023-05-24T15:13:17.182053Z" | ||
}, | ||
"execution": { | ||
"iopub.execute_input": "2024-03-29T15:39:19.901573Z", | ||
"iopub.status.busy": "2024-03-29T15:39:19.900935Z", | ||
"iopub.status.idle": "2024-03-29T15:39:19.906540Z", | ||
"shell.execute_reply": "2024-03-29T15:39:19.905345Z", | ||
"shell.execute_reply.started": "2024-03-29T15:39:19.901529Z" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"embeddings = Model2vecEmbeddings(\"minishlab/potion-base-8M\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"id": "abcf98b7-424c-4691-a1cd-862c3d53be11", | ||
"metadata": { | ||
"ExecuteTime": { | ||
"end_time": "2023-05-24T15:13:17.844903Z", | ||
"start_time": "2023-05-24T15:13:17.198751Z" | ||
}, | ||
"execution": { | ||
"iopub.execute_input": "2024-03-29T15:39:20.434581Z", | ||
"iopub.status.busy": "2024-03-29T15:39:20.433117Z", | ||
"iopub.status.idle": "2024-03-29T15:39:22.178650Z", | ||
"shell.execute_reply": "2024-03-29T15:39:22.176058Z", | ||
"shell.execute_reply.started": "2024-03-29T15:39:20.434501Z" | ||
}, | ||
"scrolled": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"query_text = \"This is a test query.\"\n", | ||
"query_result = embeddings.embed_query(query_text)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"id": "98897454-b280-4ee1-bbb9-2c6c15342f87", | ||
"metadata": { | ||
"ExecuteTime": { | ||
"end_time": "2023-05-24T15:13:18.605339Z", | ||
"start_time": "2023-05-24T15:13:17.845906Z" | ||
}, | ||
"execution": { | ||
"iopub.execute_input": "2024-03-29T15:39:28.164009Z", | ||
"iopub.status.busy": "2024-03-29T15:39:28.161759Z", | ||
"iopub.status.idle": "2024-03-29T15:39:30.217232Z", | ||
"shell.execute_reply": "2024-03-29T15:39:30.215348Z", | ||
"shell.execute_reply.started": "2024-03-29T15:39:28.163876Z" | ||
}, | ||
"scrolled": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"document_text = \"This is a test document.\"\n", | ||
"document_result = embeddings.embed_documents([document_text])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "11bac134", | ||
"metadata": {}, | ||
"source": [ | ||
"## Direct Usage\n", | ||
"\n", | ||
"Here's how you would directly make use of `model2vec`\n", | ||
"\n", | ||
"```python\n", | ||
"from model2vec import StaticModel\n", | ||
"\n", | ||
"# Load a model from the HuggingFace hub (in this case the potion-base-8M model)\n", | ||
"model = StaticModel.from_pretrained(\"minishlab/potion-base-8M\")\n", | ||
"\n", | ||
"# Make embeddings\n", | ||
"embeddings = model.encode([\"It's dangerous to go alone!\", \"It's a secret to everybody.\"])\n", | ||
"\n", | ||
"# Make sequences of token embeddings\n", | ||
"token_embeddings = model.encode_as_sequence([\"It's dangerous to go alone!\", \"It's a secret to everybody.\"])\n", | ||
"```" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "d81e21aa", | ||
"metadata": {}, | ||
"source": [ | ||
"## API Reference\n", | ||
"\n", | ||
"For more information check out the model2vec github [repo](https://github.com/MinishLab/model2vec)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.3" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
66 changes: 66 additions & 0 deletions
66
libs/community/langchain_community/embeddings/model2vec.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
"""Wrapper around model2vec embedding models.""" | ||
|
||
from typing import List | ||
|
||
from langchain_core.embeddings import Embeddings | ||
|
||
|
||
class Model2vecEmbeddings(Embeddings): | ||
"""model2v embedding models. | ||
Install model2vec first, run 'pip install -U model2vec'. | ||
The github repository for model2vec is : https://github.com/MinishLab/model2vec | ||
Example: | ||
.. code-block:: python | ||
from langchain_community.embeddings import Model2vecEmbeddings | ||
embedding = Model2vecEmbeddings("minishlab/potion-base-8M") | ||
embedding.embed_documents([ | ||
"It's dangerous to go alone!", | ||
"It's a secret to everybody.", | ||
]) | ||
embedding.embed_query( | ||
"Take this with you." | ||
) | ||
""" | ||
|
||
def __init__(self, model: str): | ||
"""Initialize embeddings. | ||
Args: | ||
model: Model name. | ||
""" | ||
try: | ||
from model2vec import StaticModel | ||
except ImportError as e: | ||
raise ImportError( | ||
"Unable to import model2vec, please install with " | ||
"`pip install -U model2vec`." | ||
) from e | ||
self._model = StaticModel.from_pretrained(model) | ||
|
||
def embed_documents(self, texts: List[str]) -> List[List[float]]: | ||
"""Embed documents using the model2vec embeddings model. | ||
Args: | ||
texts: The list of texts to embed. | ||
Returns: | ||
List of embeddings, one for each text. | ||
""" | ||
|
||
return self._model.encode_as_sequence(texts) | ||
|
||
def embed_query(self, text: str) -> List[float]: | ||
"""Embed a query using the model2vec embeddings model. | ||
Args: | ||
text: The text to embed. | ||
Returns: | ||
Embeddings for the text. | ||
""" | ||
|
||
return self._model.encode(text) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
11 changes: 11 additions & 0 deletions
11
libs/community/tests/unit_tests/embeddings/test_model2vec.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from langchain_community.embeddings.model2vec import Model2vecEmbeddings | ||
|
||
|
||
def test_hugginggface_inferenceapi_embedding_documents_init() -> None: | ||
"""Test model2vec embeddings.""" | ||
try: | ||
embedding = Model2vecEmbeddings("minishlab/potion-base-8M") | ||
assert len(embedding.embed_query("hi")) == 256 | ||
except Exception: | ||
# model2vec is not installed | ||
assert True |