Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ set(contrib_ops_excluded_files
"bert/attention.h"
"bert/attention_impl.cu"
"bert/attention_softmax.h"
"bert/cross_attention.cc"
"bert/cross_attention.h"
"bert/multihead_attention.cc"
"bert/multihead_attention.h"
"bert/embed_layer_norm.cc"
"bert/embed_layer_norm.h"
"bert/embed_layer_norm_impl.cu"
Expand Down
104 changes: 52 additions & 52 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ Do not modify directly.*
* <a href="#com.microsoft.ComplexMulConj">com.microsoft.ComplexMulConj</a>
* <a href="#com.microsoft.ConvTransposeWithDynamicPads">com.microsoft.ConvTransposeWithDynamicPads</a>
* <a href="#com.microsoft.CropAndResize">com.microsoft.CropAndResize</a>
* <a href="#com.microsoft.CrossAttention">com.microsoft.CrossAttention</a>
* <a href="#com.microsoft.DecoderAttention">com.microsoft.DecoderAttention</a>
* <a href="#com.microsoft.DequantizeBFP">com.microsoft.DequantizeBFP</a>
* <a href="#com.microsoft.DequantizeLinear">com.microsoft.DequantizeLinear</a>
Expand All @@ -42,6 +41,7 @@ Do not modify directly.*
* <a href="#com.microsoft.MatMulIntegerToFloat">com.microsoft.MatMulIntegerToFloat</a>
* <a href="#com.microsoft.MaxpoolWithMask">com.microsoft.MaxpoolWithMask</a>
* <a href="#com.microsoft.MulInteger">com.microsoft.MulInteger</a>
* <a href="#com.microsoft.MultiHeadAttention">com.microsoft.MultiHeadAttention</a>
* <a href="#com.microsoft.MurmurHash3">com.microsoft.MurmurHash3</a>
* <a href="#com.microsoft.NGramRepeatBlock">com.microsoft.NGramRepeatBlock</a>
* <a href="#com.microsoft.NhwcConv">com.microsoft.NhwcConv</a>
Expand Down Expand Up @@ -955,57 +955,6 @@ This version of the operator has been available since version 1 of the 'com.micr
</dl>


### <a name="com.microsoft.CrossAttention"></a><a name="com.microsoft.crossattention">**com.microsoft.CrossAttention**</a>

Multi-Head Cross Attention. Bias from input projection is included.

The key padding mask is optional. When its shape is (batch_size, kv_sequence_length), value 0
means padding or 1 otherwise. When key has right-side padding, its shape could be (batch_size): it is actual length of
each key sequence excluding paddings.

#### Version

This version of the operator has been available since version 1 of the 'com.microsoft' operator set.

#### Attributes

<dl>
<dt><tt>num_heads</tt> : int (required)</dt>
<dd>Number of attention heads</dd>
</dl>

#### Inputs (4 - 5)

<dl>
<dt><tt>query</tt> : T</dt>
<dd>Query with shape (batch_size, sequence_length, hidden_size) when weights is not available.</dd>
<dt><tt>key</tt> : T</dt>
<dd>Key with shape (batch_size, kv_sequence_length, hidden_size)</dd>
<dt><tt>value</tt> : T</dt>
<dd>Value with shape (batch_size, kv_sequence_length, v_hidden_size)</dd>
<dt><tt>bias</tt> : T</dt>
<dd>Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection</dd>
<dt><tt>key_padding_mask</tt> (optional) : M</dt>
<dd>Key padding mask with shape (batch_size) or (batch_size, kv_sequence_length)</dd>
</dl>

#### Outputs

<dl>
<dt><tt>output</tt> : T</dt>
<dd>3D output tensor with shape (batch_size, sequence_length, v_hidden_size)</dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(float), tensor(float16)</dt>
<dd>Constrain input and output to float tensors.</dd>
<dt><tt>M</tt> : tensor(int32)</dt>
<dd>Constrain mask to integer types</dd>
</dl>


### <a name="com.microsoft.DecoderAttention"></a><a name="com.microsoft.decoderattention">**com.microsoft.DecoderAttention**</a>

This DecoderAttention supports self attention and cross attention, key and value cache, and key_padding_mask. The attention mask is not support at the moment.
Expand Down Expand Up @@ -2156,6 +2105,57 @@ This version of the operator has been available since version 1 of the 'com.micr
</dl>


### <a name="com.microsoft.MultiHeadAttention"></a><a name="com.microsoft.multiheadattention">**com.microsoft.MultiHeadAttention**</a>

Multi-Head Self/Cross Attention. Bias from input projection is included.

The key padding mask is optional. When its shape is (batch_size, kv_sequence_length), value 0
means padding or 1 otherwise. When key has right-side padding, its shape could be (batch_size): it is actual length of
each key sequence excluding paddings.

#### Version

This version of the operator has been available since version 1 of the 'com.microsoft' operator set.

#### Attributes

<dl>
<dt><tt>num_heads</tt> : int (required)</dt>
<dd>Number of attention heads</dd>
</dl>

#### Inputs (4 - 5)

<dl>
<dt><tt>query</tt> : T</dt>
<dd>Query with shape (batch_size, sequence_length, hidden_size) when weights is not available.</dd>
<dt><tt>key</tt> : T</dt>
<dd>Key with shape (batch_size, kv_sequence_length, hidden_size)</dd>
<dt><tt>value</tt> : T</dt>
<dd>Value with shape (batch_size, kv_sequence_length, v_hidden_size)</dd>
<dt><tt>bias</tt> : T</dt>
<dd>Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection</dd>
<dt><tt>key_padding_mask</tt> (optional) : M</dt>
<dd>Key padding mask with shape (batch_size) or (batch_size, kv_sequence_length)</dd>
</dl>

#### Outputs

<dl>
<dt><tt>output</tt> : T</dt>
<dd>3D output tensor with shape (batch_size, sequence_length, v_hidden_size)</dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(float), tensor(float16)</dt>
<dd>Constrain input and output to float tensors.</dd>
<dt><tt>M</tt> : tensor(int32)</dt>
<dd>Constrain mask to integer types</dd>
</dl>


### <a name="com.microsoft.MurmurHash3"></a><a name="com.microsoft.murmurhash3">**com.microsoft.MurmurHash3**</a>

The underlying implementation is MurmurHash3_x86_32 generating low latency 32bits hash suitable for implementing lookup tables, Bloom filters, count min sketch or feature hashing.
Expand Down
2 changes: 1 addition & 1 deletion docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -784,7 +784,6 @@ Do not modify directly.*
|ComplexMul|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(float), tensor(float16)|
|ComplexMulConj|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(float), tensor(float16)|
|ConvTransposeWithDynamicPads|*in* X:**T**<br> *in* W:**T**<br> *in* Pads:**tensor(int64)**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|CrossAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|DecoderAttention|*in* query:**T**<br> *in* key:**T**<br> *in* q_weight:**T**<br> *in* kv_weight:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**B**<br> *in* key_cache:**T**<br> *in* value_cache:**T**<br> *in* static_kv:**B**<br> *in* use_past:**B**<br> *in* has_layer_state:**B**<br> *in* has_key_padding_mask:**B**<br> *out* output:**T**<br> *out* new_key_cache:**T**<br> *out* new_value_cache:**T**|1+|**T** = tensor(float), tensor(float16)|
|DequantizeLinear|*in* x:**T1**<br> *in* x_scale:**T2**<br> *in* x_zero_point:**T1**<br> *out* y:**T2**|1+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(float16)|
|DequantizeWithOrder|*in* input:**Q**<br> *in* scale_input:**S**<br> *out* output:**F**|1+|**F** = tensor(float), tensor(float16)<br/> **Q** = tensor(int8)<br/> **S** = tensor(float)|
Expand All @@ -798,6 +797,7 @@ Do not modify directly.*
|Inverse|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Irfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|LongformerAttention|*in* input:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* mask:**T**<br> *in* global_weight:**T**<br> *in* global_bias:**T**<br> *in* global:**G**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|NGramRepeatBlock|*in* input_ids:**Tid**<br> *in* scores:**T**<br> *out* scores_out:**T**|1+|**T** = tensor(float)<br/> **Tid** = tensor(int64)|
|QAttention|*in* input:**T1**<br> *in* weight:**T2**<br> *in* bias:**T3**<br> *in* input_scale:**T3**<br> *in* weight_scale:**T3**<br> *in* mask_index:**T4**<br> *in* input_zero_point:**T1**<br> *in* weight_zero_point:**T2**<br> *in* past:**T3**<br> *out* output:**T3**<br> *out* present:**T3**|1+|**T1** = tensor(int8)<br/> **T2** = tensor(int8)<br/> **T3** = tensor(float), tensor(float16)<br/> **T4** = tensor(int32)|
|QOrderedAttention|*in* input:**Q**<br> *in* scale_input:**S**<br> *in* scale_Q_gemm:**S**<br> *in* scale_K_gemm:**S**<br> *in* scale_V_gemm:**S**<br> *in* Q_weight:**Q**<br> *in* K_weight:**Q**<br> *in* V_weight:**Q**<br> *in* scale_Q_weight:**S**<br> *in* scale_K_weight:**S**<br> *in* scale_V_weight:**S**<br> *in* Q_bias:**S**<br> *in* K_bias:**S**<br> *in* V_bias:**S**<br> *in* scale_QKT_gemm:**S**<br> *in* scale_QKT_softmax:**S**<br> *in* scale_values_gemm:**S**<br> *in* mask_index:**G**<br> *in* past:**Q**<br> *in* extra_add:**S**<br> *out* output:**Q**|1+|**G** = tensor(int32)<br/> **Q** = tensor(int8)<br/> **S** = tensor(float)|
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

namespace onnxruntime {
namespace contrib {
namespace cross_attention_helper {
namespace multihead_attention_helper {

template <typename T>
Status CheckInputs(const T* query,
Expand Down Expand Up @@ -114,6 +114,6 @@ Status CheckInputs(const T* query,
return Status::OK();
}

} // namespace cross_attention_helper
} // namespace multihead_attention_helper
} // namespace contrib
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
#include "core/providers/cuda/cuda_common.h"
#include "core/platform/env_var_utils.h"
#include "contrib_ops/cuda/bert/attention_impl.h"
#include "contrib_ops/cuda/bert/cross_attention.h"
#include "contrib_ops/cpu/bert/cross_attention_helper.h"
#include "contrib_ops/cuda/bert/multihead_attention.h"
#include "contrib_ops/cpu/bert/multihead_attention_helper.h"

using namespace onnxruntime::cuda;
using namespace ::onnxruntime::common;
Expand All @@ -17,20 +17,20 @@ namespace cuda {

#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
CrossAttention, \
MultiHeadAttention, \
kMSDomain, \
1, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
CrossAttention<T>);
MultiHeadAttention<T>);

REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(MLFloat16)

template <typename T>
CrossAttention<T>::CrossAttention(const OpKernelInfo& info) : CudaKernel(info) {
MultiHeadAttention<T>::MultiHeadAttention(const OpKernelInfo& info) : CudaKernel(info) {
int64_t num_heads = 0;
ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0);
num_heads_ = static_cast<int>(num_heads);
Expand All @@ -43,7 +43,7 @@ CrossAttention<T>::CrossAttention(const OpKernelInfo& info) : CudaKernel(info) {
}

template <typename T>
Status CrossAttention<T>::ComputeInternal(OpKernelContext* context) const {
Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* query = context->Input<Tensor>(0);
const Tensor* key = context->Input<Tensor>(1);
const Tensor* value = context->Input<Tensor>(2);
Expand All @@ -52,7 +52,7 @@ Status CrossAttention<T>::ComputeInternal(OpKernelContext* context) const {

auto& device_prop = GetDeviceProp();
AttentionParameters parameters;
ORT_RETURN_IF_ERROR(cross_attention_helper::CheckInputs<Tensor>(query,
ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs<Tensor>(query,
key,
value,
bias,
Expand Down Expand Up @@ -84,7 +84,7 @@ Status CrossAttention<T>::ComputeInternal(OpKernelContext* context) const {
enable_flash_attention_, false);

if (use_fused_runner) {
// Here we assume that num_heads and head_size does not change for an CrossAttention node.
// Here we assume that num_heads and head_size does not change for an MultiHeadAttention node.
if (nullptr == fused_fp16_runner_.get()) {
constexpr bool is_unidirectional = false;
fused_fp16_runner_.reset(new FusedMHARunnerFP16v2(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ namespace cuda {
using namespace onnxruntime::cuda;

template <typename T>
class CrossAttention final : public CudaKernel {
class MultiHeadAttention final : public CudaKernel {
public:
CrossAttention(const OpKernelInfo& info);
MultiHeadAttention(const OpKernelInfo& info);
Status ComputeInternal(OpKernelContext* context) const override;

protected:
Expand Down
8 changes: 4 additions & 4 deletions onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, Crop);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, Crop);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, Crop);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, CrossAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, CrossAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MultiHeadAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MultiHeadAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DecoderAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DecoderAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, int32_t, DynamicSlice);
Expand Down Expand Up @@ -183,8 +183,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, Crop)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, Crop)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, Crop)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, CrossAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, CrossAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MultiHeadAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MultiHeadAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DecoderAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DecoderAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, int32_t, DynamicSlice)>,
Expand Down
14 changes: 7 additions & 7 deletions onnxruntime/core/graph/contrib_ops/bert_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ void RestorePaddingTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx)
}
}

void CrossAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) {
void MultiHeadAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) {
// Input 0 (query) has shape (batch_size, sequence_length, hidden_size)
// Input 1 (key) has shape (batch_size, kv_sequence_length, hidden_size)
// Input 2 (value) has shape (batch_size, kv_sequence_length, v_hidden_size)
Expand Down Expand Up @@ -258,18 +258,18 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
AttentionTypeAndShapeInference(ctx, past_input_index);
}));

constexpr const char* CrossAttention_ver1_doc = R"DOC(
Multi-Head Cross Attention. Bias from input projection is included.
constexpr const char* MultiHeadAttention_ver1_doc = R"DOC(
Multi-Head Self/Cross Attention. Bias from input projection is included.

The key padding mask is optional. When its shape is (batch_size, kv_sequence_length), value 0
means padding or 1 otherwise. When key has right-side padding, its shape could be (batch_size): it is actual length of
each key sequence excluding paddings.
)DOC";

ONNX_MS_OPERATOR_SET_SCHEMA(
CrossAttention, 1,
MultiHeadAttention, 1,
OpSchema()
.SetDoc(CrossAttention_ver1_doc)
.SetDoc(MultiHeadAttention_ver1_doc)
.Attr("num_heads", "Number of attention heads", AttributeProto::INT)
.Input(0,
"query",
Expand Down Expand Up @@ -299,7 +299,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
.TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output to float tensors.")
.TypeConstraint("M", {"tensor(int32)"}, "Constrain mask to integer types")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
CrossAttentionTypeAndShapeInference(ctx);
MultiHeadAttentionTypeAndShapeInference(ctx);
}));

constexpr const char* Longformer_Attention_doc = R"DOC(
Expand Down Expand Up @@ -436,7 +436,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
schema.BuildFunction(functionProto);
return true;
}));

ONNX_MS_OPERATOR_SET_SCHEMA(
RelativePositionBias, 1,
OpSchema()
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/graph/contrib_ops/ms_opset.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, ComplexMul);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, ComplexMulConj);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, ConvTransposeWithDynamicPads);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, CropAndResize);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, CrossAttention);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MultiHeadAttention);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, DecoderAttention);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, EmbedLayerNormalization);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, ExpandDims);
Expand Down Expand Up @@ -142,7 +142,7 @@ class OpSet_Microsoft_ver1 {
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, ComplexMulConj)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, ConvTransposeWithDynamicPads)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, CropAndResize)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, CrossAttention)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MultiHeadAttention)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, DecoderAttention)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, EmbedLayerNormalization)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, ExpandDims)>());
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/python/tools/symbolic_shape_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""):
# contrib ops:
"Attention": self._infer_Attention,
"BiasGelu": self._infer_BiasGelu,
"CrossAttention": self._infer_CrossAttention,
"MultiHeadAttention": self._infer_MultiHeadAttention,
"EmbedLayerNormalization": self._infer_EmbedLayerNormalization,
"FastGelu": self._infer_FastGelu,
"Gelu": self._infer_Gelu,
Expand Down Expand Up @@ -1994,7 +1994,7 @@ def _infer_Attention(self, node):
def _infer_BiasGelu(self, node):
self._propagate_shape_and_type(node)

def _infer_CrossAttention(self, node):
def _infer_MultiHeadAttention(self, node):
# Input 0 (query) has shape (batch_size, sequence_length, hidden_size)
# Input 1 (key) has shape (batch_size, kv_sequence_length, hidden_size)
# Input 2 (value) has shape (batch_size, kv_sequence_length, v_hidden_size)
Expand Down
Loading