-
-
Notifications
You must be signed in to change notification settings - Fork 15.6k
[Fix Bug]num_active_loras always equals to zero
#34119
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c6a3955
f7477ae
1c23b9a
00eb8a7
6e3c850
5b896a2
cee88f5
3cd4e95
ce52765
633703f
929ee75
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -138,7 +138,7 @@ def _lora_expand( | |
| lora_token_start_loc: torch.Tensor, # shape [max-loras + 2] | ||
| lora_ids: torch.Tensor, # shape [max-loras + 1] | ||
| no_lora_flag_cpu: torch.Tensor, # shape [1] | ||
| num_active_loras: int, # number of active LoRAs (unused here, for API compat) | ||
| num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit : can you add
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Comments are added. |
||
| offset_start: int = 0, | ||
| add_inputs: bool = False, | ||
| ) -> None: | ||
|
|
@@ -235,7 +235,7 @@ def _lora_expand( | |
| grid = ( | ||
| triton.cdiv(M, BLOCK_M) * triton.cdiv(MAX_N, BLOCK_N), | ||
| NUM_SLICES, | ||
| num_active_loras, | ||
| num_active_loras.item(), | ||
| ) | ||
| # We disable PDL temporarily because LoRA kernels are not launching back-to-back, | ||
| # making PDL invalid and affecting the kernel performance. | ||
|
|
@@ -289,7 +289,7 @@ def _lora_expand_fake( | |
| lora_token_start_loc: torch.Tensor, | ||
| lora_ids: torch.Tensor, | ||
| no_lora_flag_cpu: torch.Tensor, | ||
| num_active_loras: int, | ||
| num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs | ||
| offset_start: int = 0, | ||
| add_inputs: bool = False, | ||
| ) -> None: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,9 +29,16 @@ class LoRAKernelMeta: | |
| # to early exit from inside the lora_expand / lora_shrink torch operation. | ||
| no_lora_flag_cpu: torch.Tensor | ||
|
|
||
| # Number of active LoRAs (unique non-(-1) values in token_lora_mapping) | ||
| # Stored as a Python int to avoid GPU->CPU sync during forward pass | ||
| num_active_loras: int = 0 | ||
| # Number of active LoRAs (unique non-(-1) values in token_lora_mapping). | ||
| # Stored as a CPU tensor (not a Python int) so that torch.compile treats | ||
| # it as a dynamic value rather than baking it as a constant at trace time. | ||
| # This follows the same pattern as no_lora_flag_cpu above. | ||
| num_active_loras_cpu: torch.Tensor | ||
|
|
||
| # Default num_active_loras value (max_loras + 1) as a CPU tensor, | ||
| # used when specialize_active_lora is False to avoid allocating a | ||
| # new tensor on every meta_args() call. | ||
| default_num_active_loras_cpu: torch.Tensor | ||
|
|
||
| # Captured LoRA counts for cudagraph specialization (sorted list). | ||
| # When specialize_active_lora is enabled, num_active_loras is rounded up | ||
|
|
@@ -73,13 +80,20 @@ def make( | |
|
|
||
| no_lora_flag_cpu = torch.tensor([False], dtype=torch.bool, device="cpu") | ||
|
|
||
| num_active_loras_cpu = torch.tensor([0], dtype=torch.int32, device="cpu") | ||
| default_num_active_loras_cpu = torch.tensor( | ||
| [max_loras + 1], dtype=torch.int32, device="cpu" | ||
| ) | ||
|
|
||
| return LoRAKernelMeta( | ||
| token_lora_mapping=token_lora_mapping, | ||
| token_indices_sorted_by_lora_ids=token_indices_sorted_by_lora_ids, | ||
| active_lora_ids=active_lora_ids, | ||
| num_tokens_per_lora=num_tokens_per_lora, | ||
| lora_token_start_loc=lora_token_start_loc, | ||
| no_lora_flag_cpu=no_lora_flag_cpu, | ||
| num_active_loras_cpu=num_active_loras_cpu, | ||
| default_num_active_loras_cpu=default_num_active_loras_cpu, | ||
| captured_lora_counts=sorted(captured_lora_counts) | ||
| if captured_lora_counts | ||
| else [], | ||
|
|
@@ -90,8 +104,7 @@ def _reset(self): | |
| self.num_tokens_per_lora.fill_(0) | ||
| self.lora_token_start_loc.fill_(0) | ||
| self.no_lora_flag_cpu.fill_(False) | ||
| self.num_active_loras = 0 | ||
| self.captured_lora_counts = [] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes @RunkaiTao don't we need this?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| self.num_active_loras_cpu.fill_(0) | ||
|
|
||
| def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None: | ||
| """ | ||
|
|
@@ -137,14 +150,16 @@ def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None: | |
| num_tokens_per_lora, non_blocking=True | ||
| ) | ||
|
|
||
| self.num_active_loras = lora_ids.size(0) | ||
| num_active_loras = lora_ids.size(0) | ||
|
|
||
| # Round up num_active_loras to match cudagraph capture keys. | ||
| # This ensures the kernel grid dimension matches the captured graph. | ||
| if self.captured_lora_counts and self.num_active_loras > 0: | ||
| idx = bisect.bisect_left(self.captured_lora_counts, self.num_active_loras) | ||
| if self.captured_lora_counts and num_active_loras > 0: | ||
| idx = bisect.bisect_left(self.captured_lora_counts, num_active_loras) | ||
| if idx < len(self.captured_lora_counts): | ||
| self.num_active_loras = self.captured_lora_counts[idx] | ||
| num_active_loras = self.captured_lora_counts[idx] | ||
|
|
||
| self.num_active_loras_cpu[0] = num_active_loras | ||
|
|
||
| # lora_token_start_loc | ||
| lora_token_start_loc = torch.cumsum(num_tokens_per_lora, dim=0) | ||
|
|
@@ -163,7 +178,7 @@ def meta_args( | |
| torch.Tensor, | ||
| torch.Tensor, | ||
| torch.Tensor, | ||
| int, | ||
| torch.Tensor, | ||
| ]: | ||
| """ | ||
| This function returns the kernel metadata required for the current | ||
|
|
@@ -175,13 +190,16 @@ def meta_args( | |
| token_nums (int): Number of input tokens in the current forward | ||
| pass of the kernel. | ||
| """ | ||
| max_loras = self.active_lora_ids.size(0) - 1 | ||
| if specialize_active_lora: | ||
| num_active_loras = self.num_active_loras_cpu | ||
| else: | ||
| num_active_loras = self.default_num_active_loras_cpu | ||
| return ( | ||
| self.token_lora_mapping[:token_nums], | ||
| self.token_indices_sorted_by_lora_ids[:token_nums], | ||
| self.num_tokens_per_lora, | ||
| self.lora_token_start_loc, | ||
| self.active_lora_ids, | ||
| self.no_lora_flag_cpu, | ||
| self.num_active_loras if specialize_active_lora else max_loras + 1, | ||
| num_active_loras, | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think OOM is due to this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes @RunkaiTao we should add back this argument
cudagraph_specialize_lora=Falsethanks!There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made it false now