Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpp/tensorrt_llm/thop/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ add_library(
fp4Quantize.cpp
fp4BatchedQuantize.cpp
fp8BlockScalingGemm.cpp
fp8RowwiseGemm.cpp
fp8Quantize.cpp
dsv3FusedAGemmOp.cpp
fusedQKNormRopeOp.cpp
Expand Down
192 changes: 192 additions & 0 deletions cpp/tensorrt_llm/thop/fp8RowwiseGemm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "cutlass_extensions/gemm_configs.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/kernels/cutlass_kernels/fp8_rowwise_gemm/fp8_rowwise_gemm.h"
#include "tensorrt_llm/thop/thUtils.h"
#include "tensorrt_llm/thop/userbuffersTensor.h"

#include <ATen/cuda/EmptyTensor.h>
#include <ATen/native/cuda/Resize.h>

#include <cstddef>
#include <cuda_fp16.h>

#include <cstdint>
#include <functional>
#include <type_traits>
#include <vector>

using tensorrt_llm::kernels::cutlass_kernels::CutlassFp8RowwiseGemmRunner;
using tensorrt_llm::kernels::cutlass_kernels::CutlassFp8RowwiseGemmRunnerInterface;

namespace torch_ext
{

namespace
{
void check_input_dtypes(torch::Tensor const& mat, torch::Tensor const& matScale)
{
TORCH_CHECK(mat.scalar_type() == at::ScalarType::Float8_e4m3fn,
"Matrix dtype must be FP8 (the matrix will be dequantized on the fly).");

CHECK_INPUT(matScale, FP8_ROWWISE_SF_DTYPE);
}
} // namespace

template <typename OutputType>
torch::Tensor fp8_rowwise_gemm_launch(torch::Tensor const& mat1, torch::Tensor const& mat2,
torch::Tensor const& mat1Scale, torch::Tensor const& mat2Scale, bool to_userbuffers = false,
tkc::CutlassGemmConfig const* maybe_config = nullptr)
{
check_input_dtypes(mat1, mat1Scale);
check_input_dtypes(mat2, mat2Scale);

TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix");
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix");
TORCH_CHECK(mat1.sizes()[1] == mat2.sizes()[1], "mat1 and mat2 shapes cannot be multiplied (", mat1.sizes()[0], "x",
mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
TORCH_CHECK(mat1.sizes()[0] == mat1Scale.sizes()[0],
"mat1Scale should be per-token scale, but got m=", mat1.sizes()[0], ", scale_dim=", mat1Scale.sizes()[0], ".");
TORCH_CHECK(mat2.sizes()[0] == mat2Scale.sizes()[0],
"mat2Scale should be per-channel scale, but got n=", mat2.sizes()[0], ", scale_dim=", mat2Scale.sizes()[0],
".");

auto const m = mat1.sizes()[0];
auto const n = mat2.sizes()[0];
auto const k = mat1.sizes()[1];

static_assert(std::is_same<OutputType, half>::value || std::is_same<OutputType, __nv_bfloat16>::value,
"Output type must be half or bfloat16");
static constexpr auto outType
= std::is_same<OutputType, half>::value ? at::ScalarType::Half : at::ScalarType::BFloat16;
at::Tensor out;
if (to_userbuffers)
{
out = torch_ext::create_userbuffers_tensor({m, n}, outType).first;
}
else
{
out = at::detail::empty_cuda({m, n}, outType, mat1.device(), std::nullopt);
}

auto stream = at::cuda::getCurrentCUDAStream(mat1.get_device());

auto mGemmRunner = std::make_shared<CutlassFp8RowwiseGemmRunner<OutputType>>();
int64_t const wsSize = mGemmRunner->getWorkspaceSize(m, n, k);
auto gemmConfig = maybe_config ? *maybe_config : mGemmRunner->getConfigs()[0];
at::Tensor workspace = at::detail::empty_cuda({wsSize}, at::ScalarType::Char, torch::kCUDA, std::nullopt);

OutputType* outPtr = reinterpret_cast<OutputType*>(out.data_ptr());
__nv_fp8_e4m3 const* mat1Ptr = reinterpret_cast<__nv_fp8_e4m3 const*>(mat1.data_ptr());
__nv_fp8_e4m3 const* mat2Ptr = reinterpret_cast<__nv_fp8_e4m3 const*>(mat2.data_ptr());
float const* mat1ScalePtr = reinterpret_cast<float const*>(mat1Scale.data_ptr());
float const* mat2ScalePtr = reinterpret_cast<float const*>(mat2Scale.data_ptr());
char* workspacePtr = reinterpret_cast<char*>(workspace.data_ptr());

tensorrt_llm::common::QuantMode quantMode = tensorrt_llm::common::QuantMode::fp8RowWise();
mGemmRunner->gemm(outPtr, mat1Ptr, mat2Ptr, nullptr, quantMode, m, n, k, mat1ScalePtr, mat2ScalePtr, gemmConfig,
workspacePtr, wsSize, stream);

return out;
}

template torch::Tensor fp8_rowwise_gemm_launch<half>(torch::Tensor const& mat1, torch::Tensor const& mat2,
torch::Tensor const& mat1Scale, torch::Tensor const& mat2Scale, bool to_userbuffers = false,
tkc::CutlassGemmConfig const* maybe_config = nullptr);
template torch::Tensor fp8_rowwise_gemm_launch<__nv_bfloat16>(torch::Tensor const& mat1, torch::Tensor const& mat2,
torch::Tensor const& mat1Scale, torch::Tensor const& mat2Scale, bool to_userbuffers = false,
tkc::CutlassGemmConfig const* maybe_config = nullptr);

torch::Tensor fp8_rowwise_gemm_dispatch(torch::Tensor const& mat1, torch::Tensor const& mat2,
torch::Tensor const& mat1Scale, torch::Tensor const& mat2Scale, at::ScalarType outDataType,
bool to_userbuffers = false, tkc::CutlassGemmConfig const* maybe_config = nullptr)
{
// The functional version of this op does not do any profiling; use the profiler class below instead for
// better performance.
// Note that we can still add a heuristic here.
switch (outDataType)
{
case at::ScalarType::Half:
return fp8_rowwise_gemm_launch<half>(mat1, mat2, mat1Scale, mat2Scale, to_userbuffers, maybe_config);
#ifdef ENABLE_BF16
case at::ScalarType::BFloat16:
return fp8_rowwise_gemm_launch<__nv_bfloat16>(mat1, mat2, mat1Scale, mat2Scale, to_userbuffers, maybe_config);
#endif
default: TORCH_CHECK(false, "Unsupported output dtype for FP8 block scaling GEMM");
}
}

class FP8RowwiseGemmRunner : public torch::CustomClassHolder
{
public:
explicit FP8RowwiseGemmRunner(at::ScalarType outputDtype)
: mOutputDtype(outputDtype)
{
if (outputDtype == at::ScalarType::Half)
{
mGemmRunner = std::make_unique<CutlassFp8RowwiseGemmRunner<half>>();
}
#ifdef ENABLE_BF16
else if (outputDtype == at::ScalarType::BFloat16)
{
mGemmRunner = std::make_unique<CutlassFp8RowwiseGemmRunner<__nv_bfloat16>>();
}
#endif
else
{
C10_THROW_ERROR(NotImplementedError, "out_dtype must be one of fp16/bf16.");
}
mConfigs = mGemmRunner->getConfigs();
}

at::Tensor runGemm(at::Tensor const& mat1, at::Tensor const& mat2, at::Tensor const& mat1Scale,
at::Tensor const& mat2Scale, bool to_userbuffers, int64_t configIdx) const
{
tkc::CutlassGemmConfig const* config = nullptr;
if (configIdx != -1)
{
TORCH_CHECK(configIdx >= 0 && configIdx < getNumConfigs());
config = &mConfigs.at(configIdx);
}
return fp8_rowwise_gemm_dispatch(mat1, mat2, mat1Scale, mat2Scale, mOutputDtype, to_userbuffers, config);
}

at::ScalarType getOutputDtype() const
{
return mOutputDtype;
}

int64_t getNumConfigs() const
{
return static_cast<int64_t>(mConfigs.size());
}

private:
std::shared_ptr<CutlassFp8RowwiseGemmRunnerInterface> mGemmRunner{nullptr};
std::vector<tkc::CutlassGemmConfig> mConfigs;
at::ScalarType mOutputDtype;
};
} // namespace torch_ext

TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.class_<torch_ext::FP8RowwiseGemmRunner>("FP8RowwiseGemmRunner")
.def(torch::init<at::ScalarType>())
.def("run_gemm", &torch_ext::FP8RowwiseGemmRunner::runGemm)
.def("get_num_configs", &torch_ext::FP8RowwiseGemmRunner::getNumConfigs);
}
1 change: 1 addition & 0 deletions cpp/tensorrt_llm/thop/thUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ constexpr auto FLOAT4_E2M1X2 = torch::ScalarType::Byte; // uint8_t
constexpr auto SF_DTYPE = torch::ScalarType::Byte; // uint8_t

constexpr auto FP8_BLOCK_SCALING_SF_DTYPE = torch::ScalarType::Float;
constexpr auto FP8_ROWWISE_SF_DTYPE = torch::ScalarType::Float;

template <typename T>
inline T* get_ptr(torch::Tensor& t)
Expand Down
9 changes: 6 additions & 3 deletions examples/models/core/qwen/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,7 @@ trtllm-eval --model=Qwen3-30B-A3B/ --tokenizer=Qwen3-30B-A3B/ --backend=pytorch

```

### Model Quantization to FP4
### Model Quantization

To quantize the Qwen3 model for use with the PyTorch backend, we'll use NVIDIA's Model Optimizer (ModelOpt) tool. Follow these steps:

Expand All @@ -669,12 +669,15 @@ pushd TensorRT-Model-Optimizer
pip install -e .

# Quantize the Qwen3-235B-A22B model by nvfp4
# By default, the checkpoint would be stored in `TensorRT-Model-Optimizer/examples/llm_ptq/saved_models_Qwen3-235B-A22B_nvfp4_hf/`.
./examples/llm_ptq/scripts/huggingface_example.sh --model Qwen3-235B-A22B/ --quant nvfp4 --export_fmt hf

# Quantize the Qwen3-32B model by fp8_pc_pt
# By default, the checkpoint would be stored in `TensorRT-Model-Optimizer/examples/llm_ptq/saved_models_Qwen3-32B_fp8_pc_pt_hf/`.
./examples/llm_ptq/scripts/huggingface_example.sh --model Qwen3-32B/ --quant fp8_pc_pt --export_fmt hf
popd
```

By default, the checkpoint would be stored in `TensorRT-Model-Optimizer/examples/llm_ptq/saved_models_Qwen3-235B-A22B_nvfp4_hf/`.

### Benchmark

To run the benchmark, we suggest using the `trtllm-bench` tool. Please refer to the following script on B200:
Expand Down
6 changes: 4 additions & 2 deletions tensorrt_llm/_torch/attention_backend/trtllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,6 +1062,8 @@ def update_quant_config(self, new_quant_config: Optional[QuantConfig]):
self.has_fp8_qdq = self.quant_config.layer_quant_mode.has_fp8_qdq()
self.has_fp8_block_wise = self.quant_config.layer_quant_mode.has_fp8_block_scales(
)
self.has_fp8_rowwise = self.quant_config.layer_quant_mode.has_fp8_rowwise(
)
self.has_nvfp4 = self.quant_config.layer_quant_mode.has_nvfp4()

def get_local_layer_idx(self, metadata: TrtllmAttentionMetadata) -> int:
Expand Down Expand Up @@ -1167,8 +1169,8 @@ def forward(
if use_nvfp4_output:
# Use UINT8 as the container dtype for NVFP4.
out_dtype = torch.uint8
elif (self.has_fp8_qdq or self.has_nvfp4
or self.has_fp8_block_wise) and self.has_fp8_kv_cache:
elif (self.has_fp8_qdq or self.has_nvfp4 or self.has_fp8_block_wise
or self.has_fp8_rowwise) and self.has_fp8_kv_cache:
# TODO(qijun): revisit fp8_context_fmha logic
out_dtype = torch.float8_e4m3fn

Expand Down
87 changes: 87 additions & 0 deletions tensorrt_llm/_torch/custom_ops/torch_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,93 @@ def _(
return [input.new_empty([seq_len, hidden_size], dtype=output_dtype)]


class FP8RowwiseGemmRunner(TunableRunner):
runner_dict = dict()
tuning_config = TuningConfig(
dynamic_tensor_specs=(DynamicTensorSpec(
0, 0, get_last_power_of_2_num_tokens_buckets,
last_positive_power_of_2), ),
constraint_specs=(
ConstraintSpec(2, 0, lambda shapes: shapes[0][0]),
ConstraintSpec(3, 0, lambda shapes: shapes[1][0]),
))

def __init__(
self,
to_userbuffers: bool,
output_dtype: torch.dtype,
):
self.to_userbuffers = to_userbuffers
self.output_dtype = output_dtype
instance_key = (output_dtype, )
if instance_key not in FP8RowwiseGemmRunner.runner_dict:
FP8RowwiseGemmRunner.runner_dict[
instance_key] = torch.classes.trtllm.FP8RowwiseGemmRunner(
output_dtype)
self.fp8_rowwise_gemm_runner = FP8RowwiseGemmRunner.runner_dict[
instance_key]

def get_valid_tactics(
self,
inputs: List[torch.Tensor],
profile: OptimizationProfile,
) -> List[int]:
return list(range(self.fp8_rowwise_gemm_runner.get_num_configs()))

def forward(
self,
inputs: List[torch.Tensor],
tactic: int = -1,
) -> torch.Tensor:
mat1, mat2, mat1_scale, mat2_scale = inputs
return self.fp8_rowwise_gemm_runner.run_gemm(
mat1,
mat2,
mat1_scale,
mat2_scale,
self.to_userbuffers,
tactic,
)


@torch.library.custom_op("trtllm::fp8_rowwise_gemm", mutates_args=())
def fp8_rowwise_gemm(
act: torch.Tensor,
weight: torch.Tensor,
act_scale: torch.Tensor,
weight_scale: torch.Tensor,
output_dtype: torch.dtype,
to_userbuffers: bool = False,
) -> torch.Tensor:

tuner = AutoTuner.get()

# allocate workspace for profiling
fp8_rowwise_gemm_runner = FP8RowwiseGemmRunner(to_userbuffers, output_dtype)

_, best_tactic = tuner.choose_one(
"trtllm::fp8_rowwise_gemm::gemm",
[fp8_rowwise_gemm_runner],
FP8RowwiseGemmRunner.tuning_config,
[act, weight, act_scale, weight_scale],
)

return fp8_rowwise_gemm_runner(
inputs=[act, weight, act_scale, weight_scale], tactic=best_tactic)


@fp8_rowwise_gemm.register_fake
def _(
act: torch.Tensor,
weight: torch.Tensor,
act_scale: torch.Tensor,
weight_scale: torch.Tensor,
output_dtype: torch.dtype,
to_userbuffers: bool = False,
) -> torch.Tensor:
return act.new_empty((act.size(0), weight.size(0)), dtype=output_dtype)


class FP4GemmRunner(TunableRunner):
runner_dict = dict()
tuning_config = TuningConfig(dynamic_tensor_specs=(DynamicTensorSpec(
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def forward(

out_scale = None
out_scale_sf = None
if self.o_proj.has_fp8_qdq or self.o_proj.has_nvfp4 or self.o_proj.has_fp8_block_scales:
if self.o_proj.has_fp8_qdq or self.o_proj.has_nvfp4 or self.o_proj.has_fp8_block_scales or self.o_proj.has_fp8_rowwise:
out_scale = self.o_proj.inv_input_scale
if self.o_proj.has_nvfp4 and self.support_nvfp4_output:
out_scale_sf = self.o_proj.input_scale
Expand Down
Loading