-
-
Notifications
You must be signed in to change notification settings - Fork 15k
[Misc][BE] Turn on strict type coverage for vllm/compilation #31756
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7194132
3db3762
fd34484
102f7fa
60aedcc
b7f7280
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,19 +30,15 @@ | |
|
|
||
| FP8_DTYPE = current_platform.fp8_dtype() | ||
|
|
||
| flashinfer_comm: ModuleType | None = None | ||
ProExpertProg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if find_spec("flashinfer"): | ||
| try: | ||
| import flashinfer.comm as flashinfer_comm | ||
| import flashinfer.comm as _flashinfer_comm | ||
|
|
||
| flashinfer_comm: ModuleType | None = ( # type: ignore[no-redef] | ||
| flashinfer_comm | ||
| if hasattr(flashinfer_comm, "trtllm_allreduce_fusion") | ||
| else None | ||
| ) | ||
| if hasattr(_flashinfer_comm, "trtllm_allreduce_fusion"): | ||
| flashinfer_comm = _flashinfer_comm | ||
| except ImportError: | ||
| flashinfer_comm = None # type: ignore[assignment] | ||
| else: | ||
| flashinfer_comm = None # type: ignore[assignment] | ||
| pass | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
|
|
@@ -441,7 +437,7 @@ def is_applicable_for_range(self, compile_range: Range) -> bool: | |
| ): | ||
| return True | ||
| tp_size = get_tensor_model_parallel_world_size() | ||
| return compile_range.is_single_size() and compile_range.end % tp_size == 0 | ||
| return bool(compile_range.is_single_size() and compile_range.end % tp_size == 0) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how is this not already bool haha
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. with |
||
|
|
||
| @VllmInductorPass.time_and_log | ||
| def __call__(self, graph: fx.Graph) -> None: | ||
|
|
@@ -516,7 +512,7 @@ def call_trtllm_fused_allreduce_norm( | |
| # Get one shot input size limit for the current world size | ||
| # for the current device capability | ||
| max_one_shot_size = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB.get( | ||
| device_capability, # type: ignore[arg-type] | ||
| device_capability, # type: ignore[arg-type, unused-ignore] | ||
| {}, | ||
| ).get(world_size, None) | ||
| # Use one shot if no max size is specified | ||
|
|
@@ -666,6 +662,7 @@ def replacement( | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| residual = torch.zeros_like(input) | ||
| rms_result = torch.empty_like(input) | ||
| assert flashinfer_comm is not None, "FlashInfer must be enabled" | ||
| allreduce = auto_functionalized( | ||
| flashinfer_trtllm_fused_allreduce_norm, | ||
| allreduce_in=input, | ||
|
|
@@ -722,6 +719,7 @@ def pattern( | |
| def replacement( | ||
| residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| assert flashinfer_comm is not None, "FlashInfer must be enabled" | ||
| allreduce = auto_functionalized( | ||
| flashinfer_trtllm_fused_allreduce_norm, | ||
| allreduce_in=input, | ||
|
|
@@ -800,6 +798,7 @@ def replacement( | |
| residual = torch.zeros_like(input) | ||
| result_rms = torch.empty_like(input) | ||
| result_quant = torch.empty_like(input, dtype=self.quant_dtype) | ||
| assert flashinfer_comm is not None, "FlashInfer must be enabled" | ||
| allreduce = auto_functionalized( | ||
| flashinfer_trtllm_fused_allreduce_norm, | ||
| allreduce_in=input, | ||
|
|
@@ -875,6 +874,7 @@ def replacement( | |
| scale: torch.Tensor, | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| result_quant = torch.empty_like(input, dtype=self.quant_dtype) | ||
| assert flashinfer_comm is not None, "FlashInfer must be enabled" | ||
| allreduce = auto_functionalized( | ||
| flashinfer_trtllm_fused_allreduce_norm, | ||
| allreduce_in=input, | ||
|
|
@@ -960,6 +960,7 @@ def replacement( | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
| residual = torch.zeros_like(input) | ||
| result_rms = torch.empty_like(input) | ||
| assert flashinfer_comm is not None, "FlashInfer must be enabled" | ||
| allreduce = auto_functionalized( | ||
| flashinfer_trtllm_fused_allreduce_norm, | ||
| allreduce_in=input, | ||
|
|
@@ -1055,6 +1056,7 @@ def replacement( | |
| weight: torch.Tensor, | ||
| input_global_scale: torch.Tensor, | ||
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
| assert flashinfer_comm is not None, "FlashInfer must be enabled" | ||
| allreduce = auto_functionalized( | ||
| flashinfer_trtllm_fused_allreduce_norm, | ||
| allreduce_in=input, | ||
|
|
@@ -1131,7 +1133,7 @@ def __init__(self, config: VllmConfig) -> None: | |
| ) | ||
|
|
||
| self.ipc_handles, workspace_tensor = ( | ||
| flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( # type: ignore[misc] | ||
| flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( | ||
| tp_rank=rank, | ||
| tp_size=self.tp_size, | ||
| max_token_num=self.max_token_num, | ||
|
|
@@ -1204,7 +1206,7 @@ def is_applicable_for_range(self, compile_range: Range) -> bool: | |
| if self.disabled: | ||
| logger.warning_once("AllReduce fusion pass is disabled.") | ||
| return False | ||
| return compile_range.end <= self.max_token_num | ||
| return bool(compile_range.end <= self.max_token_num) | ||
|
|
||
| @VllmInductorPass.time_and_log | ||
| def __call__(self, graph: fx.Graph) -> None: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -201,9 +201,9 @@ def __init__(self, save_format: Literal["binary", "unpacked"]) -> None: | |
|
|
||
| def compute_hash(self, vllm_config: VllmConfig) -> str: | ||
| factors = get_inductor_factors() | ||
| hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()[ | ||
| :10 | ||
| ] | ||
| hash_str: str = safe_hash( | ||
| str(factors).encode(), usedforsecurity=False | ||
| ).hexdigest()[:10] | ||
|
Comment on lines
+204
to
+206
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should just use the utils for hashing here that we use for other compile_hash functions (I think sha256)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be the same util I believe ( |
||
| return hash_str | ||
|
|
||
| def initialize_cache( | ||
|
|
@@ -319,9 +319,9 @@ class InductorAdaptor(CompilerInterface): | |
|
|
||
| def compute_hash(self, vllm_config: VllmConfig) -> str: | ||
| factors = get_inductor_factors() | ||
| hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()[ | ||
| :10 | ||
| ] | ||
| hash_str: str = safe_hash( | ||
| str(factors).encode(), usedforsecurity=False | ||
| ).hexdigest()[:10] | ||
| return hash_str | ||
|
|
||
| def initialize_cache( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note for reviewer: This is not
strict = Truebecause this would cause cascading strict imports to other modules fromfollow_importswhen checked with pre-commit onlyInstead we default to a few sensible options here