Skip to content

Commit d2c5e24

Browse files
authored
Add of GlobalMaxPool Gradient (#23502)
### Description Added gradient computation support for the GlobalMaxPool node. ### Motivation and Context Improve the training capabilities of ONNX Runtime.
1 parent ded8730 commit d2c5e24

File tree

4 files changed

+43
-0
lines changed

4 files changed

+43
-0
lines changed

orttraining/orttraining/core/graph/gradient_builder.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2239,5 +2239,23 @@ IMPLEMENT_GRADIENT_BUILDER(GetAtanGradient) {
22392239
return result;
22402240
}
22412241

2242+
IMPLEMENT_GRADIENT_BUILDER(GetGlobalMaxPoolGradient) {
2243+
// For GlobalMaxPool's gradient, a binary mask flags max elements.
2244+
// We multiply that mask by the incoming gradient, passing gradients only to maxima.
2245+
std::vector<NodeDef> result;
2246+
result.push_back(NodeDef("Shape", {I(0)}, {IA("X_shape")}));
2247+
result.push_back(NodeDef("Expand", {O(0), IA("X_shape")}, {IA("expanded_Y")}));
2248+
result.push_back(NodeDef("Equal", {I(0), IA("expanded_Y")}, {IA("mask")}));
2249+
result.push_back(NodeDef("Cast",
2250+
{IA("mask")},
2251+
{IA("mask_cast")},
2252+
{MakeAttribute("to", static_cast<int64_t>(IElemType(0)))}));
2253+
2254+
result.push_back(NodeDef("Expand", {GO(0), IA("X_shape")}, {IA("expanded_dY")}));
2255+
result.push_back(NodeDef("Mul", {IA("mask_cast"), IA("expanded_dY")}, {GI(0)}));
2256+
2257+
return result;
2258+
}
2259+
22422260
} // namespace training
22432261
} // namespace onnxruntime

orttraining/orttraining/core/graph/gradient_builder.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ DECLARE_GRADIENT_BUILDER(GetLeakyReluGradient)
9494
DECLARE_GRADIENT_BUILDER(GetConvTransposeGradient)
9595
DECLARE_GRADIENT_BUILDER(GetResizeGradient)
9696
DECLARE_GRADIENT_BUILDER(GetAtanGradient)
97+
DECLARE_GRADIENT_BUILDER(GetGlobalMaxPoolGradient)
9798

9899
DECLARE_GRADIENT_BUILDER(GetExternalGradient)
99100

orttraining/orttraining/core/graph/gradient_builder_registry.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
126126
REGISTER_GRADIENT_BUILDER("ConvTranspose", GetConvTransposeGradient);
127127
REGISTER_GRADIENT_BUILDER("Resize", GetResizeGradient);
128128
REGISTER_GRADIENT_BUILDER("Atan", GetAtanGradient);
129+
REGISTER_GRADIENT_BUILDER("GlobalMaxPool", GetGlobalMaxPoolGradient);
129130

130131
REGISTER_GRADIENT_BUILDER("ExternalGradient", GetExternalGradient);
131132
};

orttraining/orttraining/test/gradient/gradient_ops_test.cc

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3356,6 +3356,29 @@ TEST(GradientCheckerTest, ResizeGrad) {
33563356

33573357
TEST(GradientCheckerTest, AtanGrad) { UnaryOpGradientTest("Atan"); }
33583358

3359+
TEST(GradientCheckerTest, GlobalMaxPoolGrad) {
3360+
float max_error;
3361+
GradientChecker<float, float, float> gradient_checker;
3362+
OpDef op_def{"GlobalMaxPool", kOnnxDomain, 11};
3363+
constexpr float error_tolerance = 1e-3f;
3364+
3365+
// globalmaxpool
3366+
{
3367+
ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def, {{2, 3, 5, 5}}, {{2, 3, 1, 1}}, &max_error, {},
3368+
/*check_not_have_gradient*/ true,
3369+
/*check_not_have_shape_inferencing*/ true));
3370+
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
3371+
}
3372+
3373+
// globalmaxpool_precomputed
3374+
{
3375+
ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def, {{2, 1, 3, 3}}, {{2, 1, 1, 1}}, &max_error, {},
3376+
/*check_not_have_gradient*/ true,
3377+
/*check_not_have_shape_inferencing*/ true));
3378+
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
3379+
}
3380+
}
3381+
33593382
} // namespace test
33603383
} // namespace onnxruntime
33613384

0 commit comments

Comments
 (0)