diff --git a/.github/workflows/scripts/config.yaml b/.github/workflows/scripts/config.yaml index a2b0499b78b..bb634ba5312 100644 --- a/.github/workflows/scripts/config.yaml +++ b/.github/workflows/scripts/config.yaml @@ -166,3 +166,5 @@ e2e-multicard-4-cards: estimated_time: 1391 - name: tests/e2e/multicard/4-cards/test_pipeline_parallel.py estimated_time: 679 +- name: tests/e2e/multicard/4-cards/test_profiling_chunk_performance.py + estimated_time: 1300 diff --git a/tests/e2e/multicard/4-cards/test_profiling_chunk_performance.py b/tests/e2e/multicard/4-cards/test_profiling_chunk_performance.py new file mode 100644 index 00000000000..c3c1f25c3d1 --- /dev/null +++ b/tests/e2e/multicard/4-cards/test_profiling_chunk_performance.py @@ -0,0 +1,82 @@ +"""Performance guard for profiling-based dynamic chunk sizing (PP scenario). + +Measures Time-To-First-Token (TTFT) on 64k-token prefill requests with +profiling_chunk_config enabled. The test runs against +DeepSeek-V2-Lite-Chat served with PP=2, TP=2 (4 NPU cards total). + +Test flow: + 1. Create an LLM engine with profiling_chunk_config enabled. + 2. Run NUM_WARMUP sequential requests (64k tokens, max_tokens=1) to warm + up both the NPU and the profiling predictor. + 3. Run NUM_TEST sequential requests, recording TTFT for each. + 4. Assert that the median TTFT does not exceed BASELINE_TTFT_S seconds. +""" + +import os +import statistics +import time + +from tests.e2e.conftest import VllmRunner + +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" +os.environ["VLLM_ALLOW_LONG_MAX_MODEL_LEN"] = "1" +os.environ["VLLM_ASCEND_ENABLE_FLASHCOMM1"] = "1" + +MODEL = "Qwen/Qwen3-30B-A3B" + +# ~64k tokens +_WORD = "hello " +INPUT_64K_TOKENS = _WORD * (384_000 // len(_WORD)) + +NUM_WARMUP = 5 +NUM_TEST = 5 + +# NOTE: Any changes to this baseline must be approved by team members. +# Measured on Qwen3-30B-A3B, PP=2, TP=2, 64k prefill, profiling_chunk enabled. +BASELINE_TTFT_S = 5.2 + + +def test_profiling_chunk_ttft_performance() -> None: + with VllmRunner( + MODEL, + max_model_len=70000, + tensor_parallel_size=2, + pipeline_parallel_size=2, + block_size=128, + enable_expert_parallel=True, + enable_prefix_caching=False, + gpu_memory_utilization=0.9, + max_num_batched_tokens=12288, + distributed_executor_backend="mp", + enforce_eager=True, + async_scheduling=False, + additional_config={"profiling_chunk_config": {"enabled":True, "smooth_factor":0.9}, "enable_cpu_binding": False}, + hf_overrides={"rope_parameters": {"rope_type":"yarn","rope_theta":1000,"factor":5,"original_max_position_embeddings":262144}} + ) as vllm_model: + # With max_tokens=1, total latency ≈ prefill time ≈ TTFT + prompts = [INPUT_64K_TOKENS] + + # ── Warmup ────────────────────────────────────────────────────────── + for _ in range(NUM_WARMUP): + vllm_model.generate_greedy(prompts, max_tokens=1) + + # ── Measurement ───────────────────────────────────────────────────── + ttfts: list[float] = [] + for _ in range(NUM_TEST): + start = time.perf_counter() + vllm_model.generate_greedy(prompts, max_tokens=1) + ttfts.append(time.perf_counter() - start) + + median_ttft = statistics.median(ttfts) + ttft_str = ", ".join(f"{t:.2f}s" for t in ttfts) + print( + f"\n[profiling_chunk perf] TTFT per request: [{ttft_str}]" + f"\n[profiling_chunk perf] Median TTFT: {median_ttft:.2f}s " + f"(baseline: {BASELINE_TTFT_S}s)" + ) + + assert median_ttft <= BASELINE_TTFT_S, ( + f"TTFT performance regression: median TTFT {median_ttft:.2f}s " + f"exceeds baseline {BASELINE_TTFT_S}s. " + f"Individual TTFTs: [{ttft_str}]" + ) \ No newline at end of file diff --git a/tests/ut/core/test_profiling_chunk.py b/tests/ut/core/test_profiling_chunk.py new file mode 100644 index 00000000000..c36a7bc0224 --- /dev/null +++ b/tests/ut/core/test_profiling_chunk.py @@ -0,0 +1,435 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import Any, Dict, Optional +from unittest.mock import MagicMock, patch + +import torch +from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, + SchedulerConfig, VllmConfig) +from vllm.sampling_params import SamplingParams +from vllm.utils.hashing import sha256 +from vllm.v1.core.kv_cache_utils import (get_request_block_hasher, + init_none_hash) +from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, + KVCacheGroupSpec) +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.request import Request +from vllm.v1.structured_output import StructuredOutputManager + +from tests.ut.base import TestBase +from vllm_ascend.ascend_config import (ProfilingChunkConfig, + clear_ascend_config, init_ascend_config) +from vllm_ascend.core.profiling_chunk_predictor import (ChunkSizePredictor, + ProfilingChunkManager) +from vllm_ascend.core.scheduler_profiling_chunk import \ + ProfilingChunkScheduler + + +MODEL = "Qwen/Qwen3-0.6B" +BLOCK_SIZE = 16 +MAX_NUM_BATCHED_TOKENS = 8192 +MAX_NUM_SEQS = 16 + + +def create_requests(num_requests, num_tokens=10, max_tokens=16): + init_none_hash(sha256) + sampling_params = SamplingParams(ignore_eos=False, max_tokens=max_tokens) + requests = [] + for i in range(num_requests): + request = Request( + request_id=f"{i}", + prompt_token_ids=[i] * num_tokens, + sampling_params=sampling_params, + pooling_params=None, + block_hasher=get_request_block_hasher(BLOCK_SIZE, sha256), + ) + requests.append(request) + return requests + + +def make_output(scheduler): + req_ids = [req.request_id for req in scheduler.running] + req_id_to_index = { + req.request_id: i + for i, req in enumerate(scheduler.running) + } + sampled_token_ids = [[1000]] * len(scheduler.running) + return ModelRunnerOutput( + req_ids=req_ids, + req_id_to_index=req_id_to_index, + sampled_token_ids=sampled_token_ids, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + ) + + +# =================================================================== +# ProfilingChunkConfig +# =================================================================== + + +class TestProfilingChunkConfig(TestBase): + + def test_default_values(self): + cfg = ProfilingChunkConfig() + self.assertFalse(cfg.enabled) + self.assertAlmostEqual(cfg.smooth_factor, 0.8) + self.assertEqual(cfg.min_chunk, 4096) + + def test_invalid_smooth_factor_raises(self): + with self.assertRaises(ValueError): + ProfilingChunkConfig({"smooth_factor": 0.0}) + with self.assertRaises(ValueError): + ProfilingChunkConfig({"smooth_factor": 1.5}) + + def test_invalid_min_chunk_raises(self): + with self.assertRaises(ValueError): + ProfilingChunkConfig({"min_chunk": 0}) + + @patch("vllm_ascend.platform.NPUPlatform._fix_incompatible_config") + def test_enabled_without_pp_raises(self, _mock): + clear_ascend_config() + vllm_config = VllmConfig() + vllm_config.additional_config = { + "profiling_chunk_config": {"enabled": True}, + "refresh": True, + } + vllm_config.parallel_config.pipeline_parallel_size = 1 + with self.assertRaises(ValueError) as ctx: + init_ascend_config(vllm_config) + self.assertIn("pipeline parallelism", str(ctx.exception)) + clear_ascend_config() + + @patch("vllm_ascend.platform.NPUPlatform._fix_incompatible_config") + def test_enabled_with_pp_ok(self, _mock): + clear_ascend_config() + vllm_config = VllmConfig() + vllm_config.additional_config = { + "profiling_chunk_config": {"enabled": True}, + "refresh": True, + } + vllm_config.parallel_config.pipeline_parallel_size = 2 + ascend_config = init_ascend_config(vllm_config) + self.assertTrue(ascend_config.profiling_chunk_config.enabled) + clear_ascend_config() + + @patch("vllm_ascend.platform.NPUPlatform._fix_incompatible_config") + def test_disabled_without_pp_ok(self, _mock): + clear_ascend_config() + vllm_config = VllmConfig() + vllm_config.additional_config = {"refresh": True} + ascend_config = init_ascend_config(vllm_config) + self.assertFalse(ascend_config.profiling_chunk_config.enabled) + clear_ascend_config() + + +# =================================================================== +# ChunkSizePredictor +# =================================================================== + + +class TestChunkSizePredictor(TestBase): + + @staticmethod + def _make_data(a, b, c, seq_lens): + return [a * l * l + b * l + c for l in seq_lens] + + def test_fit_and_predict(self): + predictor = ChunkSizePredictor() + seq_lens = list(range(64, 8256, 128)) + latencies = self._make_data(1e-6, 0.01, 1.0, seq_lens) + + self.assertTrue(predictor.fit(seq_lens, latencies)) + predictor.set_target_latency(8192) + predictor.is_ready = True + + chunk = predictor.predict( + num_computed_tokens=0, base_chunk_size=8192, page_size=128) + self.assertIsNotNone(chunk) + self.assertEqual(chunk % 128, 0) + + def test_predict_decreases_with_history(self): + predictor = ChunkSizePredictor() + seq_lens = list(range(64, 8256, 128)) + latencies = self._make_data(1e-6, 0.01, 1.0, seq_lens) + predictor.fit(seq_lens, latencies) + predictor.set_target_latency(8192) + predictor.is_ready = True + + c0 = predictor.predict(0, 8192, 128) + c1 = predictor.predict(4096, 8192, 128) + c2 = predictor.predict(16384, 8192, 128) + self.assertGreaterEqual(c0, c1) + self.assertGreaterEqual(c1, c2) + + def test_predict_not_ready_returns_none(self): + predictor = ChunkSizePredictor() + self.assertIsNone(predictor.predict(0, 8192, 128)) + + def test_fit_chunk_and_predict_with_history(self): + predictor = ChunkSizePredictor() + predictor.is_ready = True + predictor.target_latency = 50.0 + + data = [] + for i in range(10): + c, h = 1000 + i * 100, i * 500 + data.append([(c + h) * c, c + h, 1, 1e-9 * (c + h) * c + 0.001 * (c + h) + 0.5]) + self.assertTrue(predictor.fit_chunk(data)) + predictor.with_history_ready = True + + result = predictor.predict_with_history(1000, 8192, 128) + self.assertIsNotNone(result) + self.assertEqual(result % 128, 0) + + +# =================================================================== +# ProfilingChunkManager +# =================================================================== + + +class TestProfilingChunkManager(TestBase): + + def test_not_ready_before_profiling(self): + mgr = ProfilingChunkManager(base_chunk_size=8192, page_size=128) + self.assertFalse(mgr.is_ready) + self.assertIsNone(mgr.predict_chunk_size(0, 1.0)) + + def test_run_profiling_success(self): + mgr = ProfilingChunkManager(base_chunk_size=8192, page_size=128) + seq_lens = list(range(64, 8256, 128)) + latencies = [1e-6 * l * l + 0.01 * l + 1.0 for l in seq_lens] + self.assertTrue(mgr.predictor.fit(seq_lens, latencies)) + mgr.predictor.set_target_latency(8192) + mgr.predictor.is_ready = True + mgr._profiling_done = True + + self.assertTrue(mgr.is_ready) + self.assertIsNotNone(mgr.predict_chunk_size(0, 1.0)) + + def test_run_profiling_all_fail(self): + mgr = ProfilingChunkManager(base_chunk_size=8192, page_size=128) + too_few_seq_lens = [64, 128, 256] + too_few_latencies = [1.0, 2.0, 3.0] + self.assertFalse(mgr.predictor.fit(too_few_seq_lens, too_few_latencies)) + self.assertFalse(mgr.is_ready) + self.assertIsNone(mgr.predict_chunk_size(0, 1.0)) + + def test_record_batch_refines_model(self): + mgr = ProfilingChunkManager(base_chunk_size=8192, page_size=128) + seq_lens = list(range(64, 8256, 128)) + latencies = [1e-6 * l * l + 0.01 * l + 1.0 for l in seq_lens] + mgr.predictor.fit(seq_lens, latencies) + mgr.predictor.set_target_latency(8192) + mgr.predictor.is_ready = True + mgr._profiling_done = True + + for i in range(10): + mgr.record_batch_execution_time( + [(4096 - i * 100, i * 500)], 0.05 + i * 0.01) + self.assertGreaterEqual(len(mgr.chunked_fit_data), 10) + self.assertTrue(mgr.history_ready) + + +# =================================================================== +# ProfilingChunkScheduler +# =================================================================== + + +class TestProfilingChunkScheduler(TestBase): + + @patch("vllm_ascend.ascend_config.AscendConfig.__init__", MagicMock(return_value=None)) + @patch("vllm_ascend.ascend_config.get_ascend_config") + @patch("vllm.config.ModelConfig.__post_init__", MagicMock()) + @patch("vllm.config.VllmConfig.__post_init__", MagicMock()) + def create_scheduler(self, mock_get_ascend_config): + profiling_cfg = MagicMock() + profiling_cfg.enabled = True + profiling_cfg.smooth_factor = 0.8 + profiling_cfg.min_chunk = 256 + mock_get_ascend_config.return_value = MagicMock( + profiling_chunk_config=profiling_cfg) + + mock_hf_config = MagicMock() + mock_hf_config.model_type = "qwen3" + mock_hf_config.is_encoder_decoder = False + mock_hf_config.architectures = ["Qwen3ForCausalLM"] + model_config = ModelConfig( + model=MODEL, + tokenizer=MODEL, + trust_remote_code=True, + dtype="float16", + seed=42, + max_model_len=MAX_NUM_BATCHED_TOKENS, + ) + model_config.hf_config = mock_hf_config + model_config.hf_text_config = MagicMock() + model_config.hf_text_config.is_encoder_decoder = False + + scheduler_config = SchedulerConfig( + max_num_seqs=MAX_NUM_SEQS, + max_model_len=MAX_NUM_BATCHED_TOKENS, + long_prefill_token_threshold=0, + disable_chunked_mm_input=False, + enable_chunked_prefill=True, + max_num_batched_tokens=MAX_NUM_BATCHED_TOKENS, + is_encoder_decoder=False, + ) + scheduler_config.max_num_encoder_input_tokens = 10000 + scheduler_config.encoder_cache_size = 10000 + scheduler_config.chunked_prefill_enabled = True + + cache_config = CacheConfig( + block_size=BLOCK_SIZE, gpu_memory_utilization=0.9, + cache_dtype="auto", + ) + + vllm_config = VllmConfig( + scheduler_config=scheduler_config, + model_config=model_config, + cache_config=cache_config, + ) + vllm_config.parallel_config.pipeline_parallel_size = 2 + from unittest.mock import PropertyMock + type(model_config).is_encoder_decoder = PropertyMock(return_value=False) + vllm_config.model_config.hf_config.is_encoder_decoder = False + + kv_cache_config = KVCacheConfig( + num_blocks=10000, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec( + ['layer'], + FullAttentionSpec( + block_size=BLOCK_SIZE, + num_kv_heads=1, + head_size=1, + dtype=torch.float32 + ) + ) + ], + ) + kv_cache_config.hash_block_size = BLOCK_SIZE + cache_config.num_gpu_blocks = 10000 + + scheduler = ProfilingChunkScheduler( + vllm_config=vllm_config, + kv_cache_config=kv_cache_config, + block_size=BLOCK_SIZE, + log_stats=True, + structured_output_manager=MagicMock(spec=StructuredOutputManager), + ) + + should_advance = MagicMock() + should_advance.return_value = False + scheduler.structured_output_manager.should_advance = should_advance + + return scheduler + + def test_scheduler_init(self): + scheduler = self.create_scheduler() + self.assertIsNotNone(scheduler.profiling_chunk_manager) + self.assertFalse(scheduler._profiling_initialized) + + def test_run_profiling_chunk_init_success(self): + scheduler = self.create_scheduler() + mock_executor = MagicMock() + mock_executor.collective_rpc.return_value = [10.0] + + scheduler.run_profiling_chunk_init(mock_executor) + + self.assertTrue(scheduler._profiling_initialized) + self.assertTrue(scheduler.profiling_chunk_manager.is_ready) + + def test_run_profiling_chunk_init_skips_second_call(self): + scheduler = self.create_scheduler() + mock_executor = MagicMock() + mock_executor.collective_rpc.return_value = [10.0] + + scheduler.run_profiling_chunk_init(mock_executor) + call_count = mock_executor.collective_rpc.call_count + + scheduler.run_profiling_chunk_init(mock_executor) + self.assertEqual(mock_executor.collective_rpc.call_count, call_count) + + def test_run_profiling_chunk_init_none_executor(self): + scheduler = self.create_scheduler() + scheduler.run_profiling_chunk_init(None) + self.assertTrue(scheduler._profiling_initialized) + self.assertFalse(scheduler.profiling_chunk_manager.is_ready) + + def test_schedule_new_requests(self): + scheduler = self.create_scheduler() + requests = create_requests(num_requests=5) + for req in requests: + scheduler.add_request(req) + + output = scheduler.schedule() + self.assertEqual(len(output.scheduled_new_reqs), 5) + self.assertEqual(len(scheduler.waiting), 0) + self.assertEqual(len(scheduler.running), 5) + + def test_schedule_with_profiling_ready(self): + """After profiling is ready, schedule() should still work correctly.""" + scheduler = self.create_scheduler() + mock_executor = MagicMock() + mock_executor.collective_rpc.return_value = [10.0] + scheduler.run_profiling_chunk_init(mock_executor) + self.assertTrue(scheduler.profiling_chunk_manager.is_ready) + + requests = create_requests(num_requests=3, num_tokens=100) + for req in requests: + scheduler.add_request(req) + + output = scheduler.schedule() + self.assertGreater(len(output.scheduled_new_reqs), 0) + total = sum(output.num_scheduled_tokens.values()) + self.assertGreater(total, 0) + + def test_schedule_chunked_prefill_running(self): + """Running requests with num_computed_tokens > 0 get dynamic chunk.""" + scheduler = self.create_scheduler() + mock_executor = MagicMock() + mock_executor.collective_rpc.return_value = [10.0] + scheduler.run_profiling_chunk_init(mock_executor) + + requests = create_requests(num_requests=1, num_tokens=2000, + max_tokens=16) + for req in requests: + scheduler.add_request(req) + + output1 = scheduler.schedule() + self.assertEqual(len(output1.scheduled_new_reqs), 1) + + model_output = make_output(scheduler) + scheduler.update_from_output(output1, model_output) + + output2 = scheduler.schedule() + self.assertGreater(output2.total_num_scheduled_tokens, 0) + + def test_update_from_output(self): + scheduler = self.create_scheduler() + requests = create_requests(num_requests=3) + for req in requests: + scheduler.add_request(req) + + output = scheduler.schedule() + model_output = make_output(scheduler) + scheduler.update_from_output(output, model_output) + + self.assertEqual(len(scheduler.running), 3) diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 0da3c0211a3..c40c739da78 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -50,6 +50,23 @@ def __init__(self, vllm_config: "VllmConfig"): weight_prefetch_config = additional_config.get("weight_prefetch_config", {}) self.weight_prefetch_config = WeightPrefetchConfig(weight_prefetch_config) + profiling_chunk_config = additional_config.get("profiling_chunk_config", {}) + self.profiling_chunk_config = ProfilingChunkConfig(profiling_chunk_config) + if self.profiling_chunk_config.enabled and vllm_config.parallel_config.pipeline_parallel_size <= 1: + raise ValueError( + "profiling_chunk_config requires pipeline parallelism (pp > 1). " + "Please set --pipeline-parallel-size to a value greater than 1, " + "or disable profiling_chunk_config." + ) + + from vllm_ascend import envs as ascend_envs + + if self.profiling_chunk_config.enabled and ascend_envs.VLLM_ASCEND_BALANCE_SCHEDULING: + raise ValueError( + "profiling_chunk_config and balance scheduling (VLLM_ASCEND_BALANCE_SCHEDULING) " + "cannot be enabled at the same time. Please disable one of them." + ) + # Dump / PrecisionDebugger configuration self.dump_config_path = additional_config.get("dump_config_path", None) self.layer_sharding = additional_config.get("layer_sharding", None) @@ -425,6 +442,36 @@ def __init__(self, weight_prefetch_config: dict): self.prefetch_ratio = weight_prefetch_config.get("prefetch_ratio", self.prefetch_ratio) +class ProfilingChunkConfig: + """Configuration for profiling-based dynamic chunk sizing. + + When enabled, the scheduler profiles prefill latency during initialization + and uses a quadratic model to predict optimal chunk sizes at runtime. + + Usage (online):: + + vllm serve --additional-config '{"profiling_chunk_config": {"enabled": true}}' + + Usage (offline):: + + llm = LLM(model, additional_config={"profiling_chunk_config": {"enabled": true}}) + """ + + def __init__(self, config: dict | None = None): + if config is None: + config = {} + self.enabled: bool = config.get("enabled", False) + self.smooth_factor: float = float(config.get("smooth_factor", 0.8)) + self.min_chunk: int = int(config.get("min_chunk", 4096)) + self._validate() + + def _validate(self): + if not (0 < self.smooth_factor <= 1.0): + raise ValueError(f"profiling_chunk_config.smooth_factor must be in (0, 1], got {self.smooth_factor}") + if self.min_chunk <= 0: + raise ValueError(f"profiling_chunk_config.min_chunk must be positive, got {self.min_chunk}") + + class EplbConfig: """ Configuration Object for xlite_graph_config from additional_config @@ -491,14 +538,31 @@ def _validate_config(self): _ASCEND_CONFIG: AscendConfig | None = None +def _is_ascend_config_initialized(config: AscendConfig | None) -> bool: + """Check whether a config object has essential initialized fields. + + Some unit tests monkeypatch ``AscendConfig.__init__`` to bypass heavy + initialization. In that case, the singleton cache can be polluted with a + partially initialized instance. This guard prevents reusing such instances + across tests. + """ + if config is None: + return False + return hasattr(config, "ascend_compilation_config") and hasattr(config, "eplb_config") + + def init_ascend_config(vllm_config): additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {} refresh = additional_config.get("refresh", False) if additional_config else False global _ASCEND_CONFIG - if _ASCEND_CONFIG is not None and not refresh: + if _ASCEND_CONFIG is not None and not refresh and _is_ascend_config_initialized(_ASCEND_CONFIG): return _ASCEND_CONFIG - _ASCEND_CONFIG = AscendConfig(vllm_config) - return _ASCEND_CONFIG + new_config = AscendConfig(vllm_config) + if _is_ascend_config_initialized(new_config): + _ASCEND_CONFIG = new_config + else: + logger.warning("Ascend config instance is not fully initialized; skip singleton cache update.") + return new_config def clear_ascend_config(): @@ -508,6 +572,6 @@ def clear_ascend_config(): def get_ascend_config(): global _ASCEND_CONFIG - if _ASCEND_CONFIG is None: + if _ASCEND_CONFIG is None or not _is_ascend_config_initialized(_ASCEND_CONFIG): raise RuntimeError("Ascend config is not initialized. Please call init_ascend_config first.") return _ASCEND_CONFIG diff --git a/vllm_ascend/core/profiling_chunk_predictor.py b/vllm_ascend/core/profiling_chunk_predictor.py new file mode 100644 index 00000000000..193cd388e91 --- /dev/null +++ b/vllm_ascend/core/profiling_chunk_predictor.py @@ -0,0 +1,379 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Profiling-based Dynamic Chunk Size Predictor. + +This module implements a dynamic chunk sizing strategy based on profiling prefill +latency and fitting a quadratic model. + +The approach: +1. Profile: Run forward passes with different chunk sizes to measure latency +2. Fit: Use quadratic model f(l) = a*l^2 + b*l + c to fit the latency data +3. Predict: Given current num_computed_tokens, solve for chunk size that achieves + target latency +""" + +import math + +import numpy as np +from vllm.logger import logger + + +class ChunkSizePredictor: + """Predictor for dynamic chunk size based on quadratic latency model. + + Models latency as: f(l) = a*l^2 + b*l + c + + Given a target latency T and current history length L, predicts next + chunk size x such that: f(L+x) - f(L) = T + + This expands to the quadratic equation: a*x^2 + (2aL+b)*x - T = 0 + """ + + def __init__(self, smooth_factor: float = 0.8, min_chunk: int = 4096): + self.quadratic_coeff_a: float = 0.0 + self.linear_coeff_b: float = 0.0 + self.constant_coeff_c: float = 0.0 + + self.quadratic_chunk_a: float = 0.0 + self.linear_chunk_b: float = 0.0 + self.constant_chunk_c: float = 0.0 + + self.target_latency: float | None = None + self.is_ready: bool = False + self.with_history_ready: bool = False + self.smooth_factor = smooth_factor + self.min_chunk = min_chunk + self.history_fitted = False + + def clamp_quadratic_and_linear_if_negative(self, fitted_a: float, fitted_b: float) -> tuple[float, float]: + """In theory, for the Transfomur structure of LLM, the fitted quadratic and linear + terms should not be negative. Can perform zero clamping for inaccurate fitting + """ + if fitted_a < 0: + logger.warning("Fitted a=%.2e is not positive. Setting a=1e-9.", fitted_a) + fitted_a = 1e-9 + if fitted_b < 0: + logger.warning("Fitted b=%.2e is not positive. Setting b=0.0.", fitted_b) + fitted_b = 1e-9 + + return fitted_a, fitted_b + + def fit(self, seq_lens: list[int], latencies: list[float]) -> bool: + """Fit quadratic coefficients f(l) = al^2 + bl + c from data points. + + Returns: + True if fitting succeeded, False otherwise + """ + L = np.array(seq_lens, dtype=np.float64) + T = np.array(latencies, dtype=np.float64) + MIN_FIT_POINTS_NO_CHUNK = 8 + + if len(L) < MIN_FIT_POINTS_NO_CHUNK: + logger.warning( + "Not enough data points for quadratic fitting (%d < 8)", + len(L), + ) + return False + + X = np.column_stack([L * L, L, np.ones_like(L)]) + + try: + coeffs, _, _, _ = np.linalg.lstsq(X, T, rcond=None) + fitted_a = float(coeffs[0]) + fitted_b = float(coeffs[1]) + fitted_c = float(coeffs[2]) + except Exception as e: + # Keep a robust fallback for environments where least-squares may + # fail due backend/LAPACK differences. + try: + poly = np.polyfit(L, T, 2) + fitted_a = float(poly[0]) + fitted_b = float(poly[1]) + fitted_c = float(poly[2]) + logger.warning( + "Least-squares fitting failed (%s), fallback to polyfit succeeded.", + e, + ) + except Exception as fallback_error: + logger.warning("Failed to fit quadratic model: %s", fallback_error) + return False + + fitted_a, fitted_b = self.clamp_quadratic_and_linear_if_negative(fitted_a, fitted_b) + + self.quadratic_coeff_a = fitted_a + self.linear_coeff_b = fitted_b + self.constant_coeff_c = fitted_c + + logger.info( + "[ProfilingChunk] Fitted: a=%.2e, b=%.2e, c=%.2e", + fitted_a, + fitted_b, + fitted_c, + ) + return True + + def fit_chunk(self, chunked_data: list) -> bool: + """Fit time with chunks: f(C,H) = a*C(C+H) + b*C + c*H. + + Returns: + True if fitting succeeded, False otherwise + """ + num_points = len(chunked_data) + # experience values, can be tuned. We don't want the online calibration process + # to be too long, so we have limited the amount of data. + # 30 data points are already sufficient. + MIN_FIT_POINTS_CHUNK = 5 + MAX_FIT_POINTS_CHUNK = 30 + if num_points < MIN_FIT_POINTS_CHUNK: + logger.warning( + "Not enough data points for chunked data fitting (%d < 5)", + num_points, + ) + return False + if num_points > MAX_FIT_POINTS_CHUNK: + self.history_fitted = True + return False + + chunked_data_array = np.array(chunked_data) + execute_time = chunked_data_array[:, -1] + input_x = chunked_data_array[:, :-1] + + try: + params, _, _, _ = np.linalg.lstsq(input_x, execute_time, rcond=None) + fitted_a = float(params[0]) + fitted_b = float(params[1]) + fitted_c = float(params[2]) + except np.linalg.LinAlgError as e: + logger.warning("Failed to fit chunked model: %s", e) + return False + + fitted_a, fitted_b = self.clamp_quadratic_and_linear_if_negative(fitted_a, fitted_b) + + self.quadratic_chunk_a = fitted_a + self.linear_chunk_b = fitted_b + self.constant_chunk_c = fitted_c + + logger.info( + "[ProfilingChunk With History] Fitted: a=%.2e, b=%.2e, c=%.2e", + fitted_a, + fitted_b, + fitted_c, + ) + return True + + def set_target_latency(self, base_chunk_size: int, elapsed_time: float = 0.0) -> None: + """Set target latency based on base chunk size.""" + + def f(seq_lens: float) -> float: + return self.quadratic_coeff_a * seq_lens * seq_lens + self.linear_coeff_b * seq_lens + self.constant_coeff_c + + if elapsed_time > 0: + self.target_latency = elapsed_time + else: + self.target_latency = f(float(base_chunk_size)) - f(0.0) + if self.target_latency <= 0: + self.target_latency = 1.0 + + logger.info( + "[ProfilingChunk] Target latency: %.2f ms (base_chunk=%d)", + self.target_latency, + base_chunk_size, + ) + + def get_time( + self, + query_len: int, + num_computed_tokens: int, + ) -> float: + """Get time T based on current seq_lens, f(l) = al^2 + bl + c, f(L+x) - f(L) = T""" + + def f(seq_lens: float) -> float: + return self.quadratic_coeff_a * seq_lens * seq_lens + self.linear_coeff_b * seq_lens + self.constant_coeff_c + + return f(query_len + num_computed_tokens) - f(num_computed_tokens) + + def get_time_with_history( + self, + query_len: int, + num_computed_tokens: int, + ) -> float: + """Get time T based on current seq_lens, f(C,H) = a*C(C+H) + b*(C+H) + c = T""" + return ( + self.quadratic_chunk_a * query_len * (query_len + num_computed_tokens) + + self.linear_chunk_b * (query_len + num_computed_tokens) + + self.constant_chunk_c + ) + + def predict( + self, + num_computed_tokens: int, + base_chunk_size: int, + page_size: int, + ) -> int | None: + """Predict next chunk size x such that f(L+x) - f(L) = target_latency.""" + if not self.is_ready or self.target_latency is None: + return None + + if self.quadratic_coeff_a <= 0: + return None + + A = self.quadratic_coeff_a + B = 2 * self.quadratic_coeff_a * num_computed_tokens + self.linear_coeff_b + C = -self.target_latency + + discriminant = B * B - 4 * A * C + if discriminant < 0: + return None + + sqrt_disc = math.sqrt(discriminant) + x = (-B + sqrt_disc) / (2 * A) + + if x <= 0: + return None + + smoothed = base_chunk_size + self.smooth_factor * (x - base_chunk_size) + chunk_size = max(int(smoothed), self.min_chunk) + + align = max(page_size, 64) + chunk_size = ((chunk_size + align - 1) // align) * align + if chunk_size < align: + chunk_size = align + + logger.debug("[ProfilingChunk] Predicted chunk_size=%d", chunk_size) + return chunk_size if chunk_size >= align else None + + def predict_with_history( + self, + num_computed_tokens: int, + base_chunk_size: int, + page_size: int, + ) -> int | None: + """Predict next chunk size x using the history-aware model + f(C,H) = a*C(C+H) + b*C + c*H.""" + if not self.is_ready or self.target_latency is None: + return None + + if not self.with_history_ready: + return None + + if self.quadratic_chunk_a <= 0: + return None + + # a*C^2 + (a*H + b)*C + b*H + c - T = 0 + A = self.quadratic_chunk_a + B = self.quadratic_chunk_a * num_computed_tokens + self.linear_chunk_b + C = self.linear_chunk_b * num_computed_tokens + self.constant_chunk_c - self.target_latency + + discriminant = B * B - 4 * A * C + if discriminant < 0: + return None + + sqrt_disc = math.sqrt(discriminant) + x = (-B + sqrt_disc) / (2 * A) + + if x <= 0: + return None + + logger.debug("[ProfilingChunk] History-aware raw prediction: %.1f", x) + smoothed = base_chunk_size + self.smooth_factor * (x - base_chunk_size) + chunk_size = max(int(smoothed), self.min_chunk) + + align = max(page_size, 64) + chunk_size = ((chunk_size + align - 1) // align) * align + if chunk_size < align: + chunk_size = align + + return chunk_size if chunk_size >= align else None + + +class ProfilingChunkManager: + """Manager for profiling-based dynamic chunk sizing. + + Handles the profiling process and maintains the ChunkSizePredictor. + """ + + def __init__( + self, + base_chunk_size: int, + page_size: int, + smooth_factor: float = 0.8, + min_chunk: int = 4096, + ): + self.base_chunk_size = base_chunk_size + self.page_size = page_size + self.chunked_fit_data: list = [] + + self.predictor = ChunkSizePredictor(smooth_factor=smooth_factor, min_chunk=min_chunk) + self._profiling_done = False + self._set_time_done = False + + @property + def is_ready(self) -> bool: + return self._profiling_done and self.predictor.is_ready + + @property + def history_ready(self) -> bool: + return self.is_ready and self.predictor.with_history_ready + + def predict_chunk_size(self, num_computed_tokens: int, target_time: float) -> int | None: + """Predict optimal chunk size for given history length.""" + if not self.is_ready: + return None + + self.predictor.target_latency = target_time + + if not self.history_ready: + predict_func = self.predictor.predict + else: + predict_func = self.predictor.predict_with_history + return predict_func( + num_computed_tokens=num_computed_tokens, base_chunk_size=self.base_chunk_size, page_size=self.page_size + ) + + def predict_time(self, num_new_tokens: int, num_computed_tokens: int) -> float: + """Get the consumed time of scheduled reqs for time_budget.""" + if not self.is_ready: + return 0.0 + + if not self.history_ready: + predict_func = self.predictor.get_time + else: + predict_func = self.predictor.get_time_with_history + return predict_func(query_len=num_new_tokens, num_computed_tokens=num_computed_tokens) + + def record_batch_execution_time(self, request_chunks: list, elapsed_time: float) -> bool: + """Record batch execution time for online model refinement. + + Accumulates (x1, x2, x3, time_ms) data points and re-fits the + history-aware model once enough points are collected. + + Args: + request_chunks: List of (chunk_size, num_computed_tokens) per request + elapsed_time: Total elapsed time in seconds + """ + x1 = x2 = x3 = 0 + for chunk, hist in request_chunks: + x1 += (chunk + hist) * chunk + x2 += chunk + hist + x3 += 1 + self.chunked_fit_data.append([x1, x2, x3, elapsed_time * 1000]) + if not self.predictor.fit_chunk(self.chunked_fit_data): + return False + + self.predictor.with_history_ready = True + return True diff --git a/vllm_ascend/core/scheduler_profiling_chunk.py b/vllm_ascend/core/scheduler_profiling_chunk.py new file mode 100644 index 00000000000..01d46ca4ed2 --- /dev/null +++ b/vllm_ascend/core/scheduler_profiling_chunk.py @@ -0,0 +1,723 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Scheduler subclass with profiling-based dynamic chunk sizing. + +Compatible with vLLM v0.15.x scheduler. When the upstream ``schedule()`` +method is refactored, this override should be updated accordingly. +""" + +import inspect +import time + +from vllm.config import VllmConfig +from vllm.logger import logger +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.v1.core.kv_cache_manager import KVCacheBlocks +from vllm.v1.core.sched.interface import PauseState +from vllm.v1.core.sched.output import ( + NewRequestData, + SchedulerOutput, +) +from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue +from vllm.v1.core.sched.scheduler import Scheduler +from vllm.v1.engine import EngineCoreEventType +from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.request import Request, RequestStatus +from vllm.v1.structured_output import StructuredOutputManager +from vllm.v1.utils import record_function_or_nullcontext + +from vllm_ascend.core.profiling_chunk_predictor import ProfilingChunkManager + + +class ProfilingChunkScheduler(Scheduler): + """Scheduler with profiling-based dynamic chunk sizing. + + During initialization, the scheduler profiles prefill latency at various + chunk sizes by calling ``profile_prefill_latency`` on each worker via + ``collective_rpc``. A quadratic latency model is then fitted, and during + scheduling the model predicts the optimal chunk size for each waiting + request based on its ``num_computed_tokens``. + """ + + def __init__( + self, + vllm_config: VllmConfig, + kv_cache_config: KVCacheConfig, + structured_output_manager: StructuredOutputManager, + block_size: int, + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, + include_finished_set: bool = False, + log_stats: bool = False, + ) -> None: + super().__init__( + vllm_config, + kv_cache_config, + structured_output_manager, + block_size, + mm_registry=mm_registry, + include_finished_set=include_finished_set, + log_stats=log_stats, + ) + + from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config + + init_ascend_config(vllm_config) + profiling_cfg = get_ascend_config().profiling_chunk_config + base_chunk = self.max_num_scheduled_tokens + + self.profiling_chunk_manager = ProfilingChunkManager( + base_chunk_size=base_chunk, + page_size=self.cache_config.block_size, + smooth_factor=profiling_cfg.smooth_factor, + min_chunk=profiling_cfg.min_chunk, + ) + self._profiling_initialized = False + + logger.info( + "[ProfilingChunk] Scheduler initialized. base_chunk=%d, page_size=%d, smooth_factor=%.2f, min_chunk=%d", + base_chunk, + self.cache_config.block_size, + profiling_cfg.smooth_factor, + profiling_cfg.min_chunk, + ) + + # ------------------------------------------------------------------ + # Profiling initialization + # ------------------------------------------------------------------ + + def run_profiling_chunk_init(self, model_executor) -> None: + """Profile prefill latency using real model forward passes. + + Called by EngineCore after model_executor is ready. Collects latency + samples at different chunk sizes and fits the quadratic model. + """ + if self._profiling_initialized: + return + self._profiling_initialized = True + + if model_executor is None: + logger.warning("[ProfilingChunk] No model_executor provided, skipping profiling") + return + + logger.info("[ProfilingChunk] Running startup profiling with real model forward...") + + seq_lens: list[int] = [] + latencies: list[float] = [] + + base_chunk_size = self.profiling_chunk_manager.base_chunk_size + num_samples = 64 + + # Determine unique_reply_rank for PP setups + rpc_kwargs = self._build_rpc_kwargs(model_executor) + + total_steps = num_samples + 1 + log_interval = max(1, total_steps // 10) + t_start = time.perf_counter() + + for i in range(total_steps): + chunk_size = int(base_chunk_size - (i - 1) * (base_chunk_size / num_samples)) + if chunk_size <= 0: + break + + if i % log_interval == 0 or i == total_steps - 1: + elapsed = time.perf_counter() - t_start + logger.info( + "[ProfilingChunk] Profiling prefill latency: %d/%d samples done (chunk=%d, elapsed=%.1fs)", + max(i - 1, 0), + num_samples, + chunk_size, + elapsed, + ) + + try: + result = model_executor.collective_rpc( + "profile_prefill_latency", + args=(chunk_size,), + **rpc_kwargs, + ) + + # First iteration is warm-up + if i == 0: + continue + + latency_ms = self._extract_latency(result) + if latency_ms is None: + continue + + seq_lens.append(chunk_size) + latencies.append(latency_ms) + + except Exception as e: + logger.debug( + "[ProfilingChunk] Forward failed for chunk=%d: %s", + chunk_size, + e, + ) + continue + + if len(seq_lens) < 8: + logger.warning( + "[ProfilingChunk] Profiling failed: only %d samples collected", + len(seq_lens), + ) + return + + logger.info( + "[ProfilingChunk] Collected %d samples. Latency range: [%.2f, %.2f] ms", + len(seq_lens), + min(latencies), + max(latencies), + ) + + predictor = self.profiling_chunk_manager.predictor + if not predictor.fit(seq_lens, latencies): + return + + predictor.set_target_latency(base_chunk_size) + predictor.is_ready = True + self.profiling_chunk_manager._profiling_done = True + + logger.info("[ProfilingChunk] Profiling completed successfully") + + @staticmethod + def _build_rpc_kwargs(model_executor) -> dict: + """Build kwargs for collective_rpc, handling PP unique_reply_rank.""" + kwargs: dict = {} + if not hasattr(model_executor, "collective_rpc"): + return kwargs + + sig = inspect.signature(model_executor.collective_rpc) + if "unique_reply_rank" not in sig.parameters: + return kwargs + + try: + pc = model_executor.vllm_config.parallel_config + output_rank = pc.world_size - pc.tensor_parallel_size * pc.prefill_context_parallel_size + kwargs["unique_reply_rank"] = output_rank + except AttributeError: + pass + + return kwargs + + @staticmethod + def _extract_latency(result) -> float | None: + """Extract latency value from collective_rpc result.""" + if isinstance(result, (int, float)): + return float(result) + if isinstance(result, list) and len(result) > 0: + return float(result[0]) + return None + + # ------------------------------------------------------------------ + # schedule() override + # ------------------------------------------------------------------ + # The method below is based on the upstream Scheduler.schedule() + # with profiling-based chunk sizing applied to both RUNNING requests + # (chunked prefill continuation) and WAITING requests (new prefill). + # Modified sections are marked with ">>> PROFILING CHUNK" comments. + # ------------------------------------------------------------------ + + def schedule(self) -> SchedulerOutput: # noqa: C901 + scheduled_new_reqs: list[Request] = [] + scheduled_resumed_reqs: list[Request] = [] + scheduled_running_reqs: list[Request] = [] + preempted_reqs: list[Request] = [] + + req_to_new_blocks: dict[str, KVCacheBlocks] = {} + num_scheduled_tokens: dict[str, int] = {} + # >>> PROFILING CHUNK >>> + # NOTE(gjc): We found that the FIA operator has abnormal performance + # when processing multiple request groups in a batch, so the time_budget + # feature is temporarily disabled. It will be enabled again after the + # issues with the FIA operator are resolved. Therefore, in multi-request + # concurrent scenarios, there is still room for performance improvement in CPP. + # time_budget = self.profiling_chunk_manager.predictor.target_latency + time_budget = 0.01 + # <<< PROFILING CHUNK <<< + token_budget = self.max_num_scheduled_tokens + if self._pause_state == PauseState.PAUSED_ALL: + token_budget = 0 + + # Encoder-related. + scheduled_encoder_inputs: dict[str, list[int]] = {} + encoder_compute_budget = self.max_num_encoder_input_tokens + # Spec decode-related. + scheduled_spec_decode_tokens: dict[str, list[int]] = {} + + # For logging. + scheduled_timestamp = time.monotonic() + + self.kv_cache_manager.new_step_starts() + + # First, schedule the RUNNING requests. + req_index = 0 + # >>> PROFILING CHUNK >>> + while req_index < len(self.running) and token_budget > 0 and time_budget > 0: + # <<< PROFILING CHUNK <<< + request = self.running[req_index] + + if ( + request.num_output_placeholders > 0 + and request.num_computed_tokens + 2 - request.num_output_placeholders + >= request.num_prompt_tokens + request.max_tokens + ): + req_index += 1 + continue + + num_new_tokens = ( + request.num_tokens_with_spec + request.num_output_placeholders - request.num_computed_tokens + ) + if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens: + num_new_tokens = self.scheduler_config.long_prefill_token_threshold + num_new_tokens = min(num_new_tokens, token_budget) + + # Make sure the input position does not exceed the max model len. + num_new_tokens = min( + num_new_tokens, + self.max_model_len - 1 - request.num_computed_tokens, + ) + + # Schedule encoder inputs. + encoder_inputs_to_schedule = None + external_load_encoder_input: list[int] = [] + new_encoder_compute_budget = encoder_compute_budget + if request.has_encoder_inputs: + ( + encoder_inputs_to_schedule, + num_new_tokens, + new_encoder_compute_budget, + external_load_encoder_input, + ) = self._try_schedule_encoder_inputs( + request, + request.num_computed_tokens, + num_new_tokens, + encoder_compute_budget, + shift_computed_tokens=1 if self.use_eagle else 0, + ) + + # >>> PROFILING CHUNK: dynamic chunk sizing for RUNNING >>> + if ( + self.profiling_chunk_manager is not None + and self.profiling_chunk_manager.is_ready + and num_new_tokens > 1 + and request.num_computed_tokens > 0 + ): + predicted_chunk = self.profiling_chunk_manager.predict_chunk_size( + num_computed_tokens=request.num_computed_tokens, + target_time=time_budget, + ) + if predicted_chunk is not None and predicted_chunk > 0: + num_new_tokens = min(predicted_chunk, num_new_tokens) + # <<< PROFILING CHUNK <<< + + if self.need_mamba_block_aligned_split: + num_new_tokens = self._mamba_block_aligned_split(request, num_new_tokens) + + if num_new_tokens == 0: + req_index += 1 + continue + + # Schedule newly needed KV blocks for the request. + with record_function_or_nullcontext("schedule: allocate_slots"): + while True: + new_blocks = self.kv_cache_manager.allocate_slots( + request, + num_new_tokens, + num_lookahead_tokens=self.num_lookahead_tokens, + ) + + if new_blocks is not None: + break + + if self.policy == SchedulingPolicy.PRIORITY: + preempted_req = max( + self.running, + key=lambda r: (r.priority, r.arrival_time), + ) + self.running.remove(preempted_req) + if preempted_req in scheduled_running_reqs: + preempted_req_id = preempted_req.request_id + scheduled_running_reqs.remove(preempted_req) + token_budget += num_scheduled_tokens.pop(preempted_req_id) + req_to_new_blocks.pop(preempted_req_id) + scheduled_spec_decode_tokens.pop(preempted_req_id, None) + preempted_encoder_inputs = scheduled_encoder_inputs.pop(preempted_req_id, None) + if preempted_encoder_inputs: + num_embeds_to_restore = sum( + preempted_req.get_num_encoder_embeds(i) for i in preempted_encoder_inputs + ) + encoder_compute_budget += num_embeds_to_restore + req_index -= 1 + else: + preempted_req = self.running.pop() + + self._preempt_request(preempted_req, scheduled_timestamp) + preempted_reqs.append(preempted_req) + if preempted_req == request: + break + + if new_blocks is None: + break + + # Schedule the request. + scheduled_running_reqs.append(request) + request_id = request.request_id + req_to_new_blocks[request_id] = new_blocks + num_scheduled_tokens[request_id] = num_new_tokens + token_budget -= num_new_tokens + time_budget -= self.profiling_chunk_manager.predict_time(num_new_tokens, request.num_computed_tokens) + req_index += 1 + + # Speculative decode related. + if request.spec_token_ids: + num_scheduled_spec_tokens = ( + num_new_tokens + request.num_computed_tokens - request.num_tokens - request.num_output_placeholders + ) + if num_scheduled_spec_tokens > 0: + spec_token_ids = request.spec_token_ids + if len(spec_token_ids) > num_scheduled_spec_tokens: + spec_token_ids = spec_token_ids[:num_scheduled_spec_tokens] + scheduled_spec_decode_tokens[request_id] = spec_token_ids + + request.spec_token_ids = [] + + # Encoder-related. + if encoder_inputs_to_schedule: + scheduled_encoder_inputs[request_id] = encoder_inputs_to_schedule + for i in encoder_inputs_to_schedule: + self.encoder_cache_manager.allocate(request, i) + if self.ec_connector is not None: + self.ec_connector.update_state_after_alloc(request, i) + encoder_compute_budget = new_encoder_compute_budget + if external_load_encoder_input: + for i in external_load_encoder_input: + self.encoder_cache_manager.allocate(request, i) + if self.ec_connector is not None: + self.ec_connector.update_state_after_alloc(request, i) + + # Record the LoRAs in scheduled_running_reqs + scheduled_loras: set[int] = set() + if self.lora_config: + scheduled_loras = set( + req.lora_request.lora_int_id + for req in scheduled_running_reqs + if req.lora_request and req.lora_request.lora_int_id > 0 + ) + assert len(scheduled_loras) <= self.lora_config.max_loras + + # Next, schedule the WAITING requests. + if not preempted_reqs and self._pause_state == PauseState.UNPAUSED: + step_skipped_waiting = create_request_queue(self.policy) + + # >>> PROFILING CHUNK >>> + while (self.waiting or self.skipped_waiting) and token_budget > 0 and time_budget > 0: + # <<< PROFILING CHUNK <<< + if len(self.running) == self.max_num_running_reqs: + break + + request_queue = self._select_waiting_queue_for_scheduling() + assert request_queue is not None + + request = request_queue.peek_request() + request_id = request.request_id + + # Try to promote blocked statuses while traversing skipped queue. + if self._is_blocked_waiting_status(request.status) and not self._try_promote_blocked_waiting_request( + request + ): + if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: + logger.debug( + "%s is still in WAITING_FOR_REMOTE_KVS state.", + request_id, + ) + request_queue.pop_request() + step_skipped_waiting.prepend_request(request) + continue + + # Check that adding the request still respects the max_loras + # constraint. + if ( + self.lora_config + and request.lora_request + and ( + len(scheduled_loras) == self.lora_config.max_loras + and request.lora_request.lora_int_id not in scheduled_loras + ) + ): + request_queue.pop_request() + step_skipped_waiting.prepend_request(request) + continue + + num_external_computed_tokens = 0 + load_kv_async = False + connector_prefix_cache_queries, connector_prefix_cache_hits = 0, 0 + + # Get already-cached tokens. + if request.num_computed_tokens == 0: + new_computed_blocks, num_new_local_computed_tokens = self.kv_cache_manager.get_computed_blocks( + request + ) + + if self.connector is not None: + ext_tokens, load_kv_async = self.connector.get_num_new_matched_tokens( + request, num_new_local_computed_tokens + ) + + if ext_tokens is None: + request_queue.pop_request() + step_skipped_waiting.prepend_request(request) + continue + + request.num_external_computed_tokens = ext_tokens + num_external_computed_tokens = ext_tokens + + connector_prefix_cache_queries = request.num_tokens - num_new_local_computed_tokens + connector_prefix_cache_hits = num_external_computed_tokens + + num_computed_tokens = num_new_local_computed_tokens + num_external_computed_tokens + assert num_computed_tokens <= request.num_tokens + else: + new_computed_blocks = self.kv_cache_manager.empty_kv_cache_blocks + num_new_local_computed_tokens = 0 + num_computed_tokens = request.num_computed_tokens + + encoder_inputs_to_schedule = None + external_load_encoder_input = [] + new_encoder_compute_budget = encoder_compute_budget + + if load_kv_async: + assert num_external_computed_tokens > 0 + num_new_tokens = 0 + else: + num_new_tokens = request.num_tokens - num_computed_tokens + threshold = self.scheduler_config.long_prefill_token_threshold + if 0 < threshold < num_new_tokens: + num_new_tokens = threshold + + # >>> PROFILING CHUNK: dynamic chunk sizing >>> + if ( + self.profiling_chunk_manager is not None + and self.profiling_chunk_manager.is_ready + and num_new_tokens > 1 + and request.num_computed_tokens > 0 + ): + predicted_chunk = self.profiling_chunk_manager.predict_chunk_size( + num_computed_tokens=num_computed_tokens, + target_time=time_budget, + ) + if predicted_chunk is not None and predicted_chunk > 0: + num_new_tokens = min(num_new_tokens, predicted_chunk) + # <<< PROFILING CHUNK <<< + + if not self.scheduler_config.enable_chunked_prefill and num_new_tokens > token_budget: + break + + num_new_tokens = min(num_new_tokens, token_budget) + assert num_new_tokens > 0 + + # Schedule encoder inputs. + if request.has_encoder_inputs: + ( + encoder_inputs_to_schedule, + num_new_tokens, + new_encoder_compute_budget, + external_load_encoder_input, + ) = self._try_schedule_encoder_inputs( + request, + num_computed_tokens, + num_new_tokens, + encoder_compute_budget, + shift_computed_tokens=1 if self.use_eagle else 0, + ) + if num_new_tokens == 0: + break + + if self.need_mamba_block_aligned_split: + num_new_tokens = self._mamba_block_aligned_split( + request, + num_new_tokens, + num_new_local_computed_tokens, + num_external_computed_tokens, + ) + if num_new_tokens == 0: + break + + effective_lookahead_tokens = 0 if request.num_computed_tokens == 0 else self.num_lookahead_tokens + + # Determine if we need to allocate cross-attention blocks. + num_encoder_tokens = 0 + if self.is_encoder_decoder and request.has_encoder_inputs and encoder_inputs_to_schedule: + num_encoder_tokens = sum(request.get_num_encoder_embeds(i) for i in encoder_inputs_to_schedule) + + if self.scheduler_reserve_full_isl and not self.kv_cache_manager.can_fit_full_sequence( + request, + num_new_computed_tokens=num_new_local_computed_tokens, + new_computed_blocks=new_computed_blocks, + num_external_computed_tokens=num_external_computed_tokens, + num_encoder_tokens=num_encoder_tokens, + ): + if request.has_encoder_inputs: + self.encoder_cache_manager.free(request) + break + + new_blocks = self.kv_cache_manager.allocate_slots( + request, + num_new_tokens, + num_new_computed_tokens=num_new_local_computed_tokens, + new_computed_blocks=new_computed_blocks, + num_lookahead_tokens=effective_lookahead_tokens, + num_external_computed_tokens=num_external_computed_tokens, + delay_cache_blocks=load_kv_async, + num_encoder_tokens=num_encoder_tokens, + ) + + if new_blocks is None: + if request.has_encoder_inputs: + self.encoder_cache_manager.free(request) + break + + if self.connector is not None: + self.connector.update_state_after_alloc( + request, + self.kv_cache_manager.get_blocks(request_id), + num_external_computed_tokens, + ) + if self.connector_prefix_cache_stats is not None and connector_prefix_cache_queries != 0: + self.connector_prefix_cache_stats.record( + num_tokens=connector_prefix_cache_queries, + num_hits=connector_prefix_cache_hits, + preempted=request.num_preemptions > 0, + ) + + request = request_queue.pop_request() + if load_kv_async: + request.status = RequestStatus.WAITING_FOR_REMOTE_KVS + step_skipped_waiting.prepend_request(request) + request.num_computed_tokens = num_computed_tokens + continue + + self.running.append(request) + if self.log_stats: + request.record_event(EngineCoreEventType.SCHEDULED, scheduled_timestamp) + if request.status == RequestStatus.WAITING: + scheduled_new_reqs.append(request) + elif request.status == RequestStatus.PREEMPTED: + scheduled_resumed_reqs.append(request) + else: + raise RuntimeError(f"Invalid request status: {request.status}") + + if self.lora_config and request.lora_request: + scheduled_loras.add(request.lora_request.lora_int_id) + req_to_new_blocks[request_id] = self.kv_cache_manager.get_blocks(request_id) + num_scheduled_tokens[request_id] = num_new_tokens + token_budget -= num_new_tokens + time_budget -= self.profiling_chunk_manager.predict_time(num_new_tokens, request.num_computed_tokens) + request.status = RequestStatus.RUNNING + request.num_computed_tokens = num_computed_tokens + if request.num_cached_tokens < 0: + request.num_cached_tokens = num_computed_tokens + if encoder_inputs_to_schedule: + scheduled_encoder_inputs[request_id] = encoder_inputs_to_schedule + for i in encoder_inputs_to_schedule: + self.encoder_cache_manager.allocate(request, i) + if self.ec_connector is not None: + self.ec_connector.update_state_after_alloc(request, i) + encoder_compute_budget = new_encoder_compute_budget + if external_load_encoder_input: + for i in external_load_encoder_input: + self.encoder_cache_manager.allocate(request, i) + if self.ec_connector is not None: + self.ec_connector.update_state_after_alloc(request, i) + + # Re-queue requests skipped in this pass ahead of older skipped items. + if step_skipped_waiting: + self.skipped_waiting.prepend_requests(step_skipped_waiting) + + # Check if the scheduling constraints are satisfied. + total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) + assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens + + assert token_budget >= 0 + assert len(self.running) <= self.max_num_running_reqs + assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len(scheduled_running_reqs) <= len(self.running) + + # Get the longest common prefix among all requests in the running queue. + num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups) + with record_function_or_nullcontext("schedule: get_num_common_prefix_blocks"): + if self.running: + any_request_id = self.running[0].request_id + num_common_prefix_blocks = self.kv_cache_manager.get_num_common_prefix_blocks(any_request_id) + + # Construct the scheduler output. + if self.use_v2_model_runner: + scheduled_new_reqs = scheduled_new_reqs + scheduled_resumed_reqs + scheduled_resumed_reqs = [] + new_reqs_data = [ + NewRequestData.from_request( + req, + req_to_new_blocks[req.request_id].get_block_ids(), + req._all_token_ids, + ) + for req in scheduled_new_reqs + ] + else: + new_reqs_data = [ + NewRequestData.from_request(req, req_to_new_blocks[req.request_id].get_block_ids()) + for req in scheduled_new_reqs + ] + + with record_function_or_nullcontext("schedule: make_cached_request_data"): + cached_reqs_data = self._make_cached_request_data( + scheduled_running_reqs, + scheduled_resumed_reqs, + num_scheduled_tokens, + scheduled_spec_decode_tokens, + req_to_new_blocks, + ) + + self.prev_step_scheduled_req_ids.clear() + self.prev_step_scheduled_req_ids.update(num_scheduled_tokens.keys()) + + new_block_ids_to_zero = ( + (self.kv_cache_manager.take_new_block_ids() or None) if self.needs_kv_cache_zeroing else None + ) + + scheduler_output = SchedulerOutput( + scheduled_new_reqs=new_reqs_data, + scheduled_cached_reqs=cached_reqs_data, + num_scheduled_tokens=num_scheduled_tokens, + total_num_scheduled_tokens=total_num_scheduled_tokens, + scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, + scheduled_encoder_inputs=scheduled_encoder_inputs, + num_common_prefix_blocks=num_common_prefix_blocks, + preempted_req_ids={req.request_id for req in preempted_reqs}, + finished_req_ids=self.finished_req_ids, + free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(), + new_block_ids_to_zero=new_block_ids_to_zero, + ) + + if self.connector is not None: + meta = self._build_kv_connector_meta(self.connector, scheduler_output) + scheduler_output.kv_connector_metadata = meta + + if self.ec_connector is not None: + ec_meta = self.ec_connector.build_connector_meta(scheduler_output) + scheduler_output.ec_connector_metadata = ec_meta + + with record_function_or_nullcontext("schedule: update_after_schedule"): + self._update_after_schedule(scheduler_output) + return scheduler_output diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index 804c98242c3..2634d6257df 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -191,6 +191,34 @@ # Future Plan: # Remove this patch after the upcoming KV cache spec refactor. # +# ** 9. File: platform/patch_profiling_chunk.py** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. `vllm.v1.engine.core.EngineCore.__init__` +# 2. `vllm.v1.engine.core.EngineCoreProc.run_engine_core` +# 3. `Scheduler.update_from_output` (scheduler class, wrapped when profiling chunk is enabled) +# Why: +# Profiling-based dynamic chunk sizing needs to run a one-shot profiling pass +# after `model_executor` is ready, and to feed per-step execution latency back +# into `ProfilingChunkManager` so the history-aware chunk predictor can refine +# online. In multiprocessing `spawn` mode the child process starts a fresh +# interpreter, so monkey-patches applied in the parent are lost unless the +# subprocess entry point re-applies them before any `EngineCore` is created. +# How: +# Replace `EngineCore.__init__` to call `scheduler.run_profiling_chunk_init` +# when present, then wrap `scheduler.update_from_output` once per process to +# read `model_output.execution_time_ms` and `scheduler_output` token/chunk +# metadata and call `ProfilingChunkManager.record_batch_execution_time` (and +# bootstrap target latency for the first chunk when needed). Replace +# `EngineCoreProc.run_engine_core` so importing this module in the child +# re-runs the idempotent patch helper before delegating to the original +# implementation. +# Related PR (if no, explain why): +# No, vllm-ascend-specific profiling / scheduling integration. +# Future Plan: +# Remove or narrow this patch if upstream exposes stable hooks for backend +# profiling startup and per-step timing callbacks without monkey-patching +# `EngineCore` and the multiprocess entry point. +# # * Worker Patch: # =============== # diff --git a/vllm_ascend/patch/platform/patch_profiling_chunk.py b/vllm_ascend/patch/platform/patch_profiling_chunk.py new file mode 100644 index 00000000000..4d33ec64ffc --- /dev/null +++ b/vllm_ascend/patch/platform/patch_profiling_chunk.py @@ -0,0 +1,179 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Patches for profiling-based dynamic chunk sizing. + +This module patches ``EngineCore`` to: +1. Run profiling at startup (after model_executor is ready). +2. Record execution timing after each model step to refine the + history-aware chunk prediction model online. + +In multiprocessing ``spawn`` mode the child process starts a fresh Python +interpreter, so class-level monkey-patches applied in the parent are lost. +To handle this we additionally wrap ``EngineCoreProc.run_engine_core`` +(the subprocess entry-point): when pickle resolves the wrapper it triggers +an import of this module, which re-applies the ``EngineCore.__init__`` +patches inside the child process before any ``EngineCore`` is instantiated. +""" + +from vllm.logger import logger +from vllm.v1.engine.core import EngineCore, EngineCoreProc + +_profiling_patches_applied = False +_original_update_from_output = None + + +# --------------------------------------------------------------------------- +# Helper: record execution timing +# --------------------------------------------------------------------------- + + +def _record_execution_timing(scheduler, scheduler_output, model_output): + """Record execution timing for online model refinement. + + Extracts ``execution_time_ms`` (set dynamically by the NPU model runner) + from the model output and feeds it back to the + ``ProfilingChunkManager`` for incremental fitting of the history-aware + latency model. + """ + profiling_mgr = getattr(scheduler, "profiling_chunk_manager", None) + if profiling_mgr is None or not profiling_mgr.is_ready: + return + + elapsed_time_ms = getattr(model_output, "execution_time_ms", 0.0) + if elapsed_time_ms <= 0: + return + elapsed_time = elapsed_time_ms / 1000.0 + + try: + total_tokens = getattr(scheduler_output, "total_num_scheduled_tokens", 0) + if total_tokens <= 0: + return + + num_scheduled_tokens = getattr(scheduler_output, "num_scheduled_tokens", {}) + request_chunks = [] + + total_hist_tokens = 0 + new_reqs = getattr(scheduler_output, "scheduled_new_reqs", []) + for req in new_reqs: + req_id = getattr(req, "request_id", None) or getattr(req, "req_id", None) + if req_id and req_id in num_scheduled_tokens: + chunk_size = num_scheduled_tokens[req_id] + hist_seq_len = getattr(req, "num_computed_tokens", 0) + total_hist_tokens += hist_seq_len + if chunk_size > 0: + request_chunks.append((chunk_size, hist_seq_len)) + + cached_reqs = getattr(scheduler_output, "scheduled_cached_reqs", None) + if cached_reqs is not None: + req_ids = getattr(cached_reqs, "req_ids", []) + computed_tokens_list = getattr(cached_reqs, "num_computed_tokens", []) + for i, req_id in enumerate(req_ids): + if req_id in num_scheduled_tokens: + chunk_size = num_scheduled_tokens[req_id] + hist_seq_len = computed_tokens_list[i] if i < len(computed_tokens_list) else 0 + total_hist_tokens += hist_seq_len + if chunk_size > 0: + request_chunks.append((chunk_size, hist_seq_len)) + + # is first chunk processing + if total_hist_tokens == 0 and not profiling_mgr._set_time_done: + profiling_mgr.predictor.set_target_latency(0, elapsed_time * 1000) + profiling_mgr._set_time_done = True + + if not request_chunks: + request_chunks = [(total_tokens, 0)] + + if not profiling_mgr.predictor.history_fitted: + profiling_mgr.record_batch_execution_time(request_chunks, elapsed_time) + + except (AttributeError, TypeError) as e: + logger.debug("Failed to record execution timing: %s", e) + + +# --------------------------------------------------------------------------- +# Helper: wrap scheduler.update_from_output for timing +# --------------------------------------------------------------------------- + + +def _ensure_update_from_output_wrapped(scheduler): + """Wrap scheduler.update_from_output to record execution timing.""" + global _original_update_from_output + if _original_update_from_output is not None: + return + if not hasattr(scheduler, "profiling_chunk_manager"): + return + + cls = type(scheduler) + _original_update_from_output = cls.update_from_output + + def _wrapped_update_from_output(self, scheduler_output, model_output): + _record_execution_timing(self, scheduler_output, model_output) + return _original_update_from_output(self, scheduler_output, model_output) + + cls.update_from_output = _wrapped_update_from_output + + +# --------------------------------------------------------------------------- +# Core: apply EngineCore.__init__ patches (idempotent) +# --------------------------------------------------------------------------- + + +def _apply_profiling_patches(): + """Patch ``EngineCore.__init__`` to trigger profiling and timing hooks. + + Safe to call multiple times; the guard ``_profiling_patches_applied`` + ensures the patch is applied at most once per process. + """ + global _profiling_patches_applied + if _profiling_patches_applied: + return + _profiling_patches_applied = True + + original_init = EngineCore.__init__ + + def _patched_engine_core_init(self, *args, **kwargs): + original_init(self, *args, **kwargs) + + if hasattr(self.scheduler, "run_profiling_chunk_init"): + logger.info("[ProfilingChunk] Running profiling initialization...") + self.scheduler.run_profiling_chunk_init(self.model_executor) + + _ensure_update_from_output_wrapped(self.scheduler) + + EngineCore.__init__ = _patched_engine_core_init + + +# --------------------------------------------------------------------------- +# 1. Apply patches at module level for the InprocClient (in-process) path. +# --------------------------------------------------------------------------- +_apply_profiling_patches() + +# --------------------------------------------------------------------------- +# 2. Wrap EngineCoreProc.run_engine_core so that spawned subprocesses +# re-apply the patches. When the child unpickles this wrapper it +# imports this module, which triggers _apply_profiling_patches() above, +# ensuring EngineCore.__init__ is patched before any instance is created. +# --------------------------------------------------------------------------- +_original_run_engine_core = EngineCoreProc.run_engine_core + + +def _patched_run_engine_core(*args, **kwargs): + _apply_profiling_patches() + return _original_run_engine_core(*args, **kwargs) + + +EngineCoreProc.run_engine_core = _patched_run_engine_core diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index e92f686dc0e..05ce9407998 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -446,6 +446,13 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: vllm_config.scheduler_config.enable_chunked_prefill = True vllm_config.scheduler_config.SLO_limits_for_dynamic_batch = ascend_config.SLO_limits_for_dynamic_batch + # Use ProfilingChunkScheduler when profiling-based chunk sizing is on. + if ascend_config.profiling_chunk_config.enabled: + vllm_config.scheduler_config.scheduler_cls = ( + "vllm_ascend.core.scheduler_profiling_chunk.ProfilingChunkScheduler" + ) + import vllm_ascend.patch.platform.patch_profiling_chunk # noqa + cp_size = parallel_config.decode_context_parallel_size * parallel_config.prefill_context_parallel_size if ( vllm_config.kv_transfer_config is not None diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index ddde07b0b63..e5b8850cc15 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -19,6 +19,7 @@ import math import sys +import time from collections import defaultdict from contextlib import contextmanager, nullcontext from copy import copy, deepcopy @@ -1244,6 +1245,10 @@ def execute_model( capturer.clear_buffer() else: logger.warning("RoutedExpertsCapturer is not initialized.") + + if self.ascend_config.profiling_chunk_config.enabled: + self._sync_device() + self._execution_start_time = time.perf_counter() if self.execute_model_state is not None: raise RuntimeError("State error: sample_tokens() must be called after execute_model() returns None.") # self._draft_token_ids is None when `input_fits_in_drafter=False` @@ -1720,6 +1725,9 @@ def propose_draft_token_ids(sampled_token_ids): ec_connector_output=ec_connector_output if self.supports_mm_inputs else None, cudagraph_stats=cudagraph_stats, ) + if self.ascend_config.profiling_chunk_config.enabled and hasattr(self, '_execution_start_time'): + self._sync_device() + model_runner_output.execution_time_ms = (time.perf_counter() - self._execution_start_time) * 1000.0 if self.dynamic_eplb: with record_function_or_nullcontext("EPLB update"): @@ -1971,6 +1979,36 @@ def _pad_for_sequence_parallelism(self, num_scheduled_tokens: int) -> int: return round_up(num_scheduled_tokens, tp_size) return num_scheduled_tokens + # This is a function from the upstream vllm used to handle PP+SP. Since the judgment logic + # of flashcomm1 in Ascend is inconsistent with SP in vllm, it needs to be overridden. + def sync_and_slice_intermediate_tensors( + self, + num_tokens: int, + intermediate_tensors: IntermediateTensors | None, + sync_self: bool, + ) -> IntermediateTensors: + assert self.intermediate_tensors is not None + tp = self.vllm_config.parallel_config.tensor_parallel_size + + # When sequence parallelism is enabled, the "residual" tensor is sharded + # across tensor parallel ranks, so each rank only needs its own slice. + if sync_self: + assert intermediate_tensors is not None + for k, v in intermediate_tensors.items(): + copy_len = (num_tokens + tp - 1) // tp if enable_sp() else num_tokens + self.intermediate_tensors[k][:copy_len].copy_( + v[:copy_len], non_blocking=True + ) + + return IntermediateTensors( + { + k: v[: (num_tokens + tp - 1) // tp] + if enable_sp() + else v[:num_tokens] + for k, v in self.intermediate_tensors.items() + } + ) + def _sync_batch_across_dp( self, num_tokens_padded: int | None = None, @@ -2409,6 +2447,7 @@ def _dummy_run( is_graph_capturing: bool = False, num_active_loras: int = 0, profile_seq_lens: int | None = None, + profile_cpp: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: # only support eager mode and piecewise graph now assert cudagraph_runtime_mode is None or cudagraph_runtime_mode.valid_runtime_modes() @@ -2439,6 +2478,9 @@ def _dummy_run( num_scheduled_tokens_list = [max_query_len] * num_reqs if num_tokens % max_query_len != 0: num_scheduled_tokens_list[-1] = num_tokens % max_query_len + elif profile_cpp: + num_reqs = 1 + num_scheduled_tokens_list = [num_tokens] * num_reqs else: num_reqs = min(num_tokens, max_num_reqs) min_tokens_per_req = num_tokens // num_reqs @@ -2461,7 +2503,7 @@ def _dummy_run( max_num_scheduled_tokens=max_query_len, use_cascade_attn=False, allow_microbatching=allow_microbatching, - force_eager=is_profile or (cudagraph_runtime_mode == CUDAGraphMode.NONE), + force_eager=is_profile or (cudagraph_runtime_mode == CUDAGraphMode.NONE) or profile_cpp, # `force_uniform_decode` is used for cudagraph capture; because for # capturing mixed prefill-decode batches, we sometimes use # num_tokens == num_reqs which looks like a uniform decode batch to the @@ -2532,9 +2574,11 @@ def _dummy_run( num_scheduled_tokens, self.query_pos.np) self.query_start_loc.np[1 : num_reqs_padded + 1] = cum_num_tokens self.query_start_loc.copy_to_gpu() - num_reqs_padded = self._pad_query_start_loc_for_fia( - num_tokens_padded, num_reqs_padded, num_reqs, cudagraph_runtime_mode, batch_desc.num_reqs - ) + + if not profile_cpp: + num_reqs_padded = self._pad_query_start_loc_for_fia( + num_tokens_padded, num_reqs_padded, num_reqs, cudagraph_runtime_mode, batch_desc.num_reqs + ) pad_attn = cudagraph_runtime_mode == CUDAGraphMode.FULL attn_metadata, _ = self._build_attention_metadata( diff --git a/vllm_ascend/worker/worker.py b/vllm_ascend/worker/worker.py index 50af9526547..90e7ecf1889 100644 --- a/vllm_ascend/worker/worker.py +++ b/vllm_ascend/worker/worker.py @@ -483,6 +483,68 @@ def _warm_up_atb(self): def get_model(self) -> nn.Module: return self.model_runner.get_model() + @torch.inference_mode() + def profile_prefill_latency(self, num_tokens: int) -> float: + """ + Profile prefill latency for a given number of tokens. + + This runs a real model forward pass and measures the execution time. + Used for profiling-based dynamic chunk sizing. + + In PP (Pipeline Parallelism) mode: + - All workers execute the forward pass to stay synchronized + - Only the timing from PP0 (first rank) is meaningful for scheduling + - PP0 includes all the pipeline stages' latency when using async scheduling + + Args: + num_tokens: Number of tokens to profile + + Returns: + Latency in milliseconds + """ + import time + + # Clamp to valid range + num_tokens = min(num_tokens, self.scheduler_config.max_num_batched_tokens) + num_tokens = max(num_tokens, 1) + + # Synchronize all devices before timing + # This ensures clean measurement in PP/TP scenarios + torch.npu.synchronize() + + # In PP mode, we still run on all ranks to keep them synchronized + # but only the first rank's timing is used for scheduling decisions + is_first_pp_rank = get_pp_group().is_first_rank + + start = time.perf_counter() + + # Run real model forward with force_attention=True + # This ensures attention is actually executed, not skipped. + # Without force_attention, attn_metadata may be None and attention + # won't run, making profiling results inaccurate. + # _dummy_run handles PP internally (intermediate tensors, etc.) + self.model_runner._dummy_run( + num_tokens=num_tokens, + force_attention=True, # Critical: ensure attention is executed + profile_cpp=True, + ) + + # Synchronize after forward to ensure NPU operations complete + torch.npu.synchronize() + + latency_ms = (time.perf_counter() - start) * 1000 + + # Log for debugging in PP mode + if not is_first_pp_rank: + logger.debug( + "[ProfilingChunk] PP rank %d: profiled %d tokens, latency=%.2f ms (not used)", + get_pp_group().rank_in_group, + num_tokens, + latency_ms, + ) + + return latency_ms + def get_kv_connector_handshake_metadata(self) -> dict | None: """Get KV connector metadata from this worker if available.""" if not has_kv_transfer_group():