Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions orttraining/orttraining/core/graph/gradient_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2239,5 +2239,23 @@ IMPLEMENT_GRADIENT_BUILDER(GetAtanGradient) {
return result;
}

IMPLEMENT_GRADIENT_BUILDER(GetGlobalMaxPoolGradient) {
// For GlobalMaxPool's gradient, a binary mask flags max elements.
// We multiply that mask by the incoming gradient, passing gradients only to maxima.
std::vector<NodeDef> result;
result.push_back(NodeDef("Shape", {I(0)}, {IA("X_shape")}));
result.push_back(NodeDef("Expand", {O(0), IA("X_shape")}, {IA("expanded_Y")}));
result.push_back(NodeDef("Equal", {I(0), IA("expanded_Y")}, {IA("mask")}));
result.push_back(NodeDef("Cast",
{IA("mask")},
{IA("mask_cast")},
{MakeAttribute("to", static_cast<int64_t>(IElemType(0)))}));

result.push_back(NodeDef("Expand", {GO(0), IA("X_shape")}, {IA("expanded_dY")}));
result.push_back(NodeDef("Mul", {IA("mask_cast"), IA("expanded_dY")}, {GI(0)}));

return result;
}

} // namespace training
} // namespace onnxruntime
1 change: 1 addition & 0 deletions orttraining/orttraining/core/graph/gradient_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ DECLARE_GRADIENT_BUILDER(GetLeakyReluGradient)
DECLARE_GRADIENT_BUILDER(GetConvTransposeGradient)
DECLARE_GRADIENT_BUILDER(GetResizeGradient)
DECLARE_GRADIENT_BUILDER(GetAtanGradient)
DECLARE_GRADIENT_BUILDER(GetGlobalMaxPoolGradient)

DECLARE_GRADIENT_BUILDER(GetExternalGradient)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
REGISTER_GRADIENT_BUILDER("ConvTranspose", GetConvTransposeGradient);
REGISTER_GRADIENT_BUILDER("Resize", GetResizeGradient);
REGISTER_GRADIENT_BUILDER("Atan", GetAtanGradient);
REGISTER_GRADIENT_BUILDER("GlobalMaxPool", GetGlobalMaxPoolGradient);

REGISTER_GRADIENT_BUILDER("ExternalGradient", GetExternalGradient);
};
Expand Down
23 changes: 23 additions & 0 deletions orttraining/orttraining/test/gradient/gradient_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3356,6 +3356,29 @@ TEST(GradientCheckerTest, ResizeGrad) {

TEST(GradientCheckerTest, AtanGrad) { UnaryOpGradientTest("Atan"); }

TEST(GradientCheckerTest, GlobalMaxPoolGrad) {
float max_error;
GradientChecker<float, float, float> gradient_checker;
OpDef op_def{"GlobalMaxPool", kOnnxDomain, 11};
constexpr float error_tolerance = 1e-3f;

// globalmaxpool
{
ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def, {{2, 3, 5, 5}}, {{2, 3, 1, 1}}, &max_error, {},
/*check_not_have_gradient*/ true,
/*check_not_have_shape_inferencing*/ true));
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
}

// globalmaxpool_precomputed
{
ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def, {{2, 1, 3, 3}}, {{2, 1, 1, 1}}, &max_error, {},
/*check_not_have_gradient*/ true,
/*check_not_have_shape_inferencing*/ true));
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
}
}

} // namespace test
} // namespace onnxruntime

Expand Down
Loading