diff --git a/src/lighteval/models/endpoint_model.py b/src/lighteval/models/endpoint_model.py index e8f18502f..c1b5a50fe 100644 --- a/src/lighteval/models/endpoint_model.py +++ b/src/lighteval/models/endpoint_model.py @@ -23,6 +23,7 @@ import asyncio from typing import Coroutine, List, Optional, Union +import torch from huggingface_hub import ( AsyncInferenceClient, InferenceClient, @@ -314,17 +315,22 @@ def loglikelihood( ): dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=lambda batch: batch) - for batch in tqdm(dataloader, desc="Loglikleihoods", position=1, leave=False, disable=self.disable_tqdm): + for batch in tqdm(dataloader, desc="Loglikelihoods", position=1, leave=False, disable=self.disable_tqdm): if self.use_async: responses = asyncio.run(self.__async_process_batch_logprob(batch)) else: responses = self.__process_batch_logprob(batch) - for ix, response in enumerate(responses): - len_choice = len(batch[ix].tokenized_continuation) + for cur_request, response in zip(batch, responses): + cont_toks = torch.tensor(cur_request.tokenized_continuation) + len_choice = len(cont_toks) + logits = [t.logprob for t in response.details.prefill[-len_choice:] if t.logprob is not None] + + greedy_tokens = torch.tensor(logits).argmax(dim=-1) + max_equal = (greedy_tokens == cont_toks).all().squeeze(0) results.append( LoglikelihoodReturn( - result=sum(logits), + result=(sum(logits), bool(max_equal)), input_tokens=[t.id for t in response.details.prefill[:-len_choice]], generated_tokens=[t.id for t in response.details.prefill[-len_choice:]], truncated_tokens_count=-1, @@ -355,13 +361,16 @@ def loglikelihood_rolling( ): dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=lambda batch: batch) - for batch in tqdm(dataloader, desc="Loglikleihoods", position=1, leave=False, disable=self.disable_tqdm): + for batch in tqdm( + dataloader, desc="Loglikelihoods, rolling", position=1, leave=False, disable=self.disable_tqdm + ): if self.use_async: responses = asyncio.run(self.__async_process_batch_logprob(batch, rolling=True)) else: responses = self.__process_batch_logprob(batch, rolling=True) for response in responses: logits = [t.logprob for t in response.details.tokens[:-1]] + results.append( LoglikelihoodReturn( result=sum(logits),