Skip to content

Commit 237bbe9

Browse files
committed
Reverse-mode derivatives for math.atan, math.absf, arith.select
1 parent 6e06a22 commit 237bbe9

File tree

3 files changed

+19
-0
lines changed

3 files changed

+19
-0
lines changed

enzyme/Enzyme/MLIR/Implementations/ArithDerivatives.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,13 @@ def : MLIRDerivative<"arith", "DivFOp", (Op $x, $y),
3131
],
3232
(CheckedDivF (SubF (SelectIfActive $x, (MulF (Shadow $x), $y), (ConstantFP<"0","arith", "ConstantOp"> $x)), (SelectIfActive $y, (MulF (Shadow $y), $x), (ConstantFP<"0","arith","ConstantOp"> $y))), (MulF $y, $y))
3333
>;
34+
def : MLIRDerivative<"arith", "SelectOp", (Op $pred, $x, $y),
35+
[
36+
(AssertingInactiveArg),
37+
(Arith_Select $pred, (DiffeRet), (ConstantFP<"0","arith","ConstantOp"> $x)),
38+
(Arith_Select $pred, (ConstantFP<"0","arith","ConstantOp"> $x), (DiffeRet)),
39+
]
40+
>;
3441

3542
def ExtF : ArithInst<"ExtFOp">;
3643
def TruncF : ArithInst<"TruncFOp">;

enzyme/Enzyme/MLIR/Implementations/Common.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ def CmpF : ArithInst<"CmpFOp">;
155155
def Arith_Select : ArithInst<"SelectOp">;
156156

157157
def Arith_OEQ : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, "arith::CmpFPredicate::OEQ">;
158+
def Arith_OGE : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, "arith::CmpFPredicate::OGE">;
158159

159160
def CheckedMulF : SubRoutine<(Op $diffret, $x),
160161
(

enzyme/Enzyme/MLIR/Implementations/MathDerivatives.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,14 @@ def : MLIRDerivative<"math", "SqrtOp", (Op $x),
2020
(Arith_Select (CmpF (Arith_OEQ), $x, (ConstantFP<"0","arith","ConstantOp"> $x):$zero), $zero, (DivF (DiffeRet), (MulF (ConstantFP<"2.0","arith","ConstantOp"> $x), (SqrtF $x))))
2121
]
2222
>;
23+
def : MLIRDerivative<"math", "AtanOp", (Op $x),
24+
[
25+
(CheckedMulF (DiffeRet), (DivF (ConstantFP<"1.0","arith","ConstantOp"> $x):$one, (AddF (MulF $x, $x), $one)))
26+
]
27+
>;
28+
def : MLIRDerivative<"math", "AbsFOp", (Op $x),
29+
[
30+
// TODO: handle complex
31+
(Arith_Select (CmpF (Arith_OGE), $x, (ConstantFP<"0","arith","ConstantOp"> $x)), (DiffeRet), (NegF (DiffeRet)))
32+
]
33+
>;

0 commit comments

Comments
 (0)