diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index d7550bf11b7..401e0b27af0 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -143,47 +143,25 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: "Non-MLA LLMs forcibly disable the chunked prefill feature," "as the performance of operators supporting this feature " "functionality is currently suboptimal.") - if vllm_version_is("0.11.0"): - if not model_config.is_multimodal_model and \ - structured_outputs_config.backend == "auto" and \ - not scheduler_config.send_delta_data and \ - not getattr(scheduler_config, "scheduler_delay_factor", 0) > 0 and \ - scheduler_config.policy == "fcfs": - ascend_scheduler_config.enabled = True - chunked_prefill_enabled_in_ascend_scheduler = getattr( - ascend_scheduler_config, "enable_chunked_prefill", - False) - if chunked_prefill_enabled_in_ascend_scheduler: - logger.warning( - "Chunked prefill feature is enabled in ascend_scheduler," - "but note that the operator supporting this feature " - "would lead to performance degradation.") - # In this situation, max_num_batched_tokens would have been rewritten. - # So we must make sure max_num_batched_tokens is not smaller than max_model_len. - if (scheduler_config.max_num_batched_tokens - < scheduler_config.max_model_len and - not chunked_prefill_enabled_in_ascend_scheduler): - scheduler_config.max_num_batched_tokens = scheduler_config.max_model_len - else: - if not model_config.is_multimodal_model and \ - structured_outputs_config.backend == "auto" and \ - not getattr(scheduler_config, "scheduler_delay_factor", 0) > 0 and \ - scheduler_config.policy == "fcfs": - ascend_scheduler_config.enabled = True - chunked_prefill_enabled_in_ascend_scheduler = getattr( - ascend_scheduler_config, "enable_chunked_prefill", - False) - if chunked_prefill_enabled_in_ascend_scheduler: - logger.warning( - "Chunked prefill feature is enabled in ascend_scheduler," - "but note that the operator supporting this feature " - "would lead to performance degradation.") - # In this situation, max_num_batched_tokens would have been rewritten. - # So we must make sure max_num_batched_tokens is not smaller than max_model_len. - if (scheduler_config.max_num_batched_tokens - < scheduler_config.max_model_len and - not chunked_prefill_enabled_in_ascend_scheduler): - scheduler_config.max_num_batched_tokens = scheduler_config.max_model_len + if not model_config.is_multimodal_model and \ + structured_outputs_config.backend == "auto" and \ + not getattr(scheduler_config, "send_delta_data", False) and \ + not getattr(scheduler_config, "scheduler_delay_factor", 0) > 0 and \ + scheduler_config.policy == "fcfs": + ascend_scheduler_config.enabled = True + chunked_prefill_enabled_in_ascend_scheduler = getattr( + ascend_scheduler_config, "enable_chunked_prefill", False) + if chunked_prefill_enabled_in_ascend_scheduler: + logger.warning( + "Chunked prefill feature is enabled in ascend_scheduler," + "but note that the operator supporting this feature " + "would lead to performance degradation.") + # In this situation, max_num_batched_tokens would have been rewritten. + # So we must make sure max_num_batched_tokens is not smaller than max_model_len. + if (scheduler_config.max_num_batched_tokens + < scheduler_config.max_model_len + and not chunked_prefill_enabled_in_ascend_scheduler): + scheduler_config.max_num_batched_tokens = scheduler_config.max_model_len kv_cache_dtype = vllm_config.additional_config.get( "kv_cache_dtype", None) diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 74e2917806b..c1be6ed65b7 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -63,9 +63,8 @@ def __init__(self, == CompilationMode.VLLM_COMPILE and not self.vllm_config.model_config.enforce_eager) - self.cudagraph_batch_sizes = list( - reversed( - self.vllm_config.compilation_config.cudagraph_capture_sizes)) + self.cudagraph_batch_sizes = sorted( + self.vllm_config.compilation_config.cudagraph_capture_sizes) # persistent buffers for cuda graph self.input_ids = torch.zeros(