diff --git a/.buildkite/features/Speculative_Decoding-_DFlash.yml b/.buildkite/features/Speculative_Decoding-_DFlash.yml new file mode 100644 index 0000000000..994ea98ee1 --- /dev/null +++ b/.buildkite/features/Speculative_Decoding-_DFlash.yml @@ -0,0 +1,68 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pipeline-name: Speculative Decoding: DFlash +# pipeline-type: feature support matrix +steps: + - label: "${TPU_VERSION:-tpu6e} Correctness tests for Speculative Decoding: DFlash" + key: "${TPU_VERSION:-tpu6e}_Speculative_Decoding-_DFlash_CorrectnessTest" + soft_fail: true + agents: + queue: "${TPU_QUEUE_SINGLE:-tpu_v6e_queue}" + env: + TPU_VERSION: "${TPU_VERSION:-tpu6e}" + commands: + - | + .buildkite/scripts/run_in_docker.sh \ + python3 -m pytest -s -v /workspace/tpu_inference/tests/e2e/test_speculative_decoding.py::test_dflash_correctness + - label: "${TPU_VERSION:-tpu6e} Record correctness test result for Speculative Decoding: DFlash" + key: "${TPU_VERSION:-tpu6e}_record_Speculative_Decoding-_DFlash_CorrectnessTest" + depends_on: "${TPU_VERSION:-tpu6e}_Speculative_Decoding-_DFlash_CorrectnessTest" + env: + CI_TPU_VERSION: "${TPU_VERSION:-tpu6e}" + CI_TARGET: "Speculative Decoding: DFlash" + CI_STAGE: "CorrectnessTest" + CI_CATEGORY: "feature support matrix" + agents: + queue: cpu + commands: + - | + .buildkite/scripts/record_step_result.sh ${TPU_VERSION:-tpu6e}_Speculative_Decoding-_DFlash_CorrectnessTest + + - label: "${TPU_VERSION:-tpu6e} Performance tests for Speculative Decoding: DFlash" + key: "${TPU_VERSION:-tpu6e}_Speculative_Decoding-_DFlash_PerformanceTest" + depends_on: "${TPU_VERSION:-tpu6e}_record_Speculative_Decoding-_DFlash_CorrectnessTest" + soft_fail: true + agents: + queue: "${TPU_QUEUE_SINGLE:-tpu_v6e_queue}" + env: + TPU_VERSION: "${TPU_VERSION:-tpu6e}" + commands: + - | + .buildkite/scripts/run_in_docker.sh \ + python3 -m pytest -s -v /workspace/tpu_inference/tests/e2e/test_speculative_decoding.py::test_dflash_performance + + - label: "${TPU_VERSION:-tpu6e} Record performance test result for Speculative Decoding: DFlash" + key: "${TPU_VERSION:-tpu6e}_record_Speculative_Decoding-_DFlash_PerformanceTest" + depends_on: "${TPU_VERSION:-tpu6e}_Speculative_Decoding-_DFlash_PerformanceTest" + env: + CI_TPU_VERSION: "${TPU_VERSION:-tpu6e}" + CI_TARGET: "Speculative Decoding: DFlash" + CI_STAGE: "PerformanceTest" + CI_CATEGORY: "feature support matrix" + agents: + queue: cpu + commands: + - | + .buildkite/scripts/record_step_result.sh ${TPU_VERSION:-tpu6e}_Speculative_Decoding-_DFlash_PerformanceTest diff --git a/tests/e2e/test_speculative_decoding.py b/tests/e2e/test_speculative_decoding.py index 1238f01187..2a2cb6c936 100644 --- a/tests/e2e/test_speculative_decoding.py +++ b/tests/e2e/test_speculative_decoding.py @@ -59,11 +59,24 @@ def get_eagle3_test_prompts(): return prompts +def get_dflash_test_prompts(): + num_prompts = 100 + prompts = [] + + for _ in range(num_prompts): + prompts.append( + "Predict the continuation of this sequence: 1 2 3 4 5 6 7 8") + + return prompts + + def get_test_prompts(speculative_config: dict): if speculative_config['method'] == 'ngram': return get_ngram_test_prompts() elif speculative_config['method'] == 'eagle3': return get_eagle3_test_prompts() + elif speculative_config['method'] == 'dflash': + return get_dflash_test_prompts() else: raise NotImplementedError( f"{speculative_config['method']} is not supported yet.") @@ -313,3 +326,40 @@ def test_eagle3_performance( "num_speculative_tokens": 2, "draft_tensor_parallel_size": 1 }, 0.6 if _is_v7x() else 1.8) + + +def test_dflash_correctness( + monkeypatch: pytest.MonkeyPatch, + sampling_config: SamplingParams, +): + ''' + Compare the outputs of a original LLM and a speculative LLM + should be the same when using DFlash speculative decoding. + ''' + model_name = 'Qwen/Qwen3-4B' + + _test_correctness_helper( + monkeypatch, sampling_config, model_name, { + 'model': "z-lab/Qwen3-4B-DFlash-b16", + "num_speculative_tokens": 16, + "method": "dflash", + "draft_tensor_parallel_size": 1 + }) + + +def test_dflash_performance( + monkeypatch: pytest.MonkeyPatch, + sampling_config: SamplingParams, +): + ''' + Test that DFlash speculative decoding provides significant performance + improvement. Compares timing between reference LLM and speculative LLM + using Qwen3-4B. + ''' + _test_performance_helper( + monkeypatch, sampling_config, { + "method": "dflash", + "model": "z-lab/Qwen3-4B-DFlash-b16", + "num_speculative_tokens": 16, + "draft_tensor_parallel_size": 1 + }, 0.6 if _is_v7x() else 1.5)