Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2104,8 +2104,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=Fal
):
q_a_proj_weight = cached_a_proj[q_a_proj_name]
kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
cat_dim = 0
if self.quant_config.get_name() == "awq" or self.quant_config.get_name() == "moe_wna16":
cat_dim = 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The condition self.quant_config.get_name() == "awq" or self.quant_config.get_name() == "moe_wna16" is repeated. Consider extracting this to a variable for better readability and maintainability.

is_awq_or_moe_wna16 = self.quant_config.get_name() == "awq" or self.quant_config.get_name() == "moe_wna16"
                            cat_dim = 0
                            if is_awq_or_moe_wna16:
                                cat_dim = 1

fused_weight = torch.cat(
[q_a_proj_weight, kv_a_proj_weight], dim=0
[q_a_proj_weight, kv_a_proj_weight], dim=cat_dim
)

param_name = name.replace(
Expand Down
14 changes: 8 additions & 6 deletions sgl-kernel/csrc/gemm/awq_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,13 @@ __global__ void __launch_bounds__(256) dequantize_weights(
int* __restrict__ qzeros,
OutputT* __restrict__ output,
int group_size,
int qweight_cols) {
int qweight_cols,
int qweight_rows) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Adding qweight_rows as an argument to the dequantize_weights kernel is a good step to ensure correct dequantization. However, it's crucial to verify that qweight_rows is actually being used within the kernel to influence the dequantization process. If it's not used, it's an unnecessary argument.

#if CUDA_VERSION >= 12000
int col = blockIdx.x * blockDim.x + threadIdx.x;
int row = blockIdx.y * blockDim.y + threadIdx.y;

if (col >= qweight_cols || row >= qweight_rows) return;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This check if (col >= qweight_cols || row >= qweight_rows) return; is important to prevent out-of-bounds access. However, consider adding an assertion or a log message (if the condition is met unexpectedly) to help debug potential issues with the kernel launch configuration.


int group_idx = row / group_size;
int scale_offset = 8 * col + group_idx * qweight_cols * 8;
uint4 loaded_scale = *(uint4*)(scales + scale_offset);
Expand Down Expand Up @@ -188,8 +190,8 @@ torch::Tensor awq_dequantize(torch::Tensor qweight, torch::Tensor scales, torch:

int x_num_threads = 16;
int y_num_threads = 16;
int x_blocks = qweight_cols / x_num_threads;
int y_blocks = qweight_rows / y_num_threads;
int x_blocks = (qweight_cols + x_num_threads - 1) / x_num_threads;
int y_blocks = (qweight_rows + y_num_threads - 1) / y_num_threads;
Comment on lines +193 to +194
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The calculation of x_blocks and y_blocks uses ceiling division to ensure all elements are processed. This is good for correctness. However, it's important to ensure that the kernel handles the case where qweight_cols or qweight_rows is not a multiple of x_num_threads or y_num_threads efficiently. Consider adding a comment explaining this.


const at::cuda::OptionalCUDAGuard device_guard(device_of(qweight));

Expand All @@ -207,12 +209,12 @@ torch::Tensor awq_dequantize(torch::Tensor qweight, torch::Tensor scales, torch:
auto _scales = reinterpret_cast<half*>(scales.data_ptr<at::Half>());
auto _output = reinterpret_cast<half*>(output.data_ptr<at::Half>());
dequantize_weights<half>
<<<num_blocks, threads_per_block, 0, stream>>>(_qweight, _scales, _zeros, _output, group_size, qweight_cols);
<<<num_blocks, threads_per_block, 0, stream>>>(_qweight, _scales, _zeros, _output, group_size, qweight_cols, qweight_rows);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

It's good that qweight_rows is being passed to the dequantize_weights kernel. Double check that the kernel uses this value to ensure that the dequantization is performed correctly for the new deepseek v2 models.

} else {
auto _scales = reinterpret_cast<__nv_bfloat16*>(scales.data_ptr<at::BFloat16>());
auto _output = reinterpret_cast<__nv_bfloat16*>(output.data_ptr<at::BFloat16>());
dequantize_weights<__nv_bfloat16>
<<<num_blocks, threads_per_block, 0, stream>>>(_qweight, _scales, _zeros, _output, group_size, qweight_cols);
<<<num_blocks, threads_per_block, 0, stream>>>(_qweight, _scales, _zeros, _output, group_size, qweight_cols, qweight_rows);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

It's good that qweight_rows is being passed to the dequantize_weights kernel. Double check that the kernel uses this value to ensure that the dequantization is performed correctly for the new deepseek v2 models.

}

return output;
Expand Down
7 changes: 4 additions & 3 deletions sgl-kernel/tests/test_awq_dequant.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,10 @@ def sglang_awq_dequantize(
"qweight_row,qweight_col,is_bf16_act",
list(
itertools.product(
[3584, 18944, 128, 256, 512, 1024],
[448, 576, 4736, 16, 32, 64, 128],
# [7168, 7168, 7168, 128, 128],
# [264, 192, 72, 16, 24],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The commented-out test parameters suggest that larger values were previously used. It would be beneficial to re-enable these larger values in the test suite to ensure that the fix is robust across a wider range of input sizes. If there's a reason these were commented out (e.g., resource constraints), it should be documented.

[128, 128, 128, 128],
[16, 24, 32, 40],
[True, False],
)
),
Expand All @@ -77,7 +79,6 @@ def test_awq_dequant_compare_implementations(
qweight_row: int, qweight_col: int, is_bf16_act: bool
):
device = torch.device("cuda")

qweight = torch.randint(
0,
torch.iinfo(torch.int32).max,
Expand Down
Loading