-
-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[lora/moe] Improve fused MoE‑LoRA kernel indexing and memory access #32770
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a0119e3
f98d2eb
9816085
2f4df5b
2cfbe19
c83d089
502f11b
8f57d8e
7666091
4da8f6f
7899196
abe6851
d1d5acf
1bc92b2
20f38fc
b982de0
24866b6
ac2c273
b6bbb68
a72e7d0
8b207bd
3edc1c8
c0a8ab1
af07f6f
9a46d4d
5cab5f6
80e6fcd
6b078f5
42b8d02
e01a674
4adb5ec
7ab692f
6a7567c
db01ec1
78750a7
c8b56b7
ba8a3aa
c8b2723
25b5600
9dfb58d
6e96888
830f060
c0d0a6d
04dc150
b24f914
d6518d7
5ca439b
77467d1
7c712dd
135099e
1e8de71
c482308
bc077a8
31d068b
2b541c5
f2b6c80
9ca38c1
7a4b69c
505dc0e
8ddd79e
f3617c8
eda9a3d
04460ed
56eba0b
35bd5e9
b2d44ef
72b043f
0f4fd95
5a17bba
468c45a
4983e12
b6af384
08e06fa
325b7d7
1f3bd57
d7fd61a
4980d0b
c85d817
35a2c9a
6855d37
3aca787
8f46305
518c8c1
aae07c0
6dfd19f
69acfe5
8fa0a36
6804f18
04c168d
2b48388
76fde85
70ca1f0
2c86342
187c02a
f1ff77b
d6612ac
6740169
a1cdaf9
e4b1fc1
ffa2efc
8000ab5
0108f4a
75311f3
3baef7f
fc5ed11
0b063f2
b4263fb
03c481d
ec28238
086239b
9223b58
039d53c
5975c3a
6dcd724
e54514e
77935f8
44e22be
73d908e
fdb1cfb
1af9143
bc2a887
89d4644
173291f
eb596f7
4c3cdc1
bb70650
7203206
94cc8d7
0f21271
3cdaa87
196f8b8
b3deb77
2aac4da
2776a97
a3c086a
67fce76
a680781
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -62,6 +62,7 @@ def _fused_moe_lora_kernel( | |
| num_experts, | ||
| lora_ids, | ||
| adapter_enabled, | ||
| max_loras, # <<< PR2: rename, used for masks when grid axis-2 != max_loras | ||
| # The stride variables represent how much to increase the ptr by when | ||
| # moving by 1 element in a particular dimension. E.g. `stride_am` is | ||
| # how much to increase `a_ptr` by to get the element one row down | ||
|
|
@@ -83,6 +84,7 @@ def _fused_moe_lora_kernel( | |
| num_slice_c: tl.constexpr, | ||
| top_k: tl.constexpr, | ||
| MUL_ROUTED_WEIGHT: tl.constexpr, | ||
| USE_B_L2_CACHE: tl.constexpr, # new, enable .ca load for B | ||
| BLOCK_SIZE_M: tl.constexpr, | ||
| BLOCK_SIZE_N: tl.constexpr, | ||
| BLOCK_SIZE_K: tl.constexpr, | ||
|
|
@@ -104,10 +106,13 @@ def _fused_moe_lora_kernel( | |
| if moe_enabled == 0: | ||
| # Early exit for the no moe lora case. | ||
| return | ||
| # The grid size on axis 2 is (max_loras + 1) to handle the no-lora case | ||
| # (lora_id == -1), but sorted_token_ids and expert_ids are allocated with | ||
| # shape (max_loras, ...). Use (num_programs - 1) for correct bounds checking. | ||
| max_loras = tl.num_programs(axis=2) - 1 | ||
| # The grid's axis-2 dimension is max_loras + 1 to accommodate the -1 sentinel. | ||
| # This guard ensures we don't access sorted_token_ids / expert_ids / | ||
| # num_tokens_post_padded beyond their allocated bounds if an invalid | ||
| # lora_id somehow appears. Although the caller should pass correct | ||
| # max_loras, defensive programming prevents accidental out-of-bounds. | ||
| if lora_id >= max_loras: | ||
| return | ||
| grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K) | ||
|
|
||
| # calculate pid_m,pid_n | ||
|
|
@@ -136,10 +141,11 @@ def _fused_moe_lora_kernel( | |
| cur_b_ptr = tl.load(b_ptr + slice_id).to(tl.pointer_type(c_ptr.dtype.element_ty)) | ||
| cur_c_ptr = c_ptr + (slice_id % num_slice_c) * slice_c_size | ||
|
|
||
| offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N | ||
| # remove modulo wrap-around | ||
| offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int32) | ||
| offs_k = pid_sk * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) | ||
|
|
||
| offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) | ||
| offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int32) | ||
| token_ind = stride_tl * lora_id + offs_token_id | ||
| offs_token = tl.load( | ||
| sorted_token_ids_ptr + token_ind, | ||
|
|
@@ -176,7 +182,13 @@ def _fused_moe_lora_kernel( | |
| # GDC wait waits for ALL programs in the prior kernel to complete | ||
| # before continuing. | ||
| # pre-fetch lora weight | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why delete these comments?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you for pointing this out. This was an oversight on my part—I originally intended to finalize these adjustments in the final step once everything was confirmed ready to merge. I’ll restore the comments right away. Thanks again for the careful review. |
||
| b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0) | ||
| # add (offs_bn < N) mask; optional .ca for B | ||
| b_mask = (offs_k[:, None] < k_remaining) & (offs_bn[None, :] < N) | ||
| if USE_B_L2_CACHE: | ||
| b = tl.load(b_ptrs, mask=b_mask, other=0.0, cache_modifier=".ca") | ||
| else: | ||
| b = tl.load(b_ptrs, mask=b_mask, other=0.0) | ||
|
|
||
| if USE_GDC and not IS_PRIMARY: | ||
| tl.extra.cuda.gdc_wait() | ||
| a = tl.load( | ||
|
|
@@ -276,6 +288,7 @@ def _fused_moe_lora_shrink( | |
| num_experts, | ||
| lora_ids, | ||
| adapter_enabled, | ||
| lora_a_stacked[0].shape[0], | ||
| qcurr_hidden_states.stride(0), | ||
| qcurr_hidden_states.stride(1), | ||
| w1_lora_a_stacked.stride(0), | ||
|
|
@@ -292,6 +305,7 @@ def _fused_moe_lora_shrink( | |
| num_slice_c=num_slices, | ||
| top_k=1 if mul_routed_weight else top_k_num, | ||
| MUL_ROUTED_WEIGHT=False, | ||
| USE_B_L2_CACHE=True, # new | ||
| IS_PRIMARY=True, | ||
| **shrink_config, | ||
| ) | ||
|
|
@@ -377,6 +391,7 @@ def _fused_moe_lora_expand( | |
| num_experts, | ||
| lora_ids, | ||
| adapter_enabled, | ||
| lora_b_stacked[0].shape[0], | ||
| a_intermediate_cache1.stride(0), | ||
| a_intermediate_cache1.stride(1), | ||
| w1_lora_b_stacked.stride(0), | ||
|
|
@@ -393,6 +408,7 @@ def _fused_moe_lora_expand( | |
| num_slice_c=num_slices, | ||
| top_k=1, | ||
| MUL_ROUTED_WEIGHT=mul_routed_weight, | ||
| USE_B_L2_CACHE=True, # new | ||
| IS_PRIMARY=False, | ||
| **expand_config, | ||
| ) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add comment on why we need to add this case. Technically since you are passing it correctly it shouldn't go in here but ok to have this check.