-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Fix AWQ Dequant and Weight Loading of deepseek v2 #6842
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
f1ba0f2
baf08b5
bf415b6
3a5f811
4caba62
2ec4415
14a52c3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -130,11 +130,13 @@ __global__ void __launch_bounds__(256) dequantize_weights( | |
| int* __restrict__ qzeros, | ||
| OutputT* __restrict__ output, | ||
| int group_size, | ||
| int qweight_cols) { | ||
| int qweight_cols, | ||
| int qweight_rows) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| #if CUDA_VERSION >= 12000 | ||
| int col = blockIdx.x * blockDim.x + threadIdx.x; | ||
| int row = blockIdx.y * blockDim.y + threadIdx.y; | ||
|
|
||
| if (col >= qweight_cols || row >= qweight_rows) return; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| int group_idx = row / group_size; | ||
| int scale_offset = 8 * col + group_idx * qweight_cols * 8; | ||
| uint4 loaded_scale = *(uint4*)(scales + scale_offset); | ||
|
|
@@ -188,8 +190,8 @@ torch::Tensor awq_dequantize(torch::Tensor qweight, torch::Tensor scales, torch: | |
|
|
||
| int x_num_threads = 16; | ||
| int y_num_threads = 16; | ||
| int x_blocks = qweight_cols / x_num_threads; | ||
| int y_blocks = qweight_rows / y_num_threads; | ||
| int x_blocks = (qweight_cols + x_num_threads - 1) / x_num_threads; | ||
| int y_blocks = (qweight_rows + y_num_threads - 1) / y_num_threads; | ||
|
Comment on lines
+193
to
+194
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The calculation of |
||
|
|
||
| const at::cuda::OptionalCUDAGuard device_guard(device_of(qweight)); | ||
|
|
||
|
|
@@ -207,12 +209,12 @@ torch::Tensor awq_dequantize(torch::Tensor qweight, torch::Tensor scales, torch: | |
| auto _scales = reinterpret_cast<half*>(scales.data_ptr<at::Half>()); | ||
| auto _output = reinterpret_cast<half*>(output.data_ptr<at::Half>()); | ||
| dequantize_weights<half> | ||
| <<<num_blocks, threads_per_block, 0, stream>>>(_qweight, _scales, _zeros, _output, group_size, qweight_cols); | ||
| <<<num_blocks, threads_per_block, 0, stream>>>(_qweight, _scales, _zeros, _output, group_size, qweight_cols, qweight_rows); | ||
|
||
| } else { | ||
| auto _scales = reinterpret_cast<__nv_bfloat16*>(scales.data_ptr<at::BFloat16>()); | ||
| auto _output = reinterpret_cast<__nv_bfloat16*>(output.data_ptr<at::BFloat16>()); | ||
| dequantize_weights<__nv_bfloat16> | ||
| <<<num_blocks, threads_per_block, 0, stream>>>(_qweight, _scales, _zeros, _output, group_size, qweight_cols); | ||
| <<<num_blocks, threads_per_block, 0, stream>>>(_qweight, _scales, _zeros, _output, group_size, qweight_cols, qweight_rows); | ||
|
||
| } | ||
|
|
||
| return output; | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -67,8 +67,10 @@ def sglang_awq_dequantize( | |
| "qweight_row,qweight_col,is_bf16_act", | ||
| list( | ||
| itertools.product( | ||
| [3584, 18944, 128, 256, 512, 1024], | ||
| [448, 576, 4736, 16, 32, 64, 128], | ||
| # [7168, 7168, 7168, 128, 128], | ||
| # [264, 192, 72, 16, 24], | ||
|
||
| [128, 128, 128, 128], | ||
| [16, 24, 32, 40], | ||
| [True, False], | ||
| ) | ||
| ), | ||
|
|
@@ -77,7 +79,6 @@ def test_awq_dequant_compare_implementations( | |
| qweight_row: int, qweight_col: int, is_bf16_act: bool | ||
| ): | ||
| device = torch.device("cuda") | ||
|
|
||
| qweight = torch.randint( | ||
| 0, | ||
| torch.iinfo(torch.int32).max, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The condition
self.quant_config.get_name() == "awq" or self.quant_config.get_name() == "moe_wna16"is repeated. Consider extracting this to a variable for better readability and maintainability.