From 795b662cffe79fa0fa9a3f13a65113abdb4f96a9 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Fri, 6 Sep 2024 20:18:16 -0700 Subject: [PATCH] Enable Random Prefix Caching in Serving Profiling Tool (benchmark_serving.py) (#8241) --- benchmarks/benchmark_serving.py | 27 +++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index bdfa81be4208e..9ba3f649810b7 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -195,8 +195,16 @@ def sample_sonnet_requests( def sample_random_requests( - input_len: int, output_len: int, num_prompts: int, range_ratio: float, - tokenizer: PreTrainedTokenizerBase) -> List[Tuple[str, int, int]]: + prefix_len: int, + input_len: int, + output_len: int, + num_prompts: int, + range_ratio: float, + tokenizer: PreTrainedTokenizerBase, +) -> List[Tuple[str, int, int]]: + prefix_token_ids = np.random.randint(0, + tokenizer.vocab_size, + size=prefix_len).tolist() input_lens = np.random.randint( int(input_len * range_ratio), @@ -211,10 +219,12 @@ def sample_random_requests( offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts) input_requests = [] for i in range(num_prompts): - prompt = tokenizer.decode([(offsets[i] + i + j) % tokenizer.vocab_size + prompt = tokenizer.decode(prefix_token_ids + + [(offsets[i] + i + j) % tokenizer.vocab_size for j in range(input_lens[i])]) + input_requests.append( - (prompt, int(input_lens[i]), int(output_lens[i]))) + (prompt, int(prefix_len + input_lens[i]), int(output_lens[i]))) return input_requests @@ -567,6 +577,7 @@ def main(args: argparse.Namespace): elif args.dataset_name == "random": input_requests = sample_random_requests( + prefix_len=args.random_prefix_len, input_len=args.random_input_len, output_len=args.random_output_len, num_prompts=args.num_prompts, @@ -765,6 +776,14 @@ def main(args: argparse.Namespace): help="Range of sampled ratio of input/output length, " "used only for random sampling.", ) + parser.add_argument( + "--random-prefix-len", + type=int, + default=0, + help="Number of fixed prefix tokens before random " + " context. The length range of context in a random " + " request is [random-prefix-len, " + " random-prefix-len + random-prefix-len * random-range-ratio).") parser.add_argument( "--request-rate", type=float,