diff --git a/mteb/models/youtu_models.py b/mteb/models/youtu_models.py index 19127e853b..c092af8632 100644 --- a/mteb/models/youtu_models.py +++ b/mteb/models/youtu_models.py @@ -1,23 +1,14 @@ from __future__ import annotations -import json import logging -import os -import time from functools import partial -from typing import Any - -import numpy as np -import tqdm from mteb.encoder_interface import PromptType from mteb.model_meta import ModelMeta -from mteb.models.wrapper import Wrapper -from mteb.requires_package import requires_package +from mteb.models.instruct_wrapper import InstructSentenceTransformerWrapper logger = logging.getLogger(__name__) - youtu_instruction = { "CmedqaRetrieval": { "query": "Given a Chinese community medical question, retrieve replies that best answer the question", @@ -82,13 +73,26 @@ "OnlineShopping": "Classify the customer review for online shopping into positive or negative", "JDReview": "Classify the customer review for iPhone on e-commerce platform into positive or negative", "MultilingualSentiment": "Classify sentiment of the customer review into positive, neutral, or negative", - "CLSClusteringS2S": "Identify the main category of scholar papers based on the titles", - "CLSClusteringP2P": "Identify the main category of scholar papers based on the titles and abstracts", + "CLSClusteringS2S": "Given the papers, find the appropriate fine-grained topic or theme based on the titles", + "CLSClusteringP2P": "Given the papers, find the appropriate fine-grained topic or theme based on the titles and abstracts", "ThuNewsClusteringS2S": "Identify the topic or theme of the given news articles based on the titles", "ThuNewsClusteringP2P": "Identify the topic or theme of the given news articles based on the titles and contents", } +def instruction_template( + instruction: str, prompt_type: PromptType | None = None +) -> str: + if not instruction or prompt_type == PromptType.passage: + return "" + if isinstance(instruction, dict): + if prompt_type is None: + instruction = list(instruction.values())[0] + else: + instruction = instruction[prompt_type] + return f"Instruction: {instruction} \nQuery: " + + training_data = { "T2Retrieval": ["train"], "DuRetrieval": ["train"], @@ -107,131 +111,30 @@ } -class YoutuEmbeddingWrapper(Wrapper): - def __init__( - self, - model_name: str, - **kwargs, - ) -> None: - requires_package( - self, - "tencentcloud.common", - "tencentcloud.lkeap", - "pip install mteb[youtu]", - ) - - from tencentcloud.common import credential - from tencentcloud.common.profile.client_profile import ClientProfile - from tencentcloud.common.profile.http_profile import HttpProfile - from tencentcloud.lkeap.v20240522 import lkeap_client, models - - secret_id = os.getenv("TENCENTCLOUD_SECRET_ID") - secret_key = os.getenv("TENCENTCLOUD_SECRET_KEY") - if not secret_id or not secret_key: - raise ValueError( - "TENCENTCLOUD_SECRET_ID and TENCENTCLOUD_SECRET_KEY environment variables must be set" - ) - cred = credential.Credential(secret_id, secret_key) - - httpProfile = HttpProfile() - httpProfile.endpoint = "lkeap.test.tencentcloudapi.com" - - clientProfile = ClientProfile() - clientProfile.httpProfile = httpProfile - - self.client = lkeap_client.LkeapClient(cred, "ap-guangzhou", clientProfile) - self.models = models - self.model_name = model_name - - def _encode( - self, inputs: list[str], task_name: str, prompt_type: PromptType | None = None - ): - default_instruction = ( - "Given a search query, retrieve web passages that answer the question" - ) - if prompt_type == PromptType("query") or prompt_type is None: - instruction = youtu_instruction.get(task_name, default_instruction) - if isinstance(instruction, dict): - instruction = instruction[prompt_type] - instruction = f"Instruction: {instruction} \nQuery: " - elif prompt_type == PromptType.document: - instruction = "" - - params = { - "Model": self.model_name, - "Inputs": inputs, - "Instruction": instruction, - } - - req = self.models.GetEmbeddingRequest() - req.from_json_string(json.dumps(params)) - - resp = self.client.GetEmbedding(req) - resp = json.loads(resp.to_json_string()) - outputs = [item["Embedding"] for item in resp["Data"]] - - return outputs - - def encode( - self, - sentences: list[str], - task_name: str, - prompt_type: PromptType | None = None, - retries: int = 5, - **kwargs: Any, - ) -> np.ndarray: - max_batch_size = kwargs.get("batch_size", 32) - sublists = [ - sentences[i : i + max_batch_size] - for i in range(0, len(sentences), max_batch_size) - ] - - show_progress_bar = ( - False - if "show_progress_bar" not in kwargs - else kwargs.pop("show_progress_bar") - ) - - all_embeddings = [] - - for i, sublist in enumerate( - tqdm.tqdm(sublists, leave=False, disable=not show_progress_bar) - ): - embedding = self._encode(sublist, task_name, prompt_type) - all_embeddings.extend(embedding) - while retries > 0: - try: - embedding = self._encode(sublist, task_name, prompt_type) - all_embeddings.extend(embedding) - break - except Exception as e: - # Sleep due to too many requests - time.sleep(1) - logger.warning( - f"Retrying... {retries} retries left. Error: {str(e)}" - ) - retries -= 1 - if retries == 0: - raise e - - return np.array(all_embeddings) - - Youtu_Embedding_V1 = ModelMeta( - name="Youtu-RAG/Youtu-Embedding-V1", + loader=partial( + InstructSentenceTransformerWrapper, + "tencent/Youtu-Embedding", + revision="32e04afc24817c187a8422e7bdbb493b19796d47", + instruction_template=instruction_template, + apply_instruction_to_passages=False, + prompts_dict=youtu_instruction, + trust_remote_code=True, + max_seq_length=8192, + ), + name="tencent/Youtu-Embedding", languages=["zho-Hans"], - revision="1", - release_date="2025-09-02", - loader=partial(YoutuEmbeddingWrapper, model_name="youtu-embedding-v1"), - open_weights=False, - n_parameters=None, + revision="32e04afc24817c187a8422e7bdbb493b19796d47", + release_date="2025-09-28", + open_weights=True, + n_parameters=2672957440, memory_usage_mb=None, - embed_dim=2304, - license=None, - max_tokens=4096, - reference="https://youtu-embedding-v1.github.io/", + embed_dim=2048, + license="apache-2.0", + max_tokens=8192, + reference="https://huggingface.co/tencent/Youtu-Embedding", similarity_fn_name="cosine", - framework=["API"], + framework=["Sentence Transformers", "PyTorch"], use_instructions=True, public_training_code=None, public_training_data=None,