Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlir-tensorrt/build_tools/docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ case "${LINUX_DISTRO}" in
dnf install -y \
which wget gcc zlib-devel bzip2 bzip2-devel readline-devel sqlite \
sqlite-devel xz xz-devel libffi-devel curl git ncurses-devel \
openssh-clients libcudnn8-devel zip jq \
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cudnn8 conflicts with cudnn9 in the base container.

openssh-clients zip jq \
protobuf-compiler autoconf automake libtool dnf-plugins-core cmake
dnf config-manager --set-enabled powertools
dnf -y install gcc-toolset-11-gcc gcc-toolset-11-gcc-c++
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -378,4 +378,42 @@ def TensorRT_ScatterMode : TensorRT_I32EnumAttr<
def TensorRT_ScatterModeAttr : TensorRT_EnumAttr<TensorRT_ScatterMode, "scatter_mode">{
}

def TensorRT_AttentionNormalizationOp : TensorRT_I32EnumAttr<
"AttentionNormalizationOp", "",
[
I32EnumAttrCase<"kNONE", 0>,
I32EnumAttrCase<"kSOFTMAX", 1>
]>
{
let cppNamespace = "::mlir::tensorrt";
let genSpecializedAttr = 0;
}

def TensorRT_AttentionNormalizationOpAttr : TensorRT_EnumAttr<TensorRT_AttentionNormalizationOp, "attention_normalization_op">{
}

def TensorRT_DataType : TensorRT_I32EnumAttr<
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We didn't already have datatype?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We didn't have an op that explicitly requires data type as an input, for example cast op uses MLIR data types

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do the same here? I'm assuming we could use the same helpers that cast uses?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cast uses output tensor type to indicate the data type. e.g.

tensorrt.cast %arg0 : tensor<3xf16> to tensor<3xf32>

but in attention op, normalization_quantize_to_type parameter is i) optional and ii) represents an intermidiate quantization data type.

"DataType", "",
[
I32EnumAttrCase<"kFLOAT", 0>,
I32EnumAttrCase<"kHALF", 1>,
I32EnumAttrCase<"kINT8", 2>,
I32EnumAttrCase<"kINT32", 3>,
I32EnumAttrCase<"kBOOL", 4>,
I32EnumAttrCase<"kUINT8", 5>,
I32EnumAttrCase<"kFP8", 6>,
I32EnumAttrCase<"kBF16", 7>,
I32EnumAttrCase<"kINT64", 8>,
I32EnumAttrCase<"kINT4", 9>,
I32EnumAttrCase<"kFP4", 10>,
I32EnumAttrCase<"kE8M0", 11>
]>
{
let cppNamespace = "::mlir::tensorrt";
let genSpecializedAttr = 0;
}

def TensorRT_DataTypeAttr : TensorRT_EnumAttr<TensorRT_DataType, "data_type">{
}

#endif // MLIR_TENSORRT_DIALECT_TENSORRT_IR_TENSORRTENUMS
Original file line number Diff line number Diff line change
Expand Up @@ -4432,4 +4432,172 @@ def TensorRT_ScatterElementsOp : TensorRT_Op<"scatter_elements",
}];
}

//===----------------------------------------------------------------------===//
// AttentionOp
//===----------------------------------------------------------------------===//

def TensorRT_AttentionOp : TensorRT_Op<"attention",
[Pure, AttrSizedOperandSegments, TensorRTInferTensorResultTypes,
AllElementTypesMatch<["query", "key", "value"]>,
AllRanksMatch<["query", "key", "value"]>]>{
let summary = "TensorRT attention (IAttention) operation";
let description = [{
The `tensorrt.attention` operation implements a fused attention mechanism
that consumes query, key, and value tensors. The operation implicitly includes
two matrix multiplication layers (BMM1 and BMM2) and a normalization operation
(typically softmax).

By default, TensorRT will try to use a single fused kernel for better efficiency.
The operation can optionally be decomposed into multiple kernels if no fused
kernel is available by setting `decomposable` to true.

#### Architecture:

```
Query Key Value Mask (optional) NormalizationQuantizeScale (optional)
| | | | |
| Transpose | | |
| | | | |
----BMM1---- | | |
| | | |
*--------------------------- |
| | |
Normalization | |
| | |
*------------------------------------------------
| |
-------BMM2------
|
Output
```

#### Inputs:

- Query: tensor of type f32, f16, or bf16 with shape
[batchSize, numHeadsQuery, sequenceLengthQuery, dimHead]
- Key: tensor of type f32, f16, or bf16 with shape
[batchSize, numHeadsKeyValue, sequenceLengthKeyValue, dimHead]
- Value: tensor of type f32, f16, or bf16 with shape
[batchSize, numHeadsKeyValue, sequenceLengthKeyValue, dimHead]
- Mask (optional): tensor of type i1 or same type as BMM1 output with shape
[batchSize, numHeadsQuery, sequenceLengthQuery, sequenceLengthKeyValue]
where batchSize and numHeadsQuery are broadcastable. For i1 mask, true
indicates the position is allowed to attend. For other types, mask values
are added to BMM1 output.
- NormalizationQuantizeScale (optional): tensor of type f32, f16, or bf16
with rank 0 or 1, used for quantizing the normalization output.

#### Attributes:

- normalization_operation: The normalization operation to use (default: kSOFTMAX)
- causal: Whether to use causal masking (default: false). Cannot be used with mask input.
- decomposable: Whether the operation can be decomposed (default: false)
- normalization_quantize_to_type: Optional output type for quantized normalization.
When specified, must be one of kFP8 or kINT8. Requires normalization_quantize_scale input to be provided.

#### Constraints:

- All query, key, and value tensors must be rank 4 with shape [batchSize, numHeads, sequenceLength, dimHead]
- Query, key, and value must have the same element type (f32, f16, or bf16)
- If normalization_quantize_to_type is specified:
* It must be kFP8 or kINT8
* normalization_quantize_scale input must be provided
- Cannot use both mask input and causal=true simultaneously

#### Examples:

Basic attention:
```mlir
%output = tensorrt.attention ins(%query, %key, %value :
tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>)
-> tensor<2x8x128x64xf16>
```

Causal attention:
```mlir
%output_causal = tensorrt.attention {causal = true} ins(%query, %key, %value :
tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>)
-> tensor<2x8x128x64xf16>
```

Attention with quantization:
```mlir
%scale = tensorrt.constant dense<1.0> : tensor<f32>
%output_quant = tensorrt.attention {
normalization_quantize_to_type = #tensorrt.data_type<kFP8>
} ins(%query, %key, %value,
normalization_quantize_scale = %scale :
tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>,
tensor<2x8x128x64xf16>, tensor<f32>)
-> tensor<2x8x128x64xf16>
```
}];

let arguments = (ins
TensorRT_RankedTensorOf<[F16, BF16, F32]>:$query,
TensorRT_RankedTensorOf<[F16, BF16, F32]>:$key,
TensorRT_RankedTensorOf<[F16, BF16, F32]>:$value,
Optional<TensorRT_Tensor>:$mask,
Optional<TensorRT_RankedTensorOf<[F16, BF16, F32]>>:$normalization_quantize_scale,
OptionalAttr<TensorRT_AttentionNormalizationOpAttr>:$normalization_operation,
DefaultValuedAttr<BoolAttr, "false">:$causal,
DefaultValuedAttr<BoolAttr, "false">:$decomposable,
OptionalAttr<TensorRT_DataTypeAttr>:$normalization_quantize_to_type
);

let results = (outs TensorRT_RankedTensorOf<[F16, BF16, F32]>:$result);

let assemblyFormat = [{
attr-dict `ins` `(` $query `,` $key `,` $value
(`,` `mask` `=` $mask^)?
(`,` `normalization_quantize_scale` `=` $normalization_quantize_scale^)?
`:` type($query) `,` type($key) `,` type($value)
(`,` type($mask)^)?
(`,` type($normalization_quantize_scale)^)?
`)` `->` type($result)
}];

let hasVerifier = 1;

let extraClassDeclaration = [{
/// Returns true if created op is valid for TensorRT major version.
bool isValidForTensorRTVersion(int64_t trtMajorVersion);
}] # baseClassDeclaration;

let trtLayerAdd = [{
// Get normalization operation, default to kSOFTMAX
nvinfer1::AttentionNormalizationOp normOp = $normalization_operation
? *$normalization_operation
: nvinfer1::AttentionNormalizationOp::kSOFTMAX;

nvinfer1::IAttention *layer = $net->addAttention(*$query, *$key, *$value, normOp, $causal);
if (!layer)
return failure();

if ($mask)
layer->setMask(*$mask);

layer->setDecomposable($decomposable);

if ($normalization_quantize_scale) {
layer->setNormalizationQuantizeScale(*$normalization_quantize_scale);
}

if ($normalization_quantize_to_type) {
layer->setNormalizationQuantizeToType(*$normalization_quantize_to_type);
}

if (!$e.isStronglyTyped()){
FailureOr<nvinfer1::DataType> outputTrtType = getNvInferDataType($op.getLoc(),
$op.getType().getElementType());
if (failed(outputTrtType))
return failure();
layer->setOutputType(0, *outputTrtType);
}

$results.push_back(layer->getOutput(0));
$e.setMetadata(layer, $op);
}];
}

#endif // MLIR_TENSORRT_DIALECT_TENSORRT_IR_TENSORRTOPS_TD
Original file line number Diff line number Diff line change
Expand Up @@ -914,3 +914,16 @@ bool tensorrt::ScatterElementsOp::isValidForTensorRTVersion(
return isValidForTensorRTVersionScatterOpImpl(
trtMajorVersion, dataElementType, indicesElementType);
}

//===----------------------------------------------------------------------===//
// AttentionOp
//===----------------------------------------------------------------------===//

bool tensorrt::AttentionOp::isValidForTensorRTVersion(
int64_t trtMajorVersion) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also check the minor version here.

// IAttention layer is only supported in TensorRT >= 10.14.0
if (trtMajorVersion < 10)
return false;

return true;
}