Skip to content

Commit

Permalink
[DML EP] Add QuickGelu (#15220)
Browse files Browse the repository at this point in the history
  • Loading branch information
PatriceVignola authored Apr 5, 2023
1 parent a96e19a commit 9191e04
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -1183,6 +1183,7 @@ Do not modify directly.*
|QLinearAdd|*in* A:**T**<br> *in* A_scale:**tensor(float)**<br> *in* A_zero_point:**T**<br> *in* B:**T**<br> *in* B_scale:**tensor(float)**<br> *in* B_zero_point:**T**<br> *in* C_scale:**tensor(float)**<br> *in* C_zero_point:**T**<br> *out* C:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QLinearSigmoid|*in* X:**T**<br> *in* X_scale:**tensor(float)**<br> *in* X_zero_point:**T**<br> *in* Y_scale:**tensor(float)**<br> *in* Y_zero_point:**T**<br> *out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QuantizeLinear|*in* x:**T1**<br> *in* y_scale:**T1**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)|
|QuickGelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
| |
| |
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/optimizer/graph_transformer_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
transformers.emplace_back(std::make_unique<SkipLayerNormFusion>(cpu_cuda_dml_rocm_eps));

transformers.emplace_back(std::make_unique<FastGeluFusion>(cpu_cuda_rocm_eps));
transformers.emplace_back(std::make_unique<QuickGeluFusion>(cpu_cuda_rocm_eps));
transformers.emplace_back(std::make_unique<QuickGeluFusion>(cpu_cuda_dml_rocm_eps));

transformers.emplace_back(std::make_unique<MatMulScaleFusion>(cpu_cuda_dml_rocm_eps));

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "precomp.h"

namespace Dml
{

class DmlOperatorQuickGelu : public DmlOperator
{
public:
DmlOperatorQuickGelu(const MLOperatorKernelCreationContext& kernelCreationContext)
: DmlOperator(kernelCreationContext)
{
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() == 1);
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() == 1);
DmlOperator::Initialize(kernelCreationContext);

ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs.size() == 1);
ML_CHECK_VALID_ARGUMENT(m_outputTensorDescs.size() == 1);
const float alpha = kernelCreationContext.GetAttribute<float>(AttrName::Alpha);

std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();

// 1. Apply the alpha if needed
DML_SCALE_BIAS scaleBias{alpha, 0.0f};
DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC mulAlphaDesc{};
if (alpha != 1.0f)
{
mulAlphaDesc.InputTensor = &inputDescs[0];
mulAlphaDesc.OutputTensor = &inputDescs[0];
mulAlphaDesc.ScaleBias = &scaleBias;
}
DML_OPERATOR_DESC dmlMulAlphaDesc = { DML_OPERATOR_ELEMENT_WISE_IDENTITY, &mulAlphaDesc };

// 2. Apply the sigmoid activation function
DML_ACTIVATION_SIGMOID_OPERATOR_DESC sigmoidDesc{};
sigmoidDesc.InputTensor = &inputDescs[0];
sigmoidDesc.OutputTensor = &inputDescs[0];
DML_OPERATOR_DESC dmlSigmoidDesc = { DML_OPERATOR_ACTIVATION_SIGMOID, &sigmoidDesc };

// 3. Multiply the sigmoid result with the original input
DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC multiplyDesc{};
multiplyDesc.ATensor = &inputDescs[0];
multiplyDesc.BTensor = &inputDescs[0];
multiplyDesc.OutputTensor = &inputDescs[0];
DML_OPERATOR_DESC dmlMultiplyDesc = { DML_OPERATOR_ELEMENT_WISE_MULTIPLY, &multiplyDesc };

enum NodeIndex
{
sigmoidNodeIndex,
multiplyNodeIndex,
mulAlphaNodeIndex,
nodeCount,
};

// Construct the graph
std::vector<const DML_OPERATOR_DESC*> opDescs;
opDescs.reserve(3);
opDescs.push_back(&dmlSigmoidDesc);
opDescs.push_back(&dmlMultiplyDesc);

std::vector<DML_INPUT_GRAPH_EDGE_DESC> inputEdges;
inputEdges.reserve(2);

std::vector<DML_INTERMEDIATE_GRAPH_EDGE_DESC> intermediateEdges;
intermediateEdges.reserve(2);

std::vector<DML_OUTPUT_GRAPH_EDGE_DESC> outputEdges;
outputEdges.reserve(1);

if (alpha != 1.0f)
{
opDescs.push_back(&dmlMulAlphaDesc);

DML_INPUT_GRAPH_EDGE_DESC inputToMulAlphaEdge{};
inputToMulAlphaEdge.GraphInputIndex = 0;
inputToMulAlphaEdge.ToNodeIndex = mulAlphaNodeIndex;
inputToMulAlphaEdge.ToNodeInputIndex = 0;
inputEdges.push_back(inputToMulAlphaEdge);

DML_INTERMEDIATE_GRAPH_EDGE_DESC mulAlphaToSigmoidEdge{};
mulAlphaToSigmoidEdge.FromNodeIndex = mulAlphaNodeIndex;
mulAlphaToSigmoidEdge.FromNodeOutputIndex = 0;
mulAlphaToSigmoidEdge.ToNodeIndex = sigmoidNodeIndex;
mulAlphaToSigmoidEdge.ToNodeInputIndex = 0;
intermediateEdges.push_back(mulAlphaToSigmoidEdge);
}
else
{
DML_INPUT_GRAPH_EDGE_DESC inputToSigmoidEdge{};
inputToSigmoidEdge.GraphInputIndex = 0;
inputToSigmoidEdge.ToNodeIndex = sigmoidNodeIndex;
inputToSigmoidEdge.ToNodeInputIndex = 0;
inputEdges.push_back(inputToSigmoidEdge);
}

DML_INPUT_GRAPH_EDGE_DESC inputToMultiplyEdge{};
inputToMultiplyEdge.GraphInputIndex = 0;
inputToMultiplyEdge.ToNodeIndex = multiplyNodeIndex;
inputToMultiplyEdge.ToNodeInputIndex = 0;
inputEdges.push_back(inputToMultiplyEdge);

DML_INTERMEDIATE_GRAPH_EDGE_DESC sigmoidToMultiplyEdge{};
sigmoidToMultiplyEdge.FromNodeIndex = sigmoidNodeIndex;
sigmoidToMultiplyEdge.FromNodeOutputIndex = 0;
sigmoidToMultiplyEdge.ToNodeIndex = multiplyNodeIndex;
sigmoidToMultiplyEdge.ToNodeInputIndex = 1;
intermediateEdges.push_back(sigmoidToMultiplyEdge);

DML_OUTPUT_GRAPH_EDGE_DESC multiplyToOutputEdge{};
multiplyToOutputEdge.FromNodeIndex = multiplyNodeIndex;
multiplyToOutputEdge.FromNodeOutputIndex = 0;
multiplyToOutputEdge.GraphOutputIndex = 0;
outputEdges.push_back(multiplyToOutputEdge);

MLOperatorGraphDesc operatorGraphDesc = {};
operatorGraphDesc.inputEdgeCount = gsl::narrow_cast<uint32_t>(inputEdges.size());
operatorGraphDesc.inputEdges = inputEdges.data();
operatorGraphDesc.intermediateEdgeCount = gsl::narrow_cast<uint32_t>(intermediateEdges.size());
operatorGraphDesc.intermediateEdges = intermediateEdges.data();
operatorGraphDesc.outputEdgeCount = gsl::narrow_cast<uint32_t>(outputEdges.size());
operatorGraphDesc.outputEdges = outputEdges.data();
operatorGraphDesc.nodeCount = gsl::narrow_cast<uint32_t>(opDescs.size());
operatorGraphDesc.nodesAsOpDesc = opDescs.data();
SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext);
}
};

DML_OP_DEFINE_CREATION_FUNCTION(QuickGelu, DmlOperatorQuickGelu);

} // namespace Dml
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(Shape);
DML_OP_EXTERN_CREATION_FUNCTION(Size);
DML_OP_EXTERN_CREATION_FUNCTION(Attention);
DML_OP_EXTERN_CREATION_FUNCTION(NonZero);
DML_OP_EXTERN_CREATION_FUNCTION(QuickGelu);

DML_OP_EXTERN_QUERY_FUNCTION(MaxPool);
DML_OP_EXTERN_QUERY_FUNCTION(Slice);
Expand Down Expand Up @@ -878,6 +879,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 7, LayerNormalization, typeNameListLayerNormContrib, supportedTypeListLayerNormalizationContrib, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryLayerNormalization)},
{REG_INFO_MS( 1, SkipLayerNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QuerySkipLayerNormalization)},
{REG_INFO_MS( 1, EmbedLayerNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, QuickGelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, GroupNorm, typeNameListGroupNorm, supportedTypeListGroupNorm, DmlGraphSupport::Supported)},
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1546,6 +1546,7 @@ using ShapeInferenceHelper_IsInf = GetBroadcastedOutputShapeHelper;
using ShapeInferenceHelper_Mod = GetBroadcastedOutputShapeHelper;
using ShapeInferenceHelper_BitShift= GetBroadcastedOutputShapeHelper;
using ShapeInferenceHelper_Round = GetBroadcastedOutputShapeHelper;
using ShapeInferenceHelper_QuickGelu = GetOutputShapeAsInputShapeHelper;

using ShapeInferenceHelper_ReduceSum = ReduceHelper;
using ShapeInferenceHelper_ReduceMean = ReduceHelper;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,7 @@ namespace OperatorHelper
static const int sc_sinceVer_Attention = 1;
static const int sc_sinceVer_SkipLayerNormalization = 1;
static const int sc_sinceVer_EmbedLayerNormalization = 1;
static const int sc_sinceVer_QuickGelu = 1;
static const int sc_sinceVer_GroupNorm = 1;
} // namespace MsftOperatorSet1

Expand Down

0 comments on commit 9191e04

Please sign in to comment.