Skip to content

[Frontend] Add server load limit with --max-server-load parameter#22805

Closed
scratch-ml wants to merge 8 commits intovllm-project:mainfrom
scratch-ml:feature/server-load-limit
Closed

[Frontend] Add server load limit with --max-server-load parameter#22805
scratch-ml wants to merge 8 commits intovllm-project:mainfrom
scratch-ml:feature/server-load-limit

Conversation

@scratch-ml
Copy link
Copy Markdown
Contributor

@scratch-ml scratch-ml commented Aug 13, 2025

Purpose

This PR adds server load limiting functionality to vLLM's OpenAI API server to prevent server overload in production environments.

Problem: Production vLLM deployments can become overwhelmed with too many concurrent requests, leading to poor performance, resource exhaustion, or server crashes. Currently, there's no built-in mechanism to limit concurrent requests and gracefully handle overload situations.

Solution: Add a new --max-server-load parameter that works with the existing --enable-server-load-tracking feature to gracefully reject requests when the server reaches its capacity limit.

Changes Made:

  • Added max_server_load: Optional[int] = None parameter to FrontendArgs class in cli_args.py
  • Initialize max_server_load state in init_app_state function in api_server.py
  • Enhanced load_aware_call decorator in utils.py with load checking logic
  • Return HTTP 503 with detailed error message when server is overloaded
  • Comprehensive test coverage for all scenarios in test_server_load_limit.py

Benefits:

  • Production Safety: Prevents server crashes due to excessive load
  • Resource Protection: Maintains service quality under high demand
  • Monitoring Integration: Works seamlessly with existing load tracking infrastructure
  • User Experience: Provides clear feedback when service is unavailable
  • Zero Breaking Changes: Fully backward compatible

Test Plan

Unit Tests

# Run the comprehensive test suite for the new functionality
python -m pytest tests/entrypoints/openai/test_server_load_limit.py -v

# Run related existing tests to ensure no regression
python -m pytest tests/entrypoints/openai/test_basic.py -v

Test Result

All unit test cases passed successfully, validating the server load limiting functionality. Load counter management works correctly during both normal operations and exception scenarios, ensuring accurate tracking of concurrent requests.

(Optional) Documentation Update

No documentation files require updates as this is an internal server feature. The functionality is self-documenting through:

  • CLI Help: Parameter automatically appears in --help output with descriptive text
  • Error Messages: Clear, actionable error responses returned to clients
  • Code Comments: Comprehensive inline documentation explaining the logic
  • Type Annotations: Full type hints for IDE support and code clarity

@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the frontend label Aug 13, 2025
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a valuable server load limiting feature to prevent server overload. The implementation is mostly correct and well-tested for single-threaded scenarios. However, I've identified a critical race condition in the load checking logic that could allow the server to exceed the configured maximum load under concurrent requests. Additionally, there's a potential for an AttributeError if the server load metric is not yet initialized. My review comment details these issues and suggests a path to resolution using asyncio.Lock to ensure atomicity.

- Add max_server_load parameter to FrontendArgs for setting concurrent request limit
- Initialize max_server_load state in init_app_state function
- Add load checking logic in load_aware_call decorator
- Return HTTP 503 with detailed error message when server is overloaded
- Only effective when --enable-server-load-tracking is enabled
- Add comprehensive tests for new functionality

This feature prevents server overload in production deployments by allowing
administrators to set a maximum number of concurrent requests. When the limit
is exceeded, new requests receive HTTP 503 responses with clear error messages.

Signed-off-by: scratch-ml <limingliang0527@gmail.com>
@scratch-ml scratch-ml force-pushed the feature/server-load-limit branch from 32c80d6 to c8917c2 Compare August 13, 2025 09:57
Signed-off-by: scratch-ml <limingliang0527@gmail.com>
Signed-off-by: scratch-ml <limingliang0527@gmail.com>
@DarkLight1337
Copy link
Copy Markdown
Member

Production vLLM deployments can become overwhelmed with too many concurrent requests, leading to poor performance, resource exhaustion, or server crashes

Could you elaborate on how this could happen? From my understanding, vLLM queues the excess requests and sends requests into the engine only when the engine can take the next batch, so the engine core should not crash from having too many requests at once.

@scratch-ml
Copy link
Copy Markdown
Contributor Author

Could you elaborate on how this could happen? From my understanding, vLLM queues the excess requests and sends requests into the engine only when the engine can take the next batch, so the engine core should not crash from having too many requests at once.

@DarkLight1337 In our multimodal scenario services, we have observed that requests containing multiple images can cause the vLLM inference engine (version: 0.8.5) to crash. We implement the --max-image-num parameter to control the maximum number of images per inference request. When the batch size is 1, a single request with up to 50 images can be processed normally; however, when the batch size exceeds 1, the "Engine Dead" issue becomes reproducible. The trace log at the time of failure is as follows:

[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404] Traceback (most recent call last):
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]   File "/usr/local/lib/python3.10/dist-packages/vllm/v1/engine/core.py", line 395, in run_engine_core"
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]     engine_core.run_busy_loop()
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]   File "/usr/local/lib/python3.10/dist-packages/vllm/v1/engine/core.py", line 664, in run_busy_loop"
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]     self._process_engine_step()
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]   File "/usr/local/lib/python3.10/dist-packages/vllm/v1/engine/core.py", line 444, in _process_engine_step"
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]     outputs = self.step_fn()
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]   File "/usr/local/lib/python3.10/dist-packages/vllm/v1/engine/core.py", line 203, in step"
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]     output = self.model_executor.execute_model(scheduler_output)
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]   File "/usr/local/lib/python3.10/dist-packages/vllm/v1/executor/abstract.py", line 86, in execute_model"
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]     output = self.collective_rpc("execute_model","
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]   File "/usr/local/lib/python3.10/dist-packages/vllm/executor/uniproc_executor.py", line 61, in collective_rpc"
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]     answer = run_method(self.driver_worker, method, args, kwargs)"
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]   File "/usr/local/lib/python3.10/dist-packages/vllm/utils.py", line 2456, in run_method"
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]     return func(*args, **kwargs)"
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context"
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]     return func(*args, **kwargs)"
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]   File "/usr/local/lib/python3.10/dist-packages/vllm/v1/worker/gpu_worker.py", line 268, in execute_model"
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]     output = self.model_runner.execute_model(scheduler_output)
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context"
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]     return func(*args, **kwargs)"
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]   File "/usr/local/lib/python3.10/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 1053, in execute_model"
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]     self._execute_mm_encoder(scheduler_output)
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]   File "/usr/local/lib/python3.10/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 890, in _execute_mm_encoder"
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]     curr_group_outputs = self.model.get_multimodal_embeddings(
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/mm_step1o.py", line 744, in get_multimodal_embeddings"
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]     vision_embeddings = self._process_image_input(image_input)
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/mm_step1o.py", line 716, in _process_image_input"
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]     image_features = self._get_vision_model_output(image_input["pixel_values"])"
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/mm_step1o.py", line 707, in _get_vision_model_output"
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]     return self.vision_model(input_tensor)[0][:, 4:]"
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl"
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]     return self._call_impl(*args, **kwargs)"
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl"
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]     return forward_call(*args, **kwargs)"
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/step_encoder.py", line 367, in forward"
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]     return self.vision_model(pixel_values=pixel_values)
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl"
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]     return self._call_impl(*args, **kwargs)"
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl"
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]     return forward_call(*args, **kwargs)"
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/step_encoder.py", line 340, in forward"
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]     hidden_states = self.transformer(inputs_embeds=hidden_states)
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl"
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]     return self._call_impl(*args, **kwargs)"
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl"
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]     return forward_call(*args, **kwargs)"
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/step_encoder.py", line 309, in forward"
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]     hidden_states = encoder_layer(hidden_states, )"
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl"
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]     return self._call_impl(*args, **kwargs)"
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl"
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]     return forward_call(*args, **kwargs)"
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/step_encoder.py", line 275, in forward"
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]     out = h + self.layer_norm2(self.mlp(h))
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl"
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]     return self._call_impl(*args, **kwargs)"
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl"
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]     return forward_call(*args, **kwargs)"
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/step_encoder.py", line 239, in forward"
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]     hidden_states = self.activation_fn(hidden_states)
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1740, in _wrapped_call_impl"
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]     return self._call_impl(*args, **kwargs)"
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl"
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]     return forward_call(*args, **kwargs)"
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/custom_op.py", line 25, in forward"
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]     return self._forward_method(*args, **kwargs)"
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/layers/activation.py", line 264, in forward_cuda"
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404]     out = torch.empty_like(x)
[1;36m(EngineCore_5 pid=594)[0;0m ERROR 08-08 18:02:37 [core.py:404] torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 3.64 GiB. GPU 5 has a total capacity of 94.99 GiB of which 3.48 GiB is free. Process 2800536 has 320.00 MiB memory in use. Process 1113158 has 91.17 GiB memory in use. Of the allocated memory 83.77 GiB is allocated by PyTorch, with 1.00 GiB allocated in private pools (e.g., CUDA Graphs), and 5.04 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)"
ERROR 08-08 18:02:37 [async_llm.py:405] AsyncLLM output_handler failed.
ERROR 08-08 18:02:37 [async_llm.py:405] Traceback (most recent call last):
ERROR 08-08 18:02:37 [async_llm.py:405]   File "/usr/local/lib/python3.10/dist-packages/vllm/v1/engine/async_llm.py", line 363, in output_handler"
ERROR 08-08 18:02:37 [async_llm.py:405]     outputs = await engine_core.get_output_async()
ERROR 08-08 18:02:37 [async_llm.py:405]   File "/usr/local/lib/python3.10/dist-packages/vllm/v1/engine/core_client.py", line 785, in get_output_async"
ERROR 08-08 18:02:37 [core.py:404]     raise self._format_exception(outputs) from None
ERROR 08-08 18:02:37 [async_llm.py:405] vllm.v1.engine.exceptions.EngineDeadError: EngineCore encountered an issue. See stack trace (above) for the root cause.
INFO 08-08 18:02:37 [async_llm.py:330] Request chatcmpl-1942086d7c0a58fa53a8f5aeff67d58f.9a57ac0db61d6531615d93fca2d9e389 failed (engine dead).
INFO 08-08 18:02:37 [async_llm.py:330] Request chatcmpl-a3a5453f1c9aad73c212e8012fea0ee3.1ac50344cdff393bbab56fccb44fb50b failed (engine dead).

Therefore, in multimodal scenarios, the length of individual requests needs to be controlled, and in fact, the quantity of requests is also crucial.

Fundamentally, HTTP services require the capability to enforce concurrency limits. In most production environments, sequence length distribution typically fluctuates minimally, making concurrency volume a critical factor affecting service quality. Consequently, when concurrency is high, actively rejecting requests serves to protect the quality of service for in-flight requests on high-load instances. Rejected requests can be promptly rescheduled rather than queuing in a waiting or pending state.

The above are my humble thoughts. If there are any inaccuracies, please feel free to point them out. Thank you for reading.

@DarkLight1337
Copy link
Copy Markdown
Member

DarkLight1337 commented Aug 13, 2025

If OOM occurs during inference, it signals that there is a bug inside vLLM. Which model were you using and how did you serve it? The memory profiling has gotten better since v0.8.5 so if you upgrade vLLM, this problem should not occur anymore.

@scratch-ml
Copy link
Copy Markdown
Contributor Author

If OOM occurs during inference, it signals that there is a bug inside vLLM. Which model were you using and how did you serve it? The memory profiling has gotten better since v0.8.5 so if you upgrade vLLM, this problem should not occur anymore.

The model is a closed-source model, and the command we use to start the service is vllm serve /path-to-model --served-model-name online-service --max-model-len 65536 --max-num-seqs 128 --no-enable-prefix-caching --quantization groupwise-quant --tensor-parallel-size 1 --data-parallel-size 8 --max-num-batched-tokens 8192 --enforce-eager --enable-auto-tool-choice --gpu-memory-utilization 0.95 --trust-remote-code --disable-mm-preprocessor-cache
I will open a new issue to describe this problem in detail and attempt to upgrade vLLM version.

Setting aside the engine crash issue, I still believe this new --max-server-load parameter is valuable, especially in production environments. Do you also agree with this? Or do you have any suggestions for improvement regarding this PR?

Thank you for your help~

@DarkLight1337
Copy link
Copy Markdown
Member

I think it's better to add a layer on top of vLLM that queries /load endpoint in order to perform the load balancing/limiting. Perhaps @robertgshaw2-redhat and @njhill could provide better suggestions regarding this.

Signed-off-by: scratch-ml <limingliang0527@gmail.com>
@scratch-ml
Copy link
Copy Markdown
Contributor Author

scratch-ml commented Aug 13, 2025

I think it's better to add a layer on top of vLLM that queries /load endpoint in order to perform the load balancing/limiting. Perhaps @robertgshaw2-redhat and @njhill could provide better suggestions regarding this.

I appreciate the suggestion to add a layer querying the /load endpoint.

That said, I lean towards ​​active rejection at the entry point​​ instead. Checking /load incurs overhead even during low-traffic periods, while upfront rejection preserves resources by shedding excess requests early. This ensures stable latency for accepted requests and allows efficient rescheduling of rejected ones, avoiding internal queue congestion.

We also value input from @robertgshaw2-redhat and @njhill — your insights would certainly strengthen this approach.

@njhill njhill self-requested a review September 15, 2025 18:39
Copy link
Copy Markdown
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @scratch-ml, I think this is a nice lightweight approach for basic load-shedding.

Ultimately it would be good to reject inference requests based on a estimated queue waiting time, which would depend on the current contents of the running and waiting queues ... i.e. taking into account input / expected output token counts of running/queued reqs, per-token latency etc.

Copy link
Copy Markdown
Collaborator

@chaunceyjiang chaunceyjiang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Compared to #21352, this PR is indeed more lightweight. One question: does it work together with --api-server-count?

@njhill
Copy link
Copy Markdown
Member

njhill commented Sep 16, 2025

@chaunceyjiang good question ... no it wouldn't work precisely with that but if the value was large then you could approximate the behaviour by just dividing the max load value by the api server count.

Signed-off-by: scratch-ml <limingliang0527@gmail.com>
@scratch-ml scratch-ml force-pushed the feature/server-load-limit branch 2 times, most recently from af06c3c to 1f1cb85 Compare September 16, 2025 16:58
Signed-off-by: scratch-ml <limingliang0527@gmail.com>
@scratch-ml scratch-ml force-pushed the feature/server-load-limit branch from 1f1cb85 to 1dd742f Compare September 16, 2025 17:05
Signed-off-by: scratch-ml <limingliang0527@gmail.com>
Signed-off-by: scratch-ml <limingliang0527@gmail.com>
Copy link
Copy Markdown
Collaborator

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for contributing @scratch-ml !

My opinion on this kind of load-based decisions is that the policy should, in my view, sit one level "in front" of vLLM, with traditional proxies.
In particular, I think this change is mostly useful in a single-server scenario and it loses a lot of its value in a distributed setup where a more structured LB is to be used anyways (eg llm-d load-informed scheduler to name one).

@scratch-ml
Copy link
Copy Markdown
Contributor Author

scratch-ml commented Sep 17, 2025

Thanks a lot for contributing @scratch-ml !

My opinion on this kind of load-based decisions is that the policy should, in my view, sit one level "in front" of vLLM, with traditional proxies. In particular, I think this change is mostly useful in a single-server scenario and it loses a lot of its value in a distributed setup where a more structured LB is to be used anyways (eg llm-d load-informed scheduler to name one).

@NickLucche Thank you for your suggestion. I appreciate your perspective on implementing request rejection at the upper-layer proxy being a more elegant approach. That said, I believe this mechanism can still serve as a valuable ​​fallback​​ for scenarios ​​without a sophisticated scheduler infrastructure​​ — for example, where users might rely on a ​​simple round robin​​ strategy with retry mechanisms. To some extent, an API server also acts as a request proxy, and it makes sense for proxies at different levels to possess varying degrees of rate-limiting capabilities.Thank you again for everyone's input. Appreciate the discussion!

@scratch-ml
Copy link
Copy Markdown
Contributor Author

Hello @njhill @NickLucche @chaunceyjiang

I wanted to check in on this PR. If there's a consensus that this feature isn't needed, I'm ready to close it. Alternatively, if you believe it has value, I'll continue refining it with your feedback until it meets all necessary standards.

Appreciate your input and time.

@chaunceyjiang
Copy link
Copy Markdown
Collaborator

In particular, I think this change is mostly useful in a single-server scenario and it loses a lot of its value in a distributed setup where a more structured LB is to be used anyways (eg llm-d load-informed scheduler to name one).

I agree with @NickLucche’s point. In addition, I think that compared to #21352, the use cases for this PR are quite limited. Whether in distributed scenarios or with multiple API servers, I actually lean more toward an implementation based on a waiting queue.

@scratch-ml scratch-ml closed this Sep 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants