Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions orttraining/orttraining/core/graph/gradient_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2253,6 +2253,66 @@ IMPLEMENT_GRADIENT_BUILDER(GetGlobalMaxPoolGradient) {

result.push_back(NodeDef("Expand", {GO(0), IA("X_shape")}, {IA("expanded_dY")}));
result.push_back(NodeDef("Mul", {IA("mask_cast"), IA("expanded_dY")}, {GI(0)}));
return result;
}

IMPLEMENT_GRADIENT_BUILDER(GetReduceMaxGradient) {
std::vector<NodeDef> result;
auto attributes = SrcNodeAttributes();
bool keepdims = true;

// Check the "keepdims" attribute
if (attributes.find("keepdims") != attributes.end() &&
attributes.at("keepdims").has_i()) {
keepdims = static_cast<bool>(attributes.at("keepdims").i());
}

ArgDef grad = GO(0);
ArgDef reduced_output = O(0);

if (!keepdims) {
size_t numInputs = GetSrcNodeInputSize();
ArgDef unsqueeze_axes_arg;
bool axes_provided = false;

// Handle "axes" as attribute or input
if (attributes.find("axes") != attributes.end()) {
axes_provided = true;
std::vector<int64_t> axes_values = RetrieveValues<int64_t>(attributes.at("axes"));
if (SrcNodeOpsetVersion() >= 13) {
NodeDef axes_values_node = ConstantVectorNode(axes_values, Name("axes_values"));
result.push_back(axes_values_node);
unsqueeze_axes_arg = axes_values_node.output_args[0];
}
} else if (numInputs == 2) {
axes_provided = true;
unsqueeze_axes_arg = I(1);
}

if (axes_provided) {
grad = IA("Unsqueezed_Grad");
reduced_output = IA("Unsqueezed_Output");
if (SrcNodeOpsetVersion() < 13 && attributes.find("axes") != attributes.end()) {
std::vector<int64_t> axes_values = RetrieveValues<int64_t>(attributes.at("axes"));
result.push_back(NodeDef("Unsqueeze", {GO(0)}, {grad}, {MakeAttribute("axes", axes_values)}));
result.push_back(NodeDef("Unsqueeze", {O(0)}, {reduced_output}, {MakeAttribute("axes", axes_values)}));
} else {
result.push_back(NodeDef(OpDef{"Unsqueeze", kOnnxDomain, 13}, {GO(0), unsqueeze_axes_arg}, {grad}));
result.push_back(NodeDef(OpDef{"Unsqueeze", kOnnxDomain, 13}, {O(0), unsqueeze_axes_arg}, {reduced_output}));
}
}
}

// Step 1: Recreate the boolean mask tensor indicating max positions
result.push_back(NodeDef("Shape", {I(0)}, {IA("Shaped_X")}));
result.push_back(NodeDef("Expand", {reduced_output, IA("Shaped_X")}, {IA("Expanded_Output")}));
result.push_back(NodeDef("Equal", {I(0), IA("Expanded_Output")}, {IA("Mask")}));
// Step 2: Convert the boolean mask to a float tensor (0.0 and 1.0)
result.push_back(NodeDef("Cast", {IA("Mask")}, {IA("Mask_Float")}, {MakeAttribute("to", static_cast<int64_t>(OElemType(0)))}));
// Step 3: Multiply the input gradient by the mask
result.push_back(NodeDef("Mul", {grad, IA("Mask_Float")}, {IA("Masked_Grad")}));
// Step 4: Ensure the output gradient has the same shape as the input
result.push_back(NodeDef("Expand", {IA("Masked_Grad"), IA("Shaped_X")}, {GI(0)}));

return result;
}
Expand Down
1 change: 1 addition & 0 deletions orttraining/orttraining/core/graph/gradient_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ DECLARE_GRADIENT_BUILDER(GetConvTransposeGradient)
DECLARE_GRADIENT_BUILDER(GetResizeGradient)
DECLARE_GRADIENT_BUILDER(GetAtanGradient)
DECLARE_GRADIENT_BUILDER(GetGlobalMaxPoolGradient)
DECLARE_GRADIENT_BUILDER(GetReduceMaxGradient)

DECLARE_GRADIENT_BUILDER(GetExternalGradient)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
REGISTER_GRADIENT_BUILDER("Resize", GetResizeGradient);
REGISTER_GRADIENT_BUILDER("Atan", GetAtanGradient);
REGISTER_GRADIENT_BUILDER("GlobalMaxPool", GetGlobalMaxPoolGradient);
REGISTER_GRADIENT_BUILDER("ReduceMax", GetReduceMaxGradient);

REGISTER_GRADIENT_BUILDER("ExternalGradient", GetExternalGradient);
};
Expand Down
24 changes: 24 additions & 0 deletions orttraining/orttraining/test/gradient/gradient_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3379,6 +3379,30 @@ TEST(GradientCheckerTest, GlobalMaxPoolGrad) {
}
}

TEST(GradientCheckerTest, ReduceMaxGrad) {
// Attribute axes supports negative values from opset 11.
OpDef op_def_11{"ReduceMax", kOnnxDomain, 11};

RunReductionTests(op_def_11, false, true);

OpDef op_def_12{"ReduceMax", kOnnxDomain, 12};

RunReductionTests(op_def_12, false, true);

OpDef op_def_13{"ReduceMax", kOnnxDomain, 13};

RunReductionTests(op_def_13, false, true);

// axes is input from opset 18.
OpDef op_def_18{"ReduceMax", kOnnxDomain, 18};

RunReductionTests(op_def_18, true, true);

OpDef op_def_20{"ReduceMax", kOnnxDomain, 20};

RunReductionTests(op_def_20, true, true);
}

} // namespace test
} // namespace onnxruntime

Expand Down
Loading