Skip to content

Commit

Permalink
Fix eval for .pte (pytorch#1053)
Browse files Browse the repository at this point in the history
Co-authored-by: vmpuri <[email protected]>
  • Loading branch information
vmpuri and vmpuri authored Aug 27, 2024
1 parent 9dc9eff commit 0922e65
Showing 1 changed file with 20 additions and 1 deletion.
21 changes: 20 additions & 1 deletion eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def __init__(
model_forward: Optional[Callable] = None,
max_seq_length: Optional[int] = None,
device="cpu",
is_pte_model: bool = False,
):
super().__init__(device=device)
self._model = model
Expand All @@ -98,6 +99,7 @@ def __init__(
self._device = torch.device(device)
self._max_seq_length = 2048 if max_seq_length is None else max_seq_length
self.times = []
self.is_pte_model = is_pte_model

@property
def eot_token_id(self):
Expand Down Expand Up @@ -143,7 +145,21 @@ def _model_call(self, inps):
)
x = seq.index_select(0, input_pos).view(1, -1)
with measure_time(message=None) as measure:
logits = self._model_forward(x, input_pos)
if (
self.is_pte_model
): # Sequential Prefill required for ExecuTorch (.pte) models since the prompt length can introduce dynamism
width = x.size(1)
assert input_pos.size(0) == width
logits = torch.zeros(1, width, self._model.config.vocab_size).to(
x.device
)
for i in range(width):
x_sliced, ip_sliced = x[:, i].view(1, -1), input_pos[i].view(-1)
logits[0, i] = self._model_forward(
x_sliced, ip_sliced
) # (x[:, i], input_pos[i])
else:
logits = self._model_forward(x, input_pos)
self.times.append(measure.get_time())
return logits

Expand All @@ -160,6 +176,7 @@ def eval(
limit: Optional[int] = None,
max_seq_length: Optional[int] = None,
device: str = "cpu",
is_pte_model: bool = False,
) -> dict:
"""
Evaluates a language model on a specified task using the lm-evaluation-harness library.
Expand All @@ -183,6 +200,7 @@ def eval(
model_forward=model_forward,
max_seq_length=max_seq_length,
device=device,
is_pte_model=is_pte_model,
)

try:
Expand Down Expand Up @@ -257,6 +275,7 @@ def main(args) -> None:
limit,
max_seq_length,
device=builder_args.device,
is_pte_model=builder_args.pte_path is not None,
)

times = torch.tensor(result["times"])
Expand Down

0 comments on commit 0922e65

Please sign in to comment.