-
Notifications
You must be signed in to change notification settings - Fork 18
Add AttentionOp #708
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add AttentionOp #708
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -378,4 +378,42 @@ def TensorRT_ScatterMode : TensorRT_I32EnumAttr< | |
| def TensorRT_ScatterModeAttr : TensorRT_EnumAttr<TensorRT_ScatterMode, "scatter_mode">{ | ||
| } | ||
|
|
||
| def TensorRT_AttentionNormalizationOp : TensorRT_I32EnumAttr< | ||
| "AttentionNormalizationOp", "", | ||
| [ | ||
| I32EnumAttrCase<"kNONE", 0>, | ||
| I32EnumAttrCase<"kSOFTMAX", 1> | ||
| ]> | ||
| { | ||
| let cppNamespace = "::mlir::tensorrt"; | ||
| let genSpecializedAttr = 0; | ||
| } | ||
|
|
||
| def TensorRT_AttentionNormalizationOpAttr : TensorRT_EnumAttr<TensorRT_AttentionNormalizationOp, "attention_normalization_op">{ | ||
| } | ||
|
|
||
| def TensorRT_DataType : TensorRT_I32EnumAttr< | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We didn't already have datatype?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We didn't have an op that explicitly requires data type as an input, for example cast op uses MLIR data types
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we do the same here? I'm assuming we could use the same helpers that cast uses?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cast uses output tensor type to indicate the data type. e.g. but in attention op, |
||
| "DataType", "", | ||
| [ | ||
| I32EnumAttrCase<"kFLOAT", 0>, | ||
| I32EnumAttrCase<"kHALF", 1>, | ||
| I32EnumAttrCase<"kINT8", 2>, | ||
| I32EnumAttrCase<"kINT32", 3>, | ||
| I32EnumAttrCase<"kBOOL", 4>, | ||
| I32EnumAttrCase<"kUINT8", 5>, | ||
| I32EnumAttrCase<"kFP8", 6>, | ||
| I32EnumAttrCase<"kBF16", 7>, | ||
| I32EnumAttrCase<"kINT64", 8>, | ||
| I32EnumAttrCase<"kINT4", 9>, | ||
| I32EnumAttrCase<"kFP4", 10>, | ||
| I32EnumAttrCase<"kE8M0", 11> | ||
| ]> | ||
| { | ||
| let cppNamespace = "::mlir::tensorrt"; | ||
| let genSpecializedAttr = 0; | ||
| } | ||
|
|
||
| def TensorRT_DataTypeAttr : TensorRT_EnumAttr<TensorRT_DataType, "data_type">{ | ||
| } | ||
|
|
||
| #endif // MLIR_TENSORRT_DIALECT_TENSORRT_IR_TENSORRTENUMS | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -914,3 +914,16 @@ bool tensorrt::ScatterElementsOp::isValidForTensorRTVersion( | |
| return isValidForTensorRTVersionScatterOpImpl( | ||
| trtMajorVersion, dataElementType, indicesElementType); | ||
| } | ||
|
|
||
| //===----------------------------------------------------------------------===// | ||
| // AttentionOp | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| bool tensorrt::AttentionOp::isValidForTensorRTVersion( | ||
| int64_t trtMajorVersion) { | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should also check the minor version here. |
||
| // IAttention layer is only supported in TensorRT >= 10.14.0 | ||
| if (trtMajorVersion < 10) | ||
| return false; | ||
|
|
||
| return true; | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cudnn8 conflicts with cudnn9 in the base container.