|
23 | 23 | import asyncio
|
24 | 24 | from typing import Coroutine, List, Optional, Union
|
25 | 25 |
|
| 26 | +import torch |
26 | 27 | from huggingface_hub import (
|
27 | 28 | AsyncInferenceClient,
|
28 | 29 | InferenceClient,
|
@@ -314,17 +315,22 @@ def loglikelihood(
|
314 | 315 | ):
|
315 | 316 | dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=lambda batch: batch)
|
316 | 317 |
|
317 |
| - for batch in tqdm(dataloader, desc="Loglikleihoods", position=1, leave=False, disable=self.disable_tqdm): |
| 318 | + for batch in tqdm(dataloader, desc="Loglikelihoods", position=1, leave=False, disable=self.disable_tqdm): |
318 | 319 | if self.use_async:
|
319 | 320 | responses = asyncio.run(self.__async_process_batch_logprob(batch))
|
320 | 321 | else:
|
321 | 322 | responses = self.__process_batch_logprob(batch)
|
322 |
| - for ix, response in enumerate(responses): |
323 |
| - len_choice = len(batch[ix].tokenized_continuation) |
| 323 | + for cur_request, response in zip(batch, responses): |
| 324 | + cont_toks = torch.tensor(cur_request.tokenized_continuation) |
| 325 | + len_choice = len(cont_toks) |
| 326 | + |
324 | 327 | logits = [t.logprob for t in response.details.prefill[-len_choice:] if t.logprob is not None]
|
| 328 | + |
| 329 | + greedy_tokens = torch.tensor(logits).argmax(dim=-1) |
| 330 | + max_equal = (greedy_tokens == cont_toks).all().squeeze(0) |
325 | 331 | results.append(
|
326 | 332 | LoglikelihoodReturn(
|
327 |
| - result=sum(logits), |
| 333 | + result=(sum(logits), bool(max_equal)), |
328 | 334 | input_tokens=[t.id for t in response.details.prefill[:-len_choice]],
|
329 | 335 | generated_tokens=[t.id for t in response.details.prefill[-len_choice:]],
|
330 | 336 | truncated_tokens_count=-1,
|
@@ -355,13 +361,16 @@ def loglikelihood_rolling(
|
355 | 361 | ):
|
356 | 362 | dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=lambda batch: batch)
|
357 | 363 |
|
358 |
| - for batch in tqdm(dataloader, desc="Loglikleihoods", position=1, leave=False, disable=self.disable_tqdm): |
| 364 | + for batch in tqdm( |
| 365 | + dataloader, desc="Loglikelihoods, rolling", position=1, leave=False, disable=self.disable_tqdm |
| 366 | + ): |
359 | 367 | if self.use_async:
|
360 | 368 | responses = asyncio.run(self.__async_process_batch_logprob(batch, rolling=True))
|
361 | 369 | else:
|
362 | 370 | responses = self.__process_batch_logprob(batch, rolling=True)
|
363 | 371 | for response in responses:
|
364 | 372 | logits = [t.logprob for t in response.details.tokens[:-1]]
|
| 373 | + |
365 | 374 | results.append(
|
366 | 375 | LoglikelihoodReturn(
|
367 | 376 | result=sum(logits),
|
|
0 commit comments