Skip to content

Commit 744d856

Browse files
jinzhen-linYuqi Zhang
authored andcommitted
[Bugfix] fix an illegal memory access was encountered of marlin kernel + act_order (vllm-project#18245)
Signed-off-by: Yuqi Zhang <[email protected]>
1 parent 7fa41b4 commit 744d856

File tree

3 files changed

+29
-22
lines changed

3 files changed

+29
-22
lines changed

csrc/moe/marlin_moe_wna16/marlin_template.h

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1767,17 +1767,20 @@ __global__ void Marlin(
17671767

17681768
if constexpr (has_act_order) {
17691769
slice_k_start += tb_k * stages;
1770-
slice_k_start_shared_fetch += tb_k * stages;
1771-
int first_group_id = g_idx[slice_k_start];
1772-
int last_g_idx = slice_k_start + stages * tb_k * 2;
1773-
if (last_g_idx >= prob_k) {
1774-
last_g_idx = prob_k - 1;
1775-
}
1776-
int last_group_id = g_idx[last_g_idx];
1777-
if (last_group_id >= sh_first_group_id + sh_num_groups) {
1778-
fetch_act_order_scales_to_shared(false, first_group_id,
1779-
last_group_id);
1780-
__syncthreads();
1770+
1771+
if (slice_k_start < prob_k) {
1772+
slice_k_start_shared_fetch += tb_k * stages;
1773+
int first_group_id = g_idx[slice_k_start];
1774+
int last_g_idx = slice_k_start + stages * tb_k * 2;
1775+
if (last_g_idx >= prob_k) {
1776+
last_g_idx = prob_k - 1;
1777+
}
1778+
int last_group_id = g_idx[last_g_idx];
1779+
if (last_group_id >= sh_first_group_id + sh_num_groups) {
1780+
fetch_act_order_scales_to_shared(false, first_group_id,
1781+
last_group_id);
1782+
__syncthreads();
1783+
}
17811784
}
17821785
}
17831786
if (slice_iters == 0) {

csrc/quantization/gptq_marlin/marlin_template.h

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1588,16 +1588,20 @@ __global__ void Marlin(
15881588

15891589
if constexpr (has_act_order) {
15901590
slice_k_start += tb_k * stages;
1591-
slice_k_start_shared_fetch += tb_k * stages;
1592-
int first_group_id = g_idx[slice_k_start];
1593-
int last_g_idx = slice_k_start + stages * tb_k * 2;
1594-
if (last_g_idx >= prob_k) {
1595-
last_g_idx = prob_k - 1;
1596-
}
1597-
int last_group_id = g_idx[last_g_idx];
1598-
if (last_group_id >= sh_first_group_id + sh_num_groups) {
1599-
fetch_act_order_scales_to_shared(false, first_group_id, last_group_id);
1600-
__syncthreads();
1591+
1592+
if (slice_k_start < prob_k) {
1593+
slice_k_start_shared_fetch += tb_k * stages;
1594+
int first_group_id = g_idx[slice_k_start];
1595+
int last_g_idx = slice_k_start + stages * tb_k * 2;
1596+
if (last_g_idx >= prob_k) {
1597+
last_g_idx = prob_k - 1;
1598+
}
1599+
int last_group_id = g_idx[last_g_idx];
1600+
if (last_group_id >= sh_first_group_id + sh_num_groups) {
1601+
fetch_act_order_scales_to_shared(false, first_group_id,
1602+
last_group_id);
1603+
__syncthreads();
1604+
}
16011605
}
16021606
}
16031607

tests/weight_loading/models.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ gptq_marlin, robertgshaw2/zephyr-7b-beta-channelwise-gptq, main
22
gptq_marlin, TheBloke/Llama-2-7B-GPTQ, main
33
gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, main
44
gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, gptq-8bit--1g-actorder_True
5-
#gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, gptq-8bit-32g-actorder_True
5+
gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, gptq-8bit-32g-actorder_True
66
gptq_marlin, TechxGenus/gemma-1.1-2b-it-GPTQ, main
77
gptq, robertgshaw2/zephyr-7b-beta-channelwise-gptq, main
88
gptq, TheBloke/Llama-2-7B-GPTQ, main

0 commit comments

Comments
 (0)