Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mteb/models/overview.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@
voyage_models,
voyage_v,
xyz_models,
youtu_models,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -207,6 +208,7 @@
mcinext_models,
vi_vn_models,
codi_models,
youtu_models,
]
MODEL_REGISTRY = {}

Expand Down
239 changes: 239 additions & 0 deletions mteb/models/youtu_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
from __future__ import annotations

import json
import logging
import os
import time
from functools import partial
from typing import Any

import numpy as np
import tqdm

from mteb.encoder_interface import PromptType
from mteb.model_meta import ModelMeta
from mteb.models.wrapper import Wrapper
from mteb.requires_package import requires_package

logger = logging.getLogger(__name__)


youtu_instruction = {
"CmedqaRetrieval": {
"query": "Given a Chinese community medical question, retrieve replies that best answer the question",
"document": "",
},
"CovidRetrieval": {
"query": "Given a question on COVID-19, retrieve news articles that answer the question",
"document": "",
},
"DuRetrieval": {
"query": "Given a Chinese search query, retrieve web passages that answer the question",
"document": "",
},
"EcomRetrieval": {
"query": "Given a user query from an e-commerce website, retrieve description sentences of relevant products",
"document": "",
},
"MedicalRetrieval": {
"query": "Given a medical question, retrieve user replies that best answer the question",
"document": "",
},
"MMarcoRetrieval": {
"query": "Given a web search query, retrieve relevant passages that answer the query",
"document": "",
},
"T2Retrieval": {
"query": "Given a Chinese search query, retrieve web passages that answer the question",
"document": "",
},
"VideoRetrieval": {
"query": "Given a video search query, retrieve the titles of relevant videos",
"document": "",
},
"AFQMC": "Represent the text in conversations between users and financial customer service, retrieve semantically similar text",
"ATEC": "Represent the text in conversations between users and financial customer service, retrieve semantically similar text",
"BQ": "Represent the user problem descriptions when handling bank credit business, retrieve semantically similar text",
"LCQMC": "Represent the user question descriptions on general question-answering platforms, retrieve semantically similar text",
"PAWSX": "Represent the Chinese Translations of English Encyclopedias, retrieve semantically similar text",
"QBQTC": "Represent the web search query, retrieve semantically similar text",
"STSB": "Represent the short general domain sentences, retrieve semantically similar text",
"T2Reranking": {
"query": "Given a Chinese search query, retrieve web passages that answer the question",
"document": "",
},
"MMarcoReranking": {
"query": "Given a web search query, retrieve relevant passages that answer the query",
"document": "",
},
"CMedQAv1-reranking": {
"query": "Given a Chinese community medical question, retrieve replies that best answer the question",
"document": "",
},
"CMedQAv2-reranking": {
"query": "Given a Chinese community medical question, retrieve replies that best answer the question",
"document": "",
},
"Ocnli": "Retrieve semantically similar text",
"Cmnli": "Retrieve semantically similar text",
"TNews": "Classify the fine-grained category of the given news title",
"IFlyTek": "Given an App description text, find the appropriate fine-grained category",
"Waimai": "Classify the customer review from a food takeaway platform into positive or negative",
"OnlineShopping": "Classify the customer review for online shopping into positive or negative",
"JDReview": "Classify the customer review for iPhone on e-commerce platform into positive or negative",
"MultilingualSentiment": "Classify sentiment of the customer review into positive, neutral, or negative",
"CLSClusteringS2S": "Identify the main category of scholar papers based on the titles",
"CLSClusteringP2P": "Identify the main category of scholar papers based on the titles and abstracts",
"ThuNewsClusteringS2S": "Identify the topic or theme of the given news articles based on the titles",
"ThuNewsClusteringP2P": "Identify the topic or theme of the given news articles based on the titles and contents",
}


training_data = {
"T2Retrieval": ["train"],
"DuRetrieval": ["train"],
"T2Reranking": ["train"],
"MMarcoReranking": ["train"],
"CmedqaRetrieval": ["train"],
"CMedQAv1-reranking": ["train"],
"CMedQAv2-reranking": ["train"],
"BQ": ["train"],
"LCQMC": ["train"],
"PAWSX": ["train"],
"STS-B": ["train"],
"AFQMC": ["train"],
"Cmnli": ["train"],
"Ocnli": ["train"],
}


class YoutuEmbeddingWrapper(Wrapper):
def __init__(
self,
model_name: str,
**kwargs,
) -> None:
requires_package(
self,
"tencentcloud.common",
"tencentcloud.lkeap",
"pip install mteb[youtu]",
)

from tencentcloud.common import credential
from tencentcloud.common.profile.client_profile import ClientProfile
from tencentcloud.common.profile.http_profile import HttpProfile
from tencentcloud.lkeap.v20240522 import lkeap_client, models

secret_id = os.getenv("TENCENTCLOUD_SECRET_ID")
secret_key = os.getenv("TENCENTCLOUD_SECRET_KEY")
if not secret_id or not secret_key:
raise ValueError(
"TENCENTCLOUD_SECRET_ID and TENCENTCLOUD_SECRET_KEY environment variables must be set"
)
cred = credential.Credential(secret_id, secret_key)

httpProfile = HttpProfile()
httpProfile.endpoint = "lkeap.test.tencentcloudapi.com"

clientProfile = ClientProfile()
clientProfile.httpProfile = httpProfile

self.client = lkeap_client.LkeapClient(cred, "ap-guangzhou", clientProfile)
self.models = models
self.model_name = model_name

def _encode(
self, inputs: list[str], task_name: str, prompt_type: PromptType | None = None
):
default_instruction = (
"Given a search query, retrieve web passages that answer the question"
)
if prompt_type == PromptType("query") or prompt_type is None:
instruction = youtu_instruction.get(task_name, default_instruction)
if isinstance(instruction, dict):
instruction = instruction[prompt_type]
instruction = f"Instruction: {instruction} \nQuery: "
elif prompt_type == PromptType.document:
instruction = ""

params = {
"Model": self.model_name,
"Inputs": inputs,
"Instruction": instruction,
}

req = self.models.GetEmbeddingRequest()
req.from_json_string(json.dumps(params))

resp = self.client.GetEmbedding(req)
resp = json.loads(resp.to_json_string())
outputs = [item["Embedding"] for item in resp["Data"]]

return outputs

def encode(
self,
sentences: list[str],
task_name: str,
prompt_type: PromptType | None = None,
retries: int = 5,
**kwargs: Any,
) -> np.ndarray:
max_batch_size = kwargs.get("batch_size", 32)
sublists = [
sentences[i : i + max_batch_size]
for i in range(0, len(sentences), max_batch_size)
]

show_progress_bar = (
False
if "show_progress_bar" not in kwargs
else kwargs.pop("show_progress_bar")
)

all_embeddings = []

for i, sublist in enumerate(
tqdm.tqdm(sublists, leave=False, disable=not show_progress_bar)
):
embedding = self._encode(sublist, task_name, prompt_type)
all_embeddings.extend(embedding)
while retries > 0:
try:
embedding = self._encode(sublist, task_name, prompt_type)
all_embeddings.extend(embedding)
break
except Exception as e:
# Sleep due to too many requests
time.sleep(1)
logger.warning(
f"Retrying... {retries} retries left. Error: {str(e)}"
)
retries -= 1
if retries == 0:
raise e

return np.array(all_embeddings)


Youtu_Embedding_V1 = ModelMeta(
name="Youtu-RAG/Youtu-Embedding-V1",
languages=["zho-Hans"],
revision="1",
release_date="2025-09-02",
loader=partial(YoutuEmbeddingWrapper, model_name="youtu-embedding-v1"),
open_weights=False,
n_parameters=None,
memory_usage_mb=None,
embed_dim=2304,
license=None,
max_tokens=4096,
reference="https://youtu-embedding-v1.github.io/",
similarity_fn_name="cosine",
framework=["API"],
use_instructions=True,
public_training_code=None,
public_training_data=None,
training_datasets=training_data,
)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ open_clip_torch = ["open_clip_torch==2.31.0"]
xet = ["huggingface_hub>=0.32.0"]
ark = ["volcengine-python-sdk[ark]==3.0.2", "tiktoken>=0.8.0"]
colpali_engine = ["colpali_engine>=0.3.12"]
youtu = ["tencentcloud-sdk-python-common>=3.0.1454", "tencentcloud-sdk-python-lkeap>=3.0.1451"]

[dependency-groups]
lint = [
Expand Down
Loading