From 70ff54eace10b50d4f01a4f5a2d8636190ccca1b Mon Sep 17 00:00:00 2001 From: baichuan-assistant <139942740+baichuan-assistant@users.noreply.github.com> Date: Sat, 27 Jan 2024 04:57:26 +0800 Subject: [PATCH] community[minor]: Add Baichuan Text Embedding Model and Baichuan Inc introduction (#16568) - **Description:** Adding Baichuan Text Embedding Model and Baichuan Inc introduction. Baichuan Text Embedding ranks #1 in C-MTEB leaderboard: https://huggingface.co/spaces/mteb/leaderboard Co-authored-by: BaiChuanHelper --- docs/docs/integrations/providers/baichuan.mdx | 13 ++ .../text_embedding/baichuan.ipynb | 75 ++++++++++++ .../integrations/vectorstores/kdbai.ipynb | 54 +++++---- .../embeddings/__init__.py | 2 + .../embeddings/baichuan.py | 113 ++++++++++++++++++ .../embeddings/test_baichuan.py | 19 +++ .../unit_tests/embeddings/test_imports.py | 1 + 7 files changed, 252 insertions(+), 25 deletions(-) create mode 100644 docs/docs/integrations/providers/baichuan.mdx create mode 100644 docs/docs/integrations/text_embedding/baichuan.ipynb create mode 100644 libs/community/langchain_community/embeddings/baichuan.py create mode 100644 libs/community/tests/integration_tests/embeddings/test_baichuan.py diff --git a/docs/docs/integrations/providers/baichuan.mdx b/docs/docs/integrations/providers/baichuan.mdx new file mode 100644 index 0000000000000..b73e74b457941 --- /dev/null +++ b/docs/docs/integrations/providers/baichuan.mdx @@ -0,0 +1,13 @@ +# Baichuan + +>[Baichuan Inc.](https://www.baichuan-ai.com/) is a Chinese startup in the era of AGI, dedicated to addressing fundamental human needs: Efficiency, Health, and Happiness. + +## Visit Us +Visit us at https://www.baichuan-ai.com/. +Register and get an API key if you are trying out our APIs. + +## Baichuan Chat Model +An example is available at [example](/docs/integrations/chat/baichuan). + +## Baichuan Text Embedding Model +An example is available at [example] (/docs/integrations/text_embedding/baichuan) diff --git a/docs/docs/integrations/text_embedding/baichuan.ipynb b/docs/docs/integrations/text_embedding/baichuan.ipynb new file mode 100644 index 0000000000000..8b5d57a2ddbcb --- /dev/null +++ b/docs/docs/integrations/text_embedding/baichuan.ipynb @@ -0,0 +1,75 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Baichuan Text Embeddings\n", + "\n", + "As of today (Jan 25th, 2024) BaichuanTextEmbeddings ranks #1 in C-MTEB (Chinese Multi-Task Embedding Benchmark) leaderboard.\n", + "\n", + "Leaderboard (Under Overall -> Chinese section): https://huggingface.co/spaces/mteb/leaderboard\n", + "\n", + "Official Website: https://platform.baichuan-ai.com/docs/text-Embedding\n", + "An API-key is required to use this embedding model. You can get one by registering at https://platform.baichuan-ai.com/docs/text-Embedding.\n", + "BaichuanTextEmbeddings support 512 token window and preduces vectors with 1024 dimensions. \n", + "\n", + "Please NOTE that BaichuanTextEmbeddings only supports Chinese text embedding. Multi-language support is coming soon.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "vscode": { + "languageId": "plaintext" + } + }, + "outputs": [], + "source": [ + "from langchain_community.embeddings import BaichuanTextEmbeddings\n", + "\n", + "# Place your Baichuan API-key here.\n", + "embeddings = BaichuanTextEmbeddings(baichuan_api_key=\"sk-*\")\n", + "\n", + "text_1 = \"今天天气不错\"\n", + "text_2 = \"今天阳光很好\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "vscode": { + "languageId": "plaintext" + } + }, + "outputs": [], + "source": [ + "query_result = embeddings.embed_query(text_1)\n", + "query_result" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "vscode": { + "languageId": "plaintext" + } + }, + "outputs": [], + "source": [ + "doc_result = embeddings.embed_documents([text_1, text_2])\n", + "doc_result" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/docs/integrations/vectorstores/kdbai.ipynb b/docs/docs/integrations/vectorstores/kdbai.ipynb index 39c1f1fdec300..e0121796ad077 100644 --- a/docs/docs/integrations/vectorstores/kdbai.ipynb +++ b/docs/docs/integrations/vectorstores/kdbai.ipynb @@ -167,9 +167,9 @@ ], "source": [ "%%time\n", - "URL = 'https://www.conseil-constitutionnel.fr/node/3850/pdf'\n", - "PDF = 'Déclaration_des_droits_de_l_homme_et_du_citoyen.pdf'\n", - "open(PDF, 'wb').write(requests.get(URL).content)" + "URL = \"https://www.conseil-constitutionnel.fr/node/3850/pdf\"\n", + "PDF = \"Déclaration_des_droits_de_l_homme_et_du_citoyen.pdf\"\n", + "open(PDF, \"wb\").write(requests.get(URL).content)" ] }, { @@ -208,7 +208,7 @@ ], "source": [ "%%time\n", - "print('Read a PDF...')\n", + "print(\"Read a PDF...\")\n", "loader = PyPDFLoader(PDF)\n", "pages = loader.load_and_split()\n", "len(pages)" @@ -252,12 +252,14 @@ ], "source": [ "%%time\n", - "print('Create a Vector Database from PDF text...')\n", - "embeddings = OpenAIEmbeddings(model='text-embedding-ada-002')\n", + "print(\"Create a Vector Database from PDF text...\")\n", + "embeddings = OpenAIEmbeddings(model=\"text-embedding-ada-002\")\n", "texts = [p.page_content for p in pages]\n", "metadata = pd.DataFrame(index=list(range(len(texts))))\n", - "metadata['tag'] = 'law'\n", - "metadata['title'] = 'Déclaration des Droits de l\\'Homme et du Citoyen de 1789'.encode('utf-8')\n", + "metadata[\"tag\"] = \"law\"\n", + "metadata[\"title\"] = \"Déclaration des Droits de l'Homme et du Citoyen de 1789\".encode(\n", + " \"utf-8\"\n", + ")\n", "vectordb = KDBAI(table, embeddings)\n", "vectordb.add_texts(texts=texts, metadatas=metadata)" ] @@ -288,11 +290,13 @@ ], "source": [ "%%time\n", - "print('Create LangChain Pipeline...')\n", - "qabot = RetrievalQA.from_chain_type(chain_type='stuff',\n", - " llm=ChatOpenAI(model='gpt-3.5-turbo-16k', temperature=TEMP), \n", - " retriever=vectordb.as_retriever(search_kwargs=dict(k=K)),\n", - " return_source_documents=True)" + "print(\"Create LangChain Pipeline...\")\n", + "qabot = RetrievalQA.from_chain_type(\n", + " chain_type=\"stuff\",\n", + " llm=ChatOpenAI(model=\"gpt-3.5-turbo-16k\", temperature=TEMP),\n", + " retriever=vectordb.as_retriever(search_kwargs=dict(k=K)),\n", + " return_source_documents=True,\n", + ")" ] }, { @@ -325,9 +329,9 @@ ], "source": [ "%%time\n", - "Q = 'Summarize the document in English:'\n", - "print(f'\\n\\n{Q}\\n')\n", - "print(qabot.invoke(dict(query=Q))['result'])" + "Q = \"Summarize the document in English:\"\n", + "print(f\"\\n\\n{Q}\\n\")\n", + "print(qabot.invoke(dict(query=Q))[\"result\"])" ] }, { @@ -362,9 +366,9 @@ ], "source": [ "%%time\n", - "Q = 'Is it a fair law and why ?'\n", - "print(f'\\n\\n{Q}\\n')\n", - "print(qabot.invoke(dict(query=Q))['result'])" + "Q = \"Is it a fair law and why ?\"\n", + "print(f\"\\n\\n{Q}\\n\")\n", + "print(qabot.invoke(dict(query=Q))[\"result\"])" ] }, { @@ -414,9 +418,9 @@ ], "source": [ "%%time\n", - "Q = 'What are the rights and duties of the man, the citizen and the society ?'\n", - "print(f'\\n\\n{Q}\\n')\n", - "print(qabot.invoke(dict(query=Q))['result'])" + "Q = \"What are the rights and duties of the man, the citizen and the society ?\"\n", + "print(f\"\\n\\n{Q}\\n\")\n", + "print(qabot.invoke(dict(query=Q))[\"result\"])" ] }, { @@ -441,9 +445,9 @@ ], "source": [ "%%time\n", - "Q = 'Is this law practical ?'\n", - "print(f'\\n\\n{Q}\\n')\n", - "print(qabot.invoke(dict(query=Q))['result'])" + "Q = \"Is this law practical ?\"\n", + "print(f\"\\n\\n{Q}\\n\")\n", + "print(qabot.invoke(dict(query=Q))[\"result\"])" ] }, { diff --git a/libs/community/langchain_community/embeddings/__init__.py b/libs/community/langchain_community/embeddings/__init__.py index aaf893ba51b05..8c50c3f1ab79d 100644 --- a/libs/community/langchain_community/embeddings/__init__.py +++ b/libs/community/langchain_community/embeddings/__init__.py @@ -20,6 +20,7 @@ ) from langchain_community.embeddings.awa import AwaEmbeddings from langchain_community.embeddings.azure_openai import AzureOpenAIEmbeddings +from langchain_community.embeddings.baichuan import BaichuanTextEmbeddings from langchain_community.embeddings.baidu_qianfan_endpoint import ( QianfanEmbeddingsEndpoint, ) @@ -92,6 +93,7 @@ __all__ = [ "OpenAIEmbeddings", "AzureOpenAIEmbeddings", + "BaichuanTextEmbeddings", "ClarifaiEmbeddings", "CohereEmbeddings", "DatabricksEmbeddings", diff --git a/libs/community/langchain_community/embeddings/baichuan.py b/libs/community/langchain_community/embeddings/baichuan.py new file mode 100644 index 0000000000000..b31e6092d5925 --- /dev/null +++ b/libs/community/langchain_community/embeddings/baichuan.py @@ -0,0 +1,113 @@ +from typing import Any, Dict, List, Optional + +import requests +from langchain_core.embeddings import Embeddings +from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env + +BAICHUAN_API_URL: str = "http://api.baichuan-ai.com/v1/embeddings" + +# BaichuanTextEmbeddings is an embedding model provided by Baichuan Inc. (https://www.baichuan-ai.com/home). +# As of today (Jan 25th, 2024) BaichuanTextEmbeddings ranks #1 in C-MTEB +# (Chinese Multi-Task Embedding Benchmark) leaderboard. +# Leaderboard (Under Overall -> Chinese section): https://huggingface.co/spaces/mteb/leaderboard + +# Official Website: https://platform.baichuan-ai.com/docs/text-Embedding +# An API-key is required to use this embedding model. You can get one by registering +# at https://platform.baichuan-ai.com/docs/text-Embedding. +# BaichuanTextEmbeddings support 512 token window and preduces vectors with +# 1024 dimensions. + + +# NOTE!! BaichuanTextEmbeddings only supports Chinese text embedding. +# Multi-language support is coming soon. +class BaichuanTextEmbeddings(BaseModel, Embeddings): + """Baichuan Text Embedding models.""" + + session: Any #: :meta private: + model_name: str = "Baichuan-Text-Embedding" + baichuan_api_key: Optional[SecretStr] = None + + @root_validator(allow_reuse=True) + def validate_environment(cls, values: Dict) -> Dict: + """Validate that auth token exists in environment.""" + try: + baichuan_api_key = convert_to_secret_str( + get_from_dict_or_env(values, "baichuan_api_key", "BAICHUAN_API_KEY") + ) + except ValueError as original_exc: + try: + baichuan_api_key = convert_to_secret_str( + get_from_dict_or_env( + values, "baichuan_auth_token", "BAICHUAN_AUTH_TOKEN" + ) + ) + except ValueError: + raise original_exc + session = requests.Session() + session.headers.update( + { + "Authorization": f"Bearer {baichuan_api_key.get_secret_value()}", + "Accept-Encoding": "identity", + "Content-type": "application/json", + } + ) + values["session"] = session + return values + + def _embed(self, texts: List[str]) -> Optional[List[List[float]]]: + """Internal method to call Baichuan Embedding API and return embeddings. + + Args: + texts: A list of texts to embed. + + Returns: + A list of list of floats representing the embeddings, or None if an + error occurs. + """ + try: + response = self.session.post( + BAICHUAN_API_URL, json={"input": texts, "model": self.model_name} + ) + # Check if the response status code indicates success + if response.status_code == 200: + resp = response.json() + embeddings = resp.get("data", []) + # Sort resulting embeddings by index + sorted_embeddings = sorted(embeddings, key=lambda e: e.get("index", 0)) + # Return just the embeddings + return [result.get("embedding", []) for result in sorted_embeddings] + else: + # Log error or handle unsuccessful response appropriately + print( + f"""Error: Received status code {response.status_code} from + embedding API""" + ) + return None + except Exception as e: + # Log the exception or handle it as needed + print(f"Exception occurred while trying to get embeddings: {str(e)}") + return None + + def embed_documents(self, texts: List[str]) -> Optional[List[List[float]]]: + """Public method to get embeddings for a list of documents. + + Args: + texts: The list of texts to embed. + + Returns: + A list of embeddings, one for each text, or None if an error occurs. + """ + return self._embed(texts) + + def embed_query(self, text: str) -> Optional[List[float]]: + """Public method to get embedding for a single query text. + + Args: + text: The text to embed. + + Returns: + Embeddings for the text, or None if an error occurs. + """ + result = self._embed([text]) + return result[0] if result is not None else None diff --git a/libs/community/tests/integration_tests/embeddings/test_baichuan.py b/libs/community/tests/integration_tests/embeddings/test_baichuan.py new file mode 100644 index 0000000000000..008dc0f97dfee --- /dev/null +++ b/libs/community/tests/integration_tests/embeddings/test_baichuan.py @@ -0,0 +1,19 @@ +"""Test Baichuan Text Embedding.""" +from langchain_community.embeddings.baichuan import BaichuanTextEmbeddings + + +def test_baichuan_embedding_documents() -> None: + """Test Baichuan Text Embedding for documents.""" + documents = ["今天天气不错", "今天阳光灿烂"] + embedding = BaichuanTextEmbeddings() + output = embedding.embed_documents(documents) + assert len(output) == 2 + assert len(output[0]) == 1024 + + +def test_baichuan_embedding_query() -> None: + """Test Baichuan Text Embedding for query.""" + document = "所有的小学生都会学过只因兔同笼问题。" + embedding = BaichuanTextEmbeddings() + output = embedding.embed_query(document) + assert len(output) == 1024 diff --git a/libs/community/tests/unit_tests/embeddings/test_imports.py b/libs/community/tests/unit_tests/embeddings/test_imports.py index 1bb872607d85a..2220d7446f90a 100644 --- a/libs/community/tests/unit_tests/embeddings/test_imports.py +++ b/libs/community/tests/unit_tests/embeddings/test_imports.py @@ -3,6 +3,7 @@ EXPECTED_ALL = [ "OpenAIEmbeddings", "AzureOpenAIEmbeddings", + "BaichuanTextEmbeddings", "ClarifaiEmbeddings", "CohereEmbeddings", "DatabricksEmbeddings",