Skip to content
Merged
109 changes: 95 additions & 14 deletions src/lighteval/metrics/llm_as_judge.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,17 @@
# SOFTWARE.


import asyncio
import logging
import time
from concurrent.futures import ThreadPoolExecutor
from typing import Callable, Literal
from typing import Callable, Literal, Optional

from huggingface_hub import AsyncInferenceClient, InferenceTimeoutError
from pydantic import BaseModel
from requests.exceptions import HTTPError
from tqdm import tqdm
from tqdm.asyncio import tqdm_asyncio

from lighteval.utils.imports import is_litellm_available, is_openai_available, is_vllm_available
from lighteval.utils.utils import as_list
Expand Down Expand Up @@ -78,10 +82,28 @@ def __init__(
model: str,
templates: Callable,
process_judge_response: Callable,
judge_backend: Literal["litellm", "openai", "transformers", "tgi", "vllm"],
judge_backend: Literal["litellm", "openai", "transformers", "tgi", "vllm", "inference-providers"],
url: str | None = None,
api_key: str | None = None,
max_tokens: int = 512,
response_format: BaseModel = None,
hf_provider: Optional[
Literal[
"black-forest-labs",
"cerebras",
"cohere",
"fal-ai",
"fireworks-ai",
"inference-providers",
"hyperbolic",
"nebius",
"novita",
"openai",
"replicate",
"sambanova",
"together",
]
] = None,
):
self.model = model
self.template = templates
Expand All @@ -96,33 +118,47 @@ def __init__(
self.url = url
self.api_key = api_key
self.backend = judge_backend
self.hf_provider = hf_provider
self.max_tokens = max_tokens

self.response_format = response_format if not None else DEFAULT_FORMAT

# Validate that hf_provider is specified when using inference-providers backend
if self.backend == "inference-providers" and self.hf_provider is None:
raise ValueError("When using 'inference-providers' as backend, you must specify an 'hf_provider'")

def __lazy_load_client(self):
match self.backend:
# Wether we use openai or TGI models, we go through the openai API
# to route to the endpoint
case "openai" | "tgi" if is_openai_available():
# Both "openai" and "tgi" backends use the OpenAI-compatible API
# They are handled separately to allow for backend-specific validation and setup
case "openai" | "tgi":
if not is_openai_available():
raise RuntimeError("OpenAI backend is not available.")
if self.client is None:
from openai import OpenAI

if self.url is None:
self.client = OpenAI(api_key=self.api_key)
else:
self.client = OpenAI(base_url=self.url, api_key=self.api_key)
self.client = OpenAI(
api_key=self.api_key if self.url is None else None, base_url=self.url if self.url else None
)
return self.__call_api_parallel
case "litellm" if is_litellm_available():

case "litellm":
if not is_litellm_available():
raise RuntimeError("litellm is not available.")
return self.__call_litellm
case "vllm" if is_vllm_available():

case "vllm":
if not is_vllm_available():
raise RuntimeError("vllm is not available.")
if self.pipe is None:
from vllm import LLM, SamplingParams
from vllm.transformers_utils.tokenizer import get_tokenizer

self.sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=512)
self.sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=self.max_tokens)
self.tokenizer = get_tokenizer(self.model, tokenizer_mode="auto")
self.pipe = LLM(model=self.model, max_model_len=2048, gpu_memory_utilization=0.5, dtype="float16")
return self.__call_vllm

case "transformers":
if self.pipe is None:
import torch
Expand All @@ -136,11 +172,18 @@ def __lazy_load_client(self):
"text-generation",
model=transformers_model,
tokenizer=tokenizer,
max_new_tokens=256,
max_new_tokens=self.max_tokens,
)
return self.__call_transformers

case "inference-providers":
from huggingface_hub import AsyncInferenceClient

self.client = AsyncInferenceClient(token=self.api_key, base_url=self.url, provider=self.hf_provider)
return self.__call_hf_inference_async

case _:
return lambda x: x
raise ValueError(f"Unsupported backend: {self.backend}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NathanHB were you using this case for some specific use cases?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh it was because I wanted the details to be saved even with a failure in the judge so that we don't have to rerun everything.
I forgot to add a warning that's my bad :/


def dict_of_lists_to_list_of_dicts(self, dict_of_lists):
"""
Expand Down Expand Up @@ -287,6 +330,44 @@ def __call_api(prompt):

return results

def __call_hf_inference_async(self, prompts):
async def run_all() -> list[str]:
"""Wrap inference call into function"""
tasks = (self.__call_hf_inference(prompt) for prompt in prompts)
return await tqdm_asyncio.gather(*tasks, desc="HF inference", total=len(prompts))

try:
loop = asyncio.get_running_loop()
logger.debug("Exting event loop is found, using loop.create_task")
result = loop.run_until_complete(run_all())
except RuntimeError:
logger.debug("No running event loop found, using asyncio.run")
result = asyncio.run(run_all())

if None in result:
logger.warning("None found in inference results")

return result

async def __call_hf_inference(self, prompt):
self.client: AsyncInferenceClient
for _ in range(self.API_MAX_RETRY):
try:
response = await self.client.chat_completion(
model=self.model,
messages=prompt,
max_tokens=self.max_tokens,
)
return response.choices[0].message.content
except (InferenceTimeoutError, HTTPError) as e:
logger.warning(f"HTTP error during HF inference: {e}")
await asyncio.sleep(self.API_RETRY_SLEEP)
except Exception as e:
logger.warning(f"Unexpected error during HF inference: {e}")
await asyncio.sleep(self.API_RETRY_SLEEP)

raise Exception("Failed to get response from the HF API")

def __call_api_parallel(self, prompts):
results = []
with ThreadPoolExecutor(10) as executor:
Expand Down
40 changes: 28 additions & 12 deletions src/lighteval/metrics/metrics_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,30 +872,44 @@ def __init__(
judge_model_name: str,
template: Callable,
process_judge_response: Callable,
judge_backend: Literal["litellm", "openai", "transformers", "vllm", "tgi"],
judge_backend: Literal["litellm", "openai", "transformers", "vllm", "tgi", "inference-providers"],
short_judge_name: str | None = None,
response_format: BaseModel = None,
url: str | None = None,
hf_provider: str | None = None,
max_tokens: int | None = None,
) -> None:
logger.debug(f"Initializing JudgeLLM with backend: {judge_backend}, model: {judge_model_name}")

api_key = None

match judge_backend:
case "openai":
if judge_model_name not in self.available_models_openai:
raise ValueError(f"{judge_model_name} not in available models for llm as a judge metric")
else:
api_key = os.getenv("OPENAI_API_KEY")
url = None
api_key = os.getenv("OPENAI_API_KEY")
logger.debug("Using OpenAI backend for llm as a judge metric")

case "tgi":
api_key = os.getenv("HF_TOKEN")
url = "https://api-inference.huggingface.co/v1/"
if url is None:
url = "https://api-inference.huggingface.co/v1/"
logger.debug("Using TGI backend")

case "inference-providers":
api_key = os.getenv("HF_TOKEN")
logger.debug("Using Hugging Face Inference backend")

case "litellm":
api_key = None
url = None
logger.debug("Using LiteLLM backend for llm as a judge metric")

case "transformers" | "vllm":
logger.debug("Checking availability of Transformers or VLLM model")
api = HfApi()
models = api.list_models(model_name=judge_model_name)
url = None
api_key = None
if not models:
raise ValueError(f"{judge_model_name} not in available models for llm as a judge metric")
raise ValueError(f"{judge_model_name} not found on Hugging Face Hub")

case _:
raise ValueError(f"{judge_backend} is not a valid backend for llm as a judge metric")

Expand All @@ -904,10 +918,12 @@ def __init__(
model=judge_model_name,
templates=template,
process_judge_response=process_judge_response,
api_key=api_key,
url=url,
judge_backend=judge_backend,
response_format=response_format,
api_key=api_key,
url=url,
hf_provider=hf_provider,
max_tokens=max_tokens,
)

def compute(self, predictions: list[str], formatted_doc: Doc, **kwargs) -> dict[str, float]:
Expand Down