Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 36 additions & 133 deletions mteb/models/youtu_models.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,14 @@
from __future__ import annotations

import json
import logging
import os
import time
from functools import partial
from typing import Any

import numpy as np
import tqdm

from mteb.encoder_interface import PromptType
from mteb.model_meta import ModelMeta
from mteb.models.wrapper import Wrapper
from mteb.requires_package import requires_package
from mteb.models.instruct_wrapper import InstructSentenceTransformerWrapper

logger = logging.getLogger(__name__)


youtu_instruction = {
"CmedqaRetrieval": {
"query": "Given a Chinese community medical question, retrieve replies that best answer the question",
Expand Down Expand Up @@ -82,13 +73,26 @@
"OnlineShopping": "Classify the customer review for online shopping into positive or negative",
"JDReview": "Classify the customer review for iPhone on e-commerce platform into positive or negative",
"MultilingualSentiment": "Classify sentiment of the customer review into positive, neutral, or negative",
"CLSClusteringS2S": "Identify the main category of scholar papers based on the titles",
"CLSClusteringP2P": "Identify the main category of scholar papers based on the titles and abstracts",
"CLSClusteringS2S": "Given the papers, find the appropriate fine-grained topic or theme based on the titles",
"CLSClusteringP2P": "Given the papers, find the appropriate fine-grained topic or theme based on the titles and abstracts",
"ThuNewsClusteringS2S": "Identify the topic or theme of the given news articles based on the titles",
"ThuNewsClusteringP2P": "Identify the topic or theme of the given news articles based on the titles and contents",
}


def instruction_template(
instruction: str, prompt_type: PromptType | None = None
) -> str:
if not instruction or prompt_type == PromptType.passage:
return ""
if isinstance(instruction, dict):
if prompt_type is None:
instruction = list(instruction.values())[0]
else:
instruction = instruction[prompt_type]
return f"Instruction: {instruction} \nQuery: "


training_data = {
"T2Retrieval": ["train"],
"DuRetrieval": ["train"],
Expand All @@ -107,131 +111,30 @@
}


class YoutuEmbeddingWrapper(Wrapper):
def __init__(
self,
model_name: str,
**kwargs,
) -> None:
requires_package(
self,
"tencentcloud.common",
"tencentcloud.lkeap",
"pip install mteb[youtu]",
)

from tencentcloud.common import credential
from tencentcloud.common.profile.client_profile import ClientProfile
from tencentcloud.common.profile.http_profile import HttpProfile
from tencentcloud.lkeap.v20240522 import lkeap_client, models

secret_id = os.getenv("TENCENTCLOUD_SECRET_ID")
secret_key = os.getenv("TENCENTCLOUD_SECRET_KEY")
if not secret_id or not secret_key:
raise ValueError(
"TENCENTCLOUD_SECRET_ID and TENCENTCLOUD_SECRET_KEY environment variables must be set"
)
cred = credential.Credential(secret_id, secret_key)

httpProfile = HttpProfile()
httpProfile.endpoint = "lkeap.test.tencentcloudapi.com"

clientProfile = ClientProfile()
clientProfile.httpProfile = httpProfile

self.client = lkeap_client.LkeapClient(cred, "ap-guangzhou", clientProfile)
self.models = models
self.model_name = model_name

def _encode(
self, inputs: list[str], task_name: str, prompt_type: PromptType | None = None
):
default_instruction = (
"Given a search query, retrieve web passages that answer the question"
)
if prompt_type == PromptType("query") or prompt_type is None:
instruction = youtu_instruction.get(task_name, default_instruction)
if isinstance(instruction, dict):
instruction = instruction[prompt_type]
instruction = f"Instruction: {instruction} \nQuery: "
elif prompt_type == PromptType.document:
instruction = ""

params = {
"Model": self.model_name,
"Inputs": inputs,
"Instruction": instruction,
}

req = self.models.GetEmbeddingRequest()
req.from_json_string(json.dumps(params))

resp = self.client.GetEmbedding(req)
resp = json.loads(resp.to_json_string())
outputs = [item["Embedding"] for item in resp["Data"]]

return outputs

def encode(
self,
sentences: list[str],
task_name: str,
prompt_type: PromptType | None = None,
retries: int = 5,
**kwargs: Any,
) -> np.ndarray:
max_batch_size = kwargs.get("batch_size", 32)
sublists = [
sentences[i : i + max_batch_size]
for i in range(0, len(sentences), max_batch_size)
]

show_progress_bar = (
False
if "show_progress_bar" not in kwargs
else kwargs.pop("show_progress_bar")
)

all_embeddings = []

for i, sublist in enumerate(
tqdm.tqdm(sublists, leave=False, disable=not show_progress_bar)
):
embedding = self._encode(sublist, task_name, prompt_type)
all_embeddings.extend(embedding)
while retries > 0:
try:
embedding = self._encode(sublist, task_name, prompt_type)
all_embeddings.extend(embedding)
break
except Exception as e:
# Sleep due to too many requests
time.sleep(1)
logger.warning(
f"Retrying... {retries} retries left. Error: {str(e)}"
)
retries -= 1
if retries == 0:
raise e

return np.array(all_embeddings)


Youtu_Embedding_V1 = ModelMeta(
name="Youtu-RAG/Youtu-Embedding-V1",
loader=partial(
InstructSentenceTransformerWrapper,
"tencent/Youtu-Embedding",
revision="32e04afc24817c187a8422e7bdbb493b19796d47",
instruction_template=instruction_template,
apply_instruction_to_passages=False,
prompts_dict=youtu_instruction,
trust_remote_code=True,
max_seq_length=8192,
),
name="tencent/Youtu-Embedding",
languages=["zho-Hans"],
revision="1",
release_date="2025-09-02",
loader=partial(YoutuEmbeddingWrapper, model_name="youtu-embedding-v1"),
open_weights=False,
n_parameters=None,
revision="32e04afc24817c187a8422e7bdbb493b19796d47",
release_date="2025-09-28",
open_weights=True,
n_parameters=2672957440,
memory_usage_mb=None,
embed_dim=2304,
license=None,
max_tokens=4096,
reference="https://youtu-embedding-v1.github.io/",
embed_dim=2048,
license="apache-2.0",
max_tokens=8192,
reference="https://huggingface.co/tencent/Youtu-Embedding",
similarity_fn_name="cosine",
framework=["API"],
framework=["Sentence Transformers", "PyTorch"],
use_instructions=True,
public_training_code=None,
public_training_data=None,
Expand Down