Skip to content

Commit 48259af

Browse files
committed
Improve
Signed-off-by: Shu Wang. <[email protected]>
1 parent d972fbe commit 48259af

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

csrc/nv_internal/cpp/kernels/quantization.cu

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -299,13 +299,15 @@ void invokeSiluAndMulNVFP4Quantization(void* output, void* output_scale, void* i
299299
void* input_global_scale, void* mask, bool use_silu_and_mul,
300300
int m_topk, int k, int n_experts, cudaStream_t stream) {
301301
int device;
302-
cudaGetDevice(&device);
302+
TLLM_CUDA_CHECK(cudaGetDevice(&device));
303303
int multiProcessorCount;
304-
cudaDeviceGetAttribute(&multiProcessorCount, cudaDevAttrMultiProcessorCount, device);
304+
TLLM_CUDA_CHECK(
305+
cudaDeviceGetAttribute(&multiProcessorCount, cudaDevAttrMultiProcessorCount, device));
305306

306307
// Grid, Block size.
307308
// Each thread converts 8 values.
308-
int const workSizePerRow = k / CVT_ELTS_PER_THREAD;
309+
TLLM_CHECK_WITH_INFO(k > 0, "k must be > 0");
310+
int const workSizePerRow = max(1, k / CVT_ELTS_PER_THREAD);
309311
int const totalWorkSize = m_topk * workSizePerRow;
310312
dim3 block(std::min(workSizePerRow, 512));
311313
// Get number of blocks per SM (assume we can fully utilize the SM).
@@ -320,6 +322,7 @@ void invokeSiluAndMulNVFP4Quantization(void* output, void* output_scale, void* i
320322
// TODO(kaixih@nvidia): Should relax this to allow any grid size.
321323
// [email protected]: only deal with mask case
322324
TLLM_CHECK_WITH_INFO(mask != nullptr, "mask must be non-null for expert NVFP4 path");
325+
TLLM_CHECK_WITH_INFO(n_experts > 0, "n_experts must be > 0");
323326
grid.x = (grid.x + n_experts - 1) / n_experts * n_experts;
324327
cvt_fp16_to_fp4_expert<T, false><<<grid, block, 0, stream>>>(
325328
m_topk, k, reinterpret_cast<T*>(input), reinterpret_cast<float*>(input_global_scale),

csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,11 @@ void silu_and_mul_scaled_nvfp4_experts_quantize(Tensor output, Tensor output_sca
196196
CHECK_CUDA(input);
197197
CHECK_CUDA(input_global_scale);
198198
CHECK_CUDA(mask);
199+
CHECK_CONTIGUOUS(output);
200+
CHECK_CONTIGUOUS(output_scale);
201+
CHECK_CONTIGUOUS(input);
202+
CHECK_CONTIGUOUS(input_global_scale);
203+
CHECK_CONTIGUOUS(mask);
199204

200205
TVM_FFI_ICHECK_EQ(mask.ndim(), 1);
201206
TVM_FFI_ICHECK_EQ(output.ndim(), 2);
@@ -210,14 +215,15 @@ void silu_and_mul_scaled_nvfp4_experts_quantize(Tensor output, Tensor output_sca
210215
CHECK_INPUT_TYPE(output, uint8_dtype);
211216
CHECK_INPUT_TYPE(output_scale, int32_dtype);
212217

213-
const int BLOCK_SIZE = 16;
218+
constexpr int BLOCK_SIZE = 16;
214219
auto m_topk = input.shape()[0];
215220
auto k_by_2 = input.shape()[1];
216221
auto k = k_by_2;
217222
if (use_silu_and_mul) {
218223
TVM_FFI_ICHECK_EQ(k_by_2 % 2, 0) << "k must be a multiple of 2";
219224
k = k_by_2 / 2;
220225
}
226+
TVM_FFI_ICHECK_EQ(k % BLOCK_SIZE, 0) << "k must be a multiple of 16";
221227
auto n_experts = input_global_scale.shape()[0];
222228
TVM_FFI_ICHECK_EQ(mask.shape()[0], n_experts);
223229
TVM_FFI_ICHECK_EQ(output.shape()[0], m_topk);

0 commit comments

Comments
 (0)