Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ OpBuilderRegistrations::OpBuilderRegistrations() {

{
CreateGroupNormOpBuilder("GroupNormalization", *this);
CreateGroupNormOpBuilder("GroupNorm", *this);
}

{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ class BaseOpBuilder : public IOpBuilder {
{"RMSNormalization", QNN_OP_RMS_NORM},
{"SimplifiedLayerNormalization", QNN_OP_RMS_NORM},
{"GroupNormalization", QNN_OP_GROUP_NORM},
{"GroupNorm", QNN_OP_GROUP_NORM},

{"LRN", QNN_OP_LRN},

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,40 +52,52 @@ Ort::Status GroupNormOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper
const std::vector<uint32_t>& input_shape = input_info.shape;
const size_t input_rank = input_shape.size();

if (input_rank <= 2) {
return MAKE_EP_FAIL("QNN GroupNorm only supports input ranks greater than 2.");
}
RETURN_IF(input_rank <= 2, "QNN GroupNorm only supports input ranks greater than 2.");

// Handle layout transformation - check if already transformed to NHWC
const uint32_t num_channels = (node_unit.Domain() == kMSInternalNHWCDomain) ? input_shape.back() : input_shape[1];
OrtNodeAttrHelper node_helper(node_unit);
uint32_t num_channels;
if (node_unit.Domain() == kMSDomain) {
// Handle channels_last attribute for com.microsoft.GroupNorm
const int64_t channels_last = node_helper.Get("channels_last", static_cast<int64_t>(1));
num_channels = (channels_last == 1) ? input_shape.back() : input_shape[1];
} else {
// Handle layout transformation - check if already transformed to NHWC
num_channels = (node_unit.Domain() == kMSInternalNHWCDomain) ? input_shape.back() : input_shape[1];
}

TensorInfo scale_info{};
RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[1], scale_info));
const std::vector<uint32_t>& scale_shape = scale_info.shape;
if (scale_shape.size() != 1 || scale_shape[0] != num_channels) {
return MAKE_EP_FAIL("QNN GroupNorm input 1 (scale) must have 1D shape [channel].");
}
RETURN_IF(scale_shape.size() != 1 || scale_shape[0] != num_channels,
("QNN GroupNorm input 1 (scale/gamma) must have 1D shape [" + std::to_string(num_channels) + "].").c_str());

TensorInfo bias_info{};
RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[2], bias_info));
const std::vector<uint32_t>& bias_shape = bias_info.shape;
if (bias_shape.size() != 1 || bias_shape[0] != num_channels) {
return MAKE_EP_FAIL("QNN GroupNorm input 2 (bias) must have 1D shape [channel].");
}
RETURN_IF(bias_shape.size() != 1 || bias_shape[0] != num_channels,
("QNN GroupNorm input 2 (bias/beta) must have 1D shape [" + std::to_string(num_channels) + "].").c_str());

OrtNodeAttrHelper node_helper(node_unit);
const float epsilon = node_helper.Get("epsilon", 1e-05f);
if (epsilon <= 0.0f) {
return MAKE_EP_FAIL("QNN GroupNorm epsilon must be greater than 0.0");
}
RETURN_IF(epsilon <= 0.0f, "QNN GroupNorm epsilon must be greater than 0.0");

const int64_t num_groups = node_helper.Get("num_groups", static_cast<int64_t>(1));
if (num_groups <= 0) {
return MAKE_EP_FAIL("QNN GroupNorm num_groups must be greater than 0");
// Support both "num_groups" (ONNX GroupNormalization) and "groups" (com.microsoft.GroupNorm)
// Note: we cannot use node_unit.Domain() because the op domain may have been transformed to kMSInternalNHWCDomain.
int64_t num_groups;
if (node_unit.OpType() == "GroupNormalization") {
num_groups = node_helper.Get("num_groups", static_cast<int64_t>(1));
} else {
num_groups = node_helper.Get("groups", static_cast<int64_t>(1));
}
RETURN_IF(num_groups <= 0, "QNN GroupNorm num_groups/groups must be greater than 0");

if (num_channels % static_cast<uint32_t>(num_groups) != 0) {
return MAKE_EP_FAIL("QNN GroupNorm requires num_channels to be divisible by num_groups");
RETURN_IF(num_channels % static_cast<uint32_t>(num_groups) != 0,
"QNN GroupNorm requires num_channels to be divisible by num_groups/groups");

// Check activation attribute for com.microsoft.GroupNorm
if (node_unit.OpType() == "GroupNorm") {
const int64_t activation = node_helper.Get("activation", static_cast<int64_t>(0));
RETURN_IF(activation != 0 && activation != 1,
"QNN GroupNorm only supports activation=0 (None) or activation=1 (SiLU)");
}

// Continue Op validation if it's NHWC transformed
Expand Down Expand Up @@ -119,6 +131,53 @@ Ort::Status GroupNormOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn
OrtNodeAttrHelper node_helper(node_unit);
std::vector<std::string> param_tensor_names;

const auto& inputs = node_unit.Inputs();
TensorInfo input_info = {};
RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[0], input_info));

TensorInfo scale_info = {};
RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[1], scale_info));

TensorInfo bias_info = {};
RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[2], bias_info));

// Check if we need to cast scale and bias to match input dtype
std::string scale_input_name = input_names[1];
std::string bias_input_name = input_names[2];

if (scale_info.qnn_data_type != input_info.qnn_data_type) {
// Create Cast node for scale
std::string casted_scale_name = utils::GetUniqueName(node_unit.Name() + "_scale_cast");
RETURN_IF_ERROR(qnn_model_wrapper.AddCastNode(casted_scale_name,
scale_input_name,
casted_scale_name,
QNN_TENSOR_TYPE_NATIVE,
input_info.qnn_data_type,
QnnQuantParamsWrapper(),
std::move(scale_info.shape),
do_op_validation));

scale_input_name = casted_scale_name;
}

if (bias_info.qnn_data_type != input_info.qnn_data_type) {
// Create Cast node for bias
std::string casted_bias_name = utils::GetUniqueName(node_unit.Name() + "_bias_cast");
RETURN_IF_ERROR(qnn_model_wrapper.AddCastNode(casted_bias_name,
bias_input_name,
casted_bias_name,
QNN_TENSOR_TYPE_NATIVE,
input_info.qnn_data_type,
QnnQuantParamsWrapper(),
std::move(bias_info.shape),
do_op_validation));

bias_input_name = casted_bias_name;
}

// Update input_names with potentially casted scale and bias
std::vector<std::string> group_norm_input_names = {input_names[0], scale_input_name, bias_input_name};

const float epsilon = node_helper.Get("epsilon", 1e-05f);
Qnn_Scalar_t epsilon_param = QNN_SCALAR_INIT;
epsilon_param.dataType = QNN_DATATYPE_FLOAT_32;
Expand All @@ -130,7 +189,14 @@ Ort::Status GroupNormOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn
param_tensor_names.push_back(epsilon_param_wrapper.GetParamTensorName());
qnn_model_wrapper.AddParamWrapper(std::move(epsilon_param_wrapper));

const int64_t num_groups = node_helper.Get("num_groups", static_cast<int64_t>(1));
// Support both "num_groups" (ONNX GroupNormalization) and "groups" (com.microsoft.GroupNorm)
// Note: we cannot use node_unit.Domain() because the op domain may have been transformed to kMSInternalNHWCDomain.
int64_t num_groups;
if (node_unit.OpType() == "GroupNormalization") {
num_groups = node_helper.Get("num_groups", static_cast<int64_t>(1));
} else {
num_groups = node_helper.Get("groups", static_cast<int64_t>(1));
}
Qnn_Scalar_t num_groups_param = QNN_SCALAR_INIT;
num_groups_param.dataType = QNN_DATATYPE_UINT_32;
num_groups_param.uint32Value = static_cast<uint32_t>(num_groups);
Expand All @@ -141,10 +207,82 @@ Ort::Status GroupNormOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn
param_tensor_names.push_back(num_groups_param_wrapper.GetParamTensorName());
qnn_model_wrapper.AddParamWrapper(std::move(num_groups_param_wrapper));

return ProcessOutputs(qnn_model_wrapper, node_unit,
std::move(input_names),
std::move(param_tensor_names),
logger, do_op_validation, GetQnnOpType(node_unit.OpType()));
// Check if we need to add SiLU activation (activation=1)
const int64_t activation = node_helper.Get("activation", static_cast<int64_t>(0));

if (activation == 0) {
// No activation, just process outputs normally
return ProcessOutputs(qnn_model_wrapper, node_unit,
std::move(group_norm_input_names),
std::move(param_tensor_names),
logger, do_op_validation, GetQnnOpType(node_unit.OpType()));
}

// activation == 1: Add SiLU activation (x * sigmoid(x))
const auto& outputs = node_unit.Outputs();
const std::string& final_output_name = outputs[0].name;

// Create intermediate output for GroupNorm
std::string group_norm_output_name = utils::GetUniqueName(node_unit.Name() + "_group_norm_out");
QnnTensorWrapper group_norm_output_tensor(group_norm_output_name,
QNN_TENSOR_TYPE_NATIVE,
input_info.qnn_data_type,
QnnQuantParamsWrapper(),
std::vector<uint32_t>(input_info.shape));
RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(group_norm_output_tensor)),
"Failed to add group_norm_output tensor.");

// Create GroupNorm node
RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetUniqueName(node_unit.Name() + "_group_norm"),
QNN_OP_PACKAGE_NAME_QTI_AISW,
GetQnnOpType(node_unit.OpType()),
std::move(group_norm_input_names),
{group_norm_output_name},
std::move(param_tensor_names),
do_op_validation),
"Failed to create GroupNorm node.");

// Create Sigmoid output tensor
std::string sigmoid_output_name = utils::GetUniqueName(node_unit.Name() + "_sigmoid_out");
QnnTensorWrapper sigmoid_output_tensor(sigmoid_output_name,
QNN_TENSOR_TYPE_NATIVE,
input_info.qnn_data_type,
QnnQuantParamsWrapper(),
std::vector<uint32_t>(input_info.shape));
RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(sigmoid_output_tensor)),
"Failed to add sigmoid_output tensor.");

// Create Sigmoid node
RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetUniqueName(node_unit.Name() + "_sigmoid"),
QNN_OP_PACKAGE_NAME_QTI_AISW,
QNN_OP_SIGMOID,
{group_norm_output_name},
{sigmoid_output_name},
{},
do_op_validation),
"Failed to create Sigmoid node.");

// Create final output tensor
Qnn_TensorType_t tensor_type = qnn_model_wrapper.IsGraphOutput(final_output_name) ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE;
QnnTensorWrapper final_output_tensor(final_output_name,
tensor_type,
input_info.qnn_data_type,
input_info.quant_param.Copy(),
std::vector<uint32_t>(input_info.shape));
RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(final_output_tensor)),
"Failed to add final output tensor.");

// Create ElementWiseMul node for x * sigmoid(x)
RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetUniqueName(node_unit.Name() + "_silu_mul"),
QNN_OP_PACKAGE_NAME_QTI_AISW,
QNN_OP_ELEMENT_WISE_MULTIPLY,
{group_norm_output_name, sigmoid_output_name},
{final_output_name},
{},
do_op_validation),
"Failed to create SiLU multiply node.");

return Ort::Status();
}

void CreateGroupNormOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/core/providers/qnn/qnn_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1838,6 +1838,11 @@ OrtStatus* ORT_API_CALL QnnEp::ShouldConvertDataLayoutForOpImpl(_In_ OrtEp* this
*should_convert = 1;
}

if (std::string(domain) == kMSDomain && std::string(op_type) == "GroupNorm") {
// com.microsoft.GroupNorm is translated to QNN's GroupNorm, which requires the NHWC layout for processing.
*should_convert = 1;
}

return nullptr;
}

Expand Down
Loading