-
-
Notifications
You must be signed in to change notification settings - Fork 16.7k
[Fix] Misc Fixes in ViT CUDA Graph #38040
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
b1cbbba
7baa187
761244f
e09ac8c
b3b32b4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -72,25 +72,65 @@ def __init__( | |
|
|
||
| multimodal_config = vllm_config.model_config.multimodal_config | ||
|
|
||
| # Invariant: max_batch_size <= min_token_budget. | ||
| # This ensures per_image_output = budget // max_batch_size >= 1 | ||
| # for every captured budget, preventing reshape crashes on empty | ||
| # tensors during CUDA graph capture. Validated/enforced below for | ||
| # each configuration path. | ||
| if user_budgets and user_max_vision_items > 0: | ||
| # Fully user-specified | ||
| # Fully user-specified: validate the invariant. | ||
| self.token_budgets = sorted(user_budgets) | ||
| self.max_batch_size = user_max_vision_items | ||
| min_tok = min(self.token_budgets) | ||
| if self.max_batch_size > min_tok: | ||
| raise ValueError( | ||
| f"encoder_cudagraph_max_vision_items_per_batch " | ||
| f"({self.max_batch_size}) must be <= smallest token " | ||
| f"budget ({min_tok}). With budgets=" | ||
| f"{self.token_budgets}, per_image_output = " | ||
| f"{min_tok} // {self.max_batch_size} = " | ||
| f"{min_tok // self.max_batch_size}, which would cause " | ||
| f"a capture failure. Either increase the smallest " | ||
| f"budget or decrease max_vision_items_per_batch." | ||
| ) | ||
| else: | ||
| # Auto-infer missing values from model | ||
| # Auto-infer missing values from model. | ||
| min_budget, max_budget = model.get_encoder_cudagraph_budget_range( | ||
| vllm_config | ||
| ) | ||
| self.token_budgets = ( | ||
| sorted(user_budgets) | ||
| if user_budgets | ||
| else self._generate_budgets(min_budget, max_budget) | ||
| ) | ||
| self.max_batch_size = ( | ||
| user_max_vision_items | ||
| if user_max_vision_items > 0 | ||
| else max_budget // min_budget | ||
| ) | ||
| if min_budget <= 0 or max_budget <= 0: | ||
| raise ValueError( | ||
| f"Invalid encoder cudagraph budget range: " | ||
| f"min_budget={min_budget}, max_budget={max_budget}. " | ||
| f"Both must be positive." | ||
| ) | ||
| if min_budget > max_budget: | ||
| raise ValueError( | ||
| f"Invalid encoder cudagraph budget range: " | ||
| f"min_budget={min_budget} > max_budget={max_budget}." | ||
| ) | ||
|
|
||
| if user_max_vision_items > 0: | ||
| # User provided max_vision_items only; adjust auto-inferred | ||
| # budgets so min(budgets) >= max_batch_size. | ||
| self.max_batch_size = user_max_vision_items | ||
| effective_min = max(min_budget, user_max_vision_items) | ||
| self.token_budgets = self._generate_budgets(effective_min, max_budget) | ||
| elif user_budgets: | ||
| # User provided budgets only; cap auto-inferred | ||
| # max_batch_size to min(user_budgets). | ||
| self.token_budgets = sorted(user_budgets) | ||
| self.max_batch_size = min( | ||
| max_budget // min_budget, | ||
| min(self.token_budgets), | ||
| ) | ||
|
Comment on lines
+122
to
+126
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to the other block, user-provided self.token_budgets = sorted(user_budgets)
if self.token_budgets[0] <= 0:
raise ValueError(
f"Invalid encoder_cudagraph_token_budgets: {user_budgets}. "
"All budget values must be positive."
)
self.max_batch_size = min(
max_budget // min_budget,
self.token_budgets[0],
)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This has been addressed by checking in |
||
| else: | ||
| # Fully auto-inferred. | ||
| self.token_budgets = self._generate_budgets(min_budget, max_budget) | ||
| self.max_batch_size = min( | ||
| max_budget // min_budget, | ||
| min(self.token_budgets), | ||
| ) | ||
|
|
||
| assert multimodal_config is not None | ||
| if multimodal_config.get_limit_per_prompt("video") == 0: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
User-provided
encoder_cudagraph_token_budgetsare not validated for positivity. If a user provides a non-positive budget (e.g.,[0, 128]),min(self.token_budgets)could be zero or negative. This can lead toself.max_batch_sizebeing set to zero, which will cause aZeroDivisionErrorlater during CUDA graph capture preparation.You should add validation to ensure all user-provided budgets are positive. A similar check is needed in the
elif user_budgets:block.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a fair point, but I feel that the better way to handle this is through pydantic (cuz ultimately this is an input validation problem), e.g., in the definition of CompilationConfig:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have added the check in
vllm/config/compilation.py.