diff --git a/tests/v1/tpu/worker/test_tpu_model_runner.py b/tests/v1/tpu/worker/test_tpu_model_runner.py index 8ea8c890613a..5c7eab0b6b11 100644 --- a/tests/v1/tpu/worker/test_tpu_model_runner.py +++ b/tests/v1/tpu/worker/test_tpu_model_runner.py @@ -299,6 +299,18 @@ def test_get_paddings(): actual_paddings = _get_token_paddings(min_token_size, max_token_size, padding_gap) assert actual_paddings == expected_paddings + # Exponential padding. + max_token_size, padding_gap = 1024, 0 + expected_paddings = [16, 32, 64, 128, 256, 512, 1024] + actual_paddings = _get_token_paddings(min_token_size, max_token_size, + padding_gap) + assert actual_paddings == expected_paddings + # Exponential padding with max_token_size not a power of two. + max_token_size = 317 + expected_paddings = [16, 32, 64, 128, 256, 512] + actual_paddings = _get_token_paddings(min_token_size, max_token_size, + padding_gap) + assert actual_paddings == expected_paddings def test_get_padded_token_len(): diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 69251d8bbb31..6300f16c0b3f 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -1040,9 +1040,11 @@ def _get_token_paddings(min_token_size: int, max_token_size: int, if padding_gap == 0: logger.info("Using exponential token paddings:") - while num <= max_token_size: + while True: logger.info(" %d", num) paddings.append(num) + if num >= max_token_size: + break num *= 2 else: