From 54cf3897c37f36888d6672ef5439417422f464ba Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Fri, 28 Feb 2025 17:01:47 -0800 Subject: [PATCH] Add basic performance metrics to native llama runner --- examples/models/llama/runner/generation.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/examples/models/llama/runner/generation.py b/examples/models/llama/runner/generation.py index 3e9ceb34af5..4ba645ffd87 100644 --- a/examples/models/llama/runner/generation.py +++ b/examples/models/llama/runner/generation.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import time from abc import ABC, abstractmethod from typing import List, Optional @@ -97,6 +98,7 @@ def generate( # noqa: C901 pos_base: int = 0, ) -> List[int]: # Prefill + prefill_start = time.time() logits = self.forward( tokens=torch.tensor([prompt_tokens], dtype=torch.long, device=self.device), input_pos=( @@ -105,11 +107,13 @@ def generate( # noqa: C901 else None ), ) + prefill_time = time.time() - prefill_start current_token = next_token(logits, temperature, top_p) print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True) tokens = prompt_tokens + [current_token] + generate_start = time.time() while len(tokens) < max_seq_len: if self.use_kv_cache: logits = self.forward( @@ -140,6 +144,10 @@ def generate( # noqa: C901 print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True) print("\n") + generate_time = time.time() - generate_start + print(f"Prefill time: {prefill_time}") + print(f"Generation tok/s: {len(tokens) / generate_time}") + return tokens if echo else tokens[len(prompt_tokens) :] def text_completion(