-
Notifications
You must be signed in to change notification settings - Fork 362
Integrate huggingface_hub inference support #651
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ddccc66
6bd7af7
0a84f04
a645f3f
d0e4a54
9aa02c8
3f5ad7b
d9a47ce
2438479
0ebde9c
85f1dd1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,13 +21,17 @@ | |
# SOFTWARE. | ||
|
||
|
||
import asyncio | ||
import logging | ||
import time | ||
from concurrent.futures import ThreadPoolExecutor | ||
from typing import Callable, Literal | ||
from typing import Callable, Literal, Optional | ||
|
||
from huggingface_hub import AsyncInferenceClient, InferenceTimeoutError | ||
from pydantic import BaseModel | ||
from requests.exceptions import HTTPError | ||
from tqdm import tqdm | ||
from tqdm.asyncio import tqdm_asyncio | ||
|
||
from lighteval.utils.imports import is_litellm_available, is_openai_available, is_vllm_available | ||
from lighteval.utils.utils import as_list | ||
|
@@ -78,10 +82,28 @@ def __init__( | |
model: str, | ||
templates: Callable, | ||
process_judge_response: Callable, | ||
judge_backend: Literal["litellm", "openai", "transformers", "tgi", "vllm"], | ||
judge_backend: Literal["litellm", "openai", "transformers", "tgi", "vllm", "inference-providers"], | ||
url: str | None = None, | ||
api_key: str | None = None, | ||
max_tokens: int = 512, | ||
response_format: BaseModel = None, | ||
hf_provider: Optional[ | ||
Literal[ | ||
"black-forest-labs", | ||
"cerebras", | ||
"cohere", | ||
"fal-ai", | ||
"fireworks-ai", | ||
"inference-providers", | ||
"hyperbolic", | ||
"nebius", | ||
"novita", | ||
"openai", | ||
"replicate", | ||
"sambanova", | ||
"together", | ||
] | ||
] = None, | ||
): | ||
self.model = model | ||
self.template = templates | ||
|
@@ -96,33 +118,47 @@ def __init__( | |
self.url = url | ||
self.api_key = api_key | ||
self.backend = judge_backend | ||
self.hf_provider = hf_provider | ||
self.max_tokens = max_tokens | ||
|
||
self.response_format = response_format if not None else DEFAULT_FORMAT | ||
|
||
# Validate that hf_provider is specified when using inference-providers backend | ||
if self.backend == "inference-providers" and self.hf_provider is None: | ||
raise ValueError("When using 'inference-providers' as backend, you must specify an 'hf_provider'") | ||
|
||
def __lazy_load_client(self): | ||
match self.backend: | ||
# Wether we use openai or TGI models, we go through the openai API | ||
# to route to the endpoint | ||
case "openai" | "tgi" if is_openai_available(): | ||
# Both "openai" and "tgi" backends use the OpenAI-compatible API | ||
# They are handled separately to allow for backend-specific validation and setup | ||
case "openai" | "tgi": | ||
if not is_openai_available(): | ||
raise RuntimeError("OpenAI backend is not available.") | ||
if self.client is None: | ||
from openai import OpenAI | ||
|
||
if self.url is None: | ||
self.client = OpenAI(api_key=self.api_key) | ||
else: | ||
self.client = OpenAI(base_url=self.url, api_key=self.api_key) | ||
self.client = OpenAI( | ||
api_key=self.api_key if self.url is None else None, base_url=self.url if self.url else None | ||
) | ||
return self.__call_api_parallel | ||
case "litellm" if is_litellm_available(): | ||
|
||
case "litellm": | ||
if not is_litellm_available(): | ||
raise RuntimeError("litellm is not available.") | ||
return self.__call_litellm | ||
case "vllm" if is_vllm_available(): | ||
|
||
case "vllm": | ||
if not is_vllm_available(): | ||
raise RuntimeError("vllm is not available.") | ||
if self.pipe is None: | ||
from vllm import LLM, SamplingParams | ||
from vllm.transformers_utils.tokenizer import get_tokenizer | ||
|
||
self.sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=512) | ||
self.sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=self.max_tokens) | ||
self.tokenizer = get_tokenizer(self.model, tokenizer_mode="auto") | ||
self.pipe = LLM(model=self.model, max_model_len=2048, gpu_memory_utilization=0.5, dtype="float16") | ||
return self.__call_vllm | ||
|
||
case "transformers": | ||
if self.pipe is None: | ||
import torch | ||
|
@@ -136,11 +172,18 @@ def __lazy_load_client(self): | |
"text-generation", | ||
model=transformers_model, | ||
tokenizer=tokenizer, | ||
max_new_tokens=256, | ||
max_new_tokens=self.max_tokens, | ||
) | ||
return self.__call_transformers | ||
|
||
case "inference-providers": | ||
from huggingface_hub import AsyncInferenceClient | ||
|
||
self.client = AsyncInferenceClient(token=self.api_key, base_url=self.url, provider=self.hf_provider) | ||
return self.__call_hf_inference_async | ||
|
||
case _: | ||
return lambda x: x | ||
raise ValueError(f"Unsupported backend: {self.backend}") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @NathanHB were you using this case for some specific use cases? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh it was because I wanted the details to be saved even with a failure in the judge so that we don't have to rerun everything. |
||
|
||
def dict_of_lists_to_list_of_dicts(self, dict_of_lists): | ||
""" | ||
|
@@ -287,6 +330,44 @@ def __call_api(prompt): | |
|
||
return results | ||
|
||
def __call_hf_inference_async(self, prompts): | ||
async def run_all() -> list[str]: | ||
"""Wrap inference call into function""" | ||
tasks = (self.__call_hf_inference(prompt) for prompt in prompts) | ||
return await tqdm_asyncio.gather(*tasks, desc="HF inference", total=len(prompts)) | ||
|
||
try: | ||
loop = asyncio.get_running_loop() | ||
logger.debug("Exting event loop is found, using loop.create_task") | ||
result = loop.run_until_complete(run_all()) | ||
except RuntimeError: | ||
logger.debug("No running event loop found, using asyncio.run") | ||
result = asyncio.run(run_all()) | ||
|
||
if None in result: | ||
logger.warning("None found in inference results") | ||
|
||
return result | ||
|
||
async def __call_hf_inference(self, prompt): | ||
self.client: AsyncInferenceClient | ||
for _ in range(self.API_MAX_RETRY): | ||
try: | ||
response = await self.client.chat_completion( | ||
model=self.model, | ||
messages=prompt, | ||
max_tokens=self.max_tokens, | ||
) | ||
return response.choices[0].message.content | ||
except (InferenceTimeoutError, HTTPError) as e: | ||
logger.warning(f"HTTP error during HF inference: {e}") | ||
await asyncio.sleep(self.API_RETRY_SLEEP) | ||
except Exception as e: | ||
logger.warning(f"Unexpected error during HF inference: {e}") | ||
await asyncio.sleep(self.API_RETRY_SLEEP) | ||
|
||
raise Exception("Failed to get response from the HF API") | ||
|
||
def __call_api_parallel(self, prompts): | ||
results = [] | ||
with ThreadPoolExecutor(10) as executor: | ||
|
Uh oh!
There was an error while loading. Please reload this page.