diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index 76fe0ee91d4c6..43835f07c4b40 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -2227,5 +2227,17 @@ IMPLEMENT_GRADIENT_BUILDER(GetResizeGradient) { SrcNodeAttributes())}; } +IMPLEMENT_GRADIENT_BUILDER(GetAtanGradient) { + // dl/dx = dl/dy * (1/(1+x^2)) + NodeDef one_const_node = OneConstantNode(IElemType(0)); + ArgDef one = one_const_node.output_args[0]; + std::vector result; + result.push_back(one_const_node); + result.push_back(NodeDef("Mul", {I(0), I(0)}, {IA("Square_I0")})); + result.push_back(NodeDef("Add", {IA("Square_I0"), one}, {IA("One_Plus_Square_I0")})); + result.push_back(NodeDef("Div", {GO(0), IA("One_Plus_Square_I0")}, {GI(0)})); + return result; +} + } // namespace training } // namespace onnxruntime diff --git a/orttraining/orttraining/core/graph/gradient_builder.h b/orttraining/orttraining/core/graph/gradient_builder.h index 92bfae9cd83a4..2b40754b6261f 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.h +++ b/orttraining/orttraining/core/graph/gradient_builder.h @@ -93,6 +93,7 @@ DECLARE_GRADIENT_BUILDER(GetReciprocalGradient) DECLARE_GRADIENT_BUILDER(GetLeakyReluGradient) DECLARE_GRADIENT_BUILDER(GetConvTransposeGradient) DECLARE_GRADIENT_BUILDER(GetResizeGradient) +DECLARE_GRADIENT_BUILDER(GetAtanGradient) DECLARE_GRADIENT_BUILDER(GetExternalGradient) diff --git a/orttraining/orttraining/core/graph/gradient_builder_registry.cc b/orttraining/orttraining/core/graph/gradient_builder_registry.cc index ea56be9e6dfa3..9c9884c5d3865 100755 --- a/orttraining/orttraining/core/graph/gradient_builder_registry.cc +++ b/orttraining/orttraining/core/graph/gradient_builder_registry.cc @@ -125,6 +125,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() { REGISTER_GRADIENT_BUILDER("LeakyRelu", GetLeakyReluGradient); REGISTER_GRADIENT_BUILDER("ConvTranspose", GetConvTransposeGradient); REGISTER_GRADIENT_BUILDER("Resize", GetResizeGradient); + REGISTER_GRADIENT_BUILDER("Atan", GetAtanGradient); REGISTER_GRADIENT_BUILDER("ExternalGradient", GetExternalGradient); }; diff --git a/orttraining/orttraining/test/gradient/gradient_ops_test.cc b/orttraining/orttraining/test/gradient/gradient_ops_test.cc index 94ca96c68f2ce..b81a08e23e3cf 100644 --- a/orttraining/orttraining/test/gradient/gradient_ops_test.cc +++ b/orttraining/orttraining/test/gradient/gradient_ops_test.cc @@ -3352,6 +3352,8 @@ TEST(GradientCheckerTest, ResizeGrad) { #endif // USE_CUDA +TEST(GradientCheckerTest, AtanGrad) { UnaryOpGradientTest("Atan"); } + } // namespace test } // namespace onnxruntime