Skip to content

Commit

Permalink
[mlir][complex] Support fast math flag for complex.sign op (#87148)
Browse files Browse the repository at this point in the history
We are going to support the fast math flag given in `complex.sign` op in
the conversion to standard dialect.

See:
https://discourse.llvm.org/t/rfc-fastmath-flags-support-in-complex-dialect/71981
  • Loading branch information
Lewuathe authored Apr 6, 2024
1 parent 0ba3e96 commit a522dbb
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 3 deletions.
7 changes: 4 additions & 3 deletions mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -918,6 +918,7 @@ struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
auto type = cast<ComplexType>(adaptor.getComplex().getType());
auto elementType = cast<FloatType>(type.getElementType());
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();

Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
Expand All @@ -928,9 +929,9 @@ struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
Value imagIsZero =
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
Value isZero = b.create<arith::AndIOp>(realIsZero, imagIsZero);
auto abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex());
Value realSign = b.create<arith::DivFOp>(real, abs);
Value imagSign = b.create<arith::DivFOp>(imag, abs);
auto abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex(), fmf);
Value realSign = b.create<arith::DivFOp>(real, abs, fmf);
Value imagSign = b.create<arith::DivFOp>(imag, abs, fmf);
Value sign = b.create<complex::CreateOp>(type, realSign, imagSign);
rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isZero,
adaptor.getComplex(), sign);
Expand Down
43 changes: 43 additions & 0 deletions mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1880,3 +1880,46 @@ func.func @complex_sin_with_fmf(%arg: complex<f32>) -> complex<f32> {
// CHECK-DAG: %[[RESULT_IMAG:.*]] = arith.mulf %[[EXP_DIFF]], %[[COS]] fastmath<nnan,contract>
// CHECK-DAG: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
// CHECK: return %[[RESULT]]

// -----

// CHECK-LABEL: func @complex_sign_with_fmf
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
func.func @complex_sign_with_fmf(%arg: complex<f32>) -> complex<f32> {
%sign = complex.sign %arg fastmath<nnan,contract> : complex<f32>
return %sign : complex<f32>
}

// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[REAL_IS_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
// CHECK: %[[IMAG_IS_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
// CHECK: %[[IS_ZERO:.*]] = arith.andi %[[REAL_IS_ZERO]], %[[IMAG_IS_ZERO]] : i1
// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex<f32>
// CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex<f32>
// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL2]], %[[ZERO]] : f32
// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG2]], %[[ZERO]] : f32
// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG2]], %[[REAL2]] fastmath<nnan,contract> : f32
// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] fastmath<nnan,contract> : f32
// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
// CHECK: %[[REAL_ABS:.*]] = math.absf %[[REAL2]] fastmath<nnan,contract> : f32
// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL_ABS]] fastmath<nnan,contract> : f32
// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL2]], %[[IMAG2]] fastmath<nnan,contract> : f32
// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] fastmath<nnan,contract> : f32
// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
// CHECK: %[[IMAG_ABS:.*]] = math.absf %[[IMAG2]] fastmath<nnan,contract> : f32
// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG_ABS]] fastmath<nnan,contract> : f32
// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL2]], %[[IMAG2]] : f32
// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32
// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL_ABS]], %[[ABS1]] : f32
// CHECK: %[[NORM:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG_ABS]], %[[ABS2]] : f32
// CHECK: %[[REAL_SIGN:.*]] = arith.divf %[[REAL]], %[[NORM]] fastmath<nnan,contract> : f32
// CHECK: %[[IMAG_SIGN:.*]] = arith.divf %[[IMAG]], %[[NORM]] fastmath<nnan,contract> : f32
// CHECK: %[[SIGN:.*]] = complex.create %[[REAL_SIGN]], %[[IMAG_SIGN]] : complex<f32>
// CHECK: %[[RESULT:.*]] = arith.select %[[IS_ZERO]], %[[ARG]], %[[SIGN]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>

0 comments on commit a522dbb

Please sign in to comment.