Skip to content

Commit e2d771e

Browse files
committed
upd
1 parent c52c080 commit e2d771e

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -236,16 +236,16 @@ void silu_and_mul_scaled_nvfp4_experts_quantize(Tensor output, Tensor output_sca
236236
// 4 means 4 fp8 values are packed into one int32
237237
TVM_FFI_ICHECK_EQ(output_scale.shape()[1] * 4, padded_k);
238238

239-
auto in_dtype = input->dtype;
240-
const cudaStream_t stream = get_stream(input->device);
239+
auto in_dtype = input.dtype();
240+
const cudaStream_t stream = get_stream(input.device());
241241
if (in_dtype == dl_float16) {
242242
tensorrt_llm::kernels::invokeSiluAndMulNVFP4Quantization<half>(
243-
output->data, output_scale->data, input->data, input_global_scale->data, mask->data,
244-
use_silu_and_mul, m_topk, k, n_experts, stream);
243+
output.data_ptr(), output_scale.data_ptr(), input.data_ptr(), input_global_scale.data_ptr(),
244+
mask.data_ptr(), use_silu_and_mul, m_topk, k, n_experts, stream);
245245
} else if (in_dtype == dl_bfloat16) {
246246
tensorrt_llm::kernels::invokeSiluAndMulNVFP4Quantization<__nv_bfloat16>(
247-
output->data, output_scale->data, input->data, input_global_scale->data, mask->data,
248-
use_silu_and_mul, m_topk, k, n_experts, stream);
247+
output.data_ptr(), output_scale.data_ptr(), input.data_ptr(), input_global_scale.data_ptr(),
248+
mask.data_ptr(), use_silu_and_mul, m_topk, k, n_experts, stream);
249249
} else {
250250
TVM_FFI_LOG_AND_THROW(NotImplementedError) << "silu_and_mul_scaled_nvfp4_experts_quantize only "
251251
"supports input tensor with dtypes fp16/bf16.";

0 commit comments

Comments
 (0)