Skip to content

Commit 0fdb9d7

Browse files
authored
fixed the loglikelihood models, was not returning the boolean value (#119)
1 parent af36b5b commit 0fdb9d7

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

src/lighteval/models/endpoint_model.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import asyncio
2424
from typing import Coroutine, List, Optional, Union
2525

26+
import torch
2627
from huggingface_hub import (
2728
AsyncInferenceClient,
2829
InferenceClient,
@@ -314,17 +315,22 @@ def loglikelihood(
314315
):
315316
dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=lambda batch: batch)
316317

317-
for batch in tqdm(dataloader, desc="Loglikleihoods", position=1, leave=False, disable=self.disable_tqdm):
318+
for batch in tqdm(dataloader, desc="Loglikelihoods", position=1, leave=False, disable=self.disable_tqdm):
318319
if self.use_async:
319320
responses = asyncio.run(self.__async_process_batch_logprob(batch))
320321
else:
321322
responses = self.__process_batch_logprob(batch)
322-
for ix, response in enumerate(responses):
323-
len_choice = len(batch[ix].tokenized_continuation)
323+
for cur_request, response in zip(batch, responses):
324+
cont_toks = torch.tensor(cur_request.tokenized_continuation)
325+
len_choice = len(cont_toks)
326+
324327
logits = [t.logprob for t in response.details.prefill[-len_choice:] if t.logprob is not None]
328+
329+
greedy_tokens = torch.tensor(logits).argmax(dim=-1)
330+
max_equal = (greedy_tokens == cont_toks).all().squeeze(0)
325331
results.append(
326332
LoglikelihoodReturn(
327-
result=sum(logits),
333+
result=(sum(logits), bool(max_equal)),
328334
input_tokens=[t.id for t in response.details.prefill[:-len_choice]],
329335
generated_tokens=[t.id for t in response.details.prefill[-len_choice:]],
330336
truncated_tokens_count=-1,
@@ -355,13 +361,16 @@ def loglikelihood_rolling(
355361
):
356362
dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=lambda batch: batch)
357363

358-
for batch in tqdm(dataloader, desc="Loglikleihoods", position=1, leave=False, disable=self.disable_tqdm):
364+
for batch in tqdm(
365+
dataloader, desc="Loglikelihoods, rolling", position=1, leave=False, disable=self.disable_tqdm
366+
):
359367
if self.use_async:
360368
responses = asyncio.run(self.__async_process_batch_logprob(batch, rolling=True))
361369
else:
362370
responses = self.__process_batch_logprob(batch, rolling=True)
363371
for response in responses:
364372
logits = [t.logprob for t in response.details.tokens[:-1]]
373+
365374
results.append(
366375
LoglikelihoodReturn(
367376
result=sum(logits),

0 commit comments

Comments
 (0)