Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/user_guide/configuration/additional_config.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ The details of each config option are as follows:
| `enabled` | bool | `False` | Whether to enable ascend scheduler for V1 engine|
| `enable_pd_transfer` | bool | `False` | Whether to enable pd transfer. When using it, decode is started only when prefill of all requests is done. This option only takes effects on offline inference. |
| `decode_max_num_seqs` | int | `0` | Whether to change max_num_seqs of decode phase when enable pd transfer. This option only takes effects when enable_pd_transfer is True. |
| `max_long_partial_prefills` | Union[int, float] | `float('inf')` | the maximum number of prompts longer than long_prefill_token_threshold that will be prefilled concurrently. |
| `long_prefill_token_threshold` | Union[int, float] | `float('inf')` | a request is considered long if the prompt is longer than this number of tokens. |
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these two config is supporeted by vLLM by default. So we don't need to add them here. See L66


ascend_scheduler_config also support the options from [vllm scheduler config](https://docs.vllm.ai/en/stable/api/vllm/config.html#vllm.config.SchedulerConfig). For example, you can add `enable_chunked_prefill: True` to ascend_scheduler_config as well.

Expand All @@ -79,6 +81,8 @@ An example of additional configuration is as follows:
"ascend_scheduler_config": {
"enabled": True,
"enable_chunked_prefill": True,
"max_long_partial_prefills": 1,
"long_prefill_token_threshold": 4096,
},
"multistream_overlap_shared_expert": True,
"refresh": False,
Expand Down
4 changes: 4 additions & 0 deletions tests/ut/core/test_schedule_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ def test_initialize_from_config_with_override(self):
scheduler_cls="vllm_ascend.core.scheduler.AscendScheduler",
max_num_batched_tokens=2048,
max_model_len=2048,
max_long_partial_prefills=1,
long_prefill_token_threshold=512,
),
)
self.assertEqual(ascend_config.enable_chunked_prefill, False)
Expand All @@ -59,6 +61,8 @@ def test_initialize_from_config_with_override(self):
"vllm_ascend.core.scheduler.AscendScheduler")
self.assertEqual(ascend_config.max_num_batched_tokens, 2048)
self.assertEqual(ascend_config.encoder_cache_size, 2048)
self.assertEqual(ascend_config.max_long_partial_prefills, 1)
self.assertEqual(ascend_config.long_prefill_token_threshold, 512)

def test_not_implemented_policy(self):
with self.assertRaises(NotImplementedError) as context:
Expand Down
23 changes: 22 additions & 1 deletion tests/ut/core/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def test_get_num_unfinished_requests(self):
len(requests) - i - 1)

def test_schedule(self):
'''Test scheduling.
'''Test scheduling.
Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs
'''
scheduler = self.create_scheduler()
Expand Down Expand Up @@ -279,6 +279,27 @@ def test_schedule_multimodal_requests(self):
for i, request in enumerate(requests):
self.assertEqual(scheduler.running[i], request)

def test_concurrent_partial_prefills_schedule(self):
'''Test concurrent partial prefills scheduling.
total requests = 10, every request has 10 token.
while set long_prefill_token_threshold = 1, scheduler can
only schedule max_long_partial_prefills long request.
'''
scheduler = self.create_scheduler()
scheduler.scheduler_config.chunked_prefill_enabled = False
scheduler.scheduler_config.max_long_partial_prefills = 2
scheduler.scheduler_config.long_prefill_token_threshold = 1
requests = create_requests(num_requests=10, num_tokens=20)
for request in requests:
scheduler.add_request(request)

# Test initial scheduling
output = scheduler.schedule()
self.assertEqual(len(output.scheduled_new_reqs),
scheduler.scheduler_config.max_long_partial_prefills)
self.assertEqual(output.scheduled_cached_reqs.num_reqs, 0)
self.assertEqual(len(output.finished_req_ids), 0)

def test_schedule_enable_prefix_caching(self):
'''Test scheduling.
Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs
Expand Down
28 changes: 28 additions & 0 deletions vllm_ascend/core/schedule_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,14 @@

from vllm.config import SchedulerConfig

MAX_INT = 2147483647


@dataclass
class AscendSchedulerConfig(SchedulerConfig):
enable_chunked_prefill: bool = False
max_long_partial_prefills: int = MAX_INT
long_prefill_token_threshold: int = MAX_INT
policy: str = "fcfs"
scheduler_cls: Union[str, Type[object]] = (
"vllm_ascend.core.scheduler.AscendScheduler")
Expand All @@ -42,6 +46,8 @@ def initialize_from_config(
}
# Override default values into original SchedulerConfig
scheduler_config["enable_chunked_prefill"] = False
scheduler_config["max_long_partial_prefills"] = None
scheduler_config["long_prefill_token_threshold"] = None
scheduler_config["policy"] = "fcfs"
scheduler_config["scheduler_cls"] = (
"vllm_ascend.core.scheduler.AscendScheduler")
Expand All @@ -67,6 +73,28 @@ def __post_init__(self) -> None:
"max_num_batched_tokens and makes vLLM reject longer "
"sequences. Please increase max_num_batched_tokens or "
"decrease max_model_len.")
# concurrent partial prefills. Default is inf
if self.max_long_partial_prefills is None:
self.max_long_partial_prefills = MAX_INT
self.long_prefill_token_threshold = MAX_INT

if self.long_prefill_token_threshold is None or \
self.long_prefill_token_threshold <= 0:
if self.max_model_len is None:
self.long_prefill_token_threshold = MAX_INT
else:
self.long_prefill_token_threshold = \
max(1, int(self.max_model_len * 0.04))

if self.max_long_partial_prefills < 0:
raise ValueError(
f"max_long_partial_prefills must be non-negative, but got "
f"{self.max_long_partial_prefills}")
if self.long_prefill_token_threshold < 0:
raise ValueError(
f"long_prefill_token_threshold must be non-negative, but got "
f"{self.long_prefill_token_threshold}")

if self.policy != "fcfs":
raise NotImplementedError(
f"currently AscendScheduler only supports fcfs policy, got {self.policy}"
Expand Down
15 changes: 15 additions & 0 deletions vllm_ascend/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,14 @@ def schedule(self) -> SchedulerOutput:
# all request prefilled, change phase to decode
if not self.waiting and not self.running:
self.phase = "decode"
# Skip long prompt requests in prefill stage.
# long_prefill_budget is float('inf') if not use.
if self.vllm_config.scheduler_config.long_prefill_token_threshold == 0:
long_prefill_budget = float('inf')
long_prefill_token_threshold = float('inf')
else:
long_prefill_budget = self.vllm_config.scheduler_config.max_long_partial_prefills
long_prefill_token_threshold = self.vllm_config.scheduler_config.long_prefill_token_threshold

# Schedule prefill requests first.
while self.waiting and token_budget > 0:
Expand Down Expand Up @@ -217,6 +225,11 @@ def skip_cur_request():
skip_cur_request()
continue

if num_new_tokens > long_prefill_token_threshold \
and long_prefill_budget <= 0:
skip_cur_request()
continue

new_blocks = self.kv_cache_manager.allocate_slots(
request,
num_new_tokens + num_external_computed_tokens,
Expand Down Expand Up @@ -268,6 +281,8 @@ def skip_cur_request():
# Update request info.
num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens
if num_new_tokens > long_prefill_token_threshold:
long_prefill_budget -= 1
request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens
# Count the number of prefix cached tokens.
Expand Down
Loading