-
Notifications
You must be signed in to change notification settings - Fork 45
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added all possible elementwise binary ops for #316. #344
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -53,6 +53,7 @@ enum EWBinOpId : int { | |
NOT_EQUAL, | ||
OR, | ||
POWER, | ||
RSHIFT, | ||
SUBTRACT, | ||
TRUE_DIVIDE, | ||
XOR, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,9 +24,11 @@ | |
#include <mlir/Dialect/Func/Transforms/FuncConversions.h> | ||
#include <mlir/Dialect/LLVMIR/LLVMDialect.h> | ||
#include <mlir/Dialect/Linalg/IR/Linalg.h> | ||
#include <mlir/Dialect/Math/IR/Math.h> | ||
#include <mlir/Dialect/MemRef/IR/MemRef.h> | ||
#include <mlir/Dialect/Shape/IR/Shape.h> | ||
#include <mlir/Dialect/Tensor/IR/Tensor.h> | ||
#include <mlir/Dialect/Tosa/IR/TosaOps.h> | ||
#include <mlir/IR/BuiltinOps.h> | ||
#include <mlir/Pass/Pass.h> | ||
|
||
|
@@ -293,19 +295,63 @@ static BodyType buildTrivial(::mlir::Type typ) { | |
}; | ||
} | ||
|
||
// Builder for TOSA body. | ||
// The builder function takes an extra agrument ::mlir::Type | ||
// Many TOSA functions take two arguments and do the | ||
// same operation on both arguments. | ||
Comment on lines
+300
to
+301
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this the functionality we need? Why is TOSA any better than math and arith? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @fschlimb I found those logical ops only in TOSA, is there any other alternative? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think so. An alternative is of course to write them manually. |
||
// TODO: | ||
// 1. Find a way to merge this function with trivialBuilder(). | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why would we need buildTrivial if we use TOSA? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @fschlimb That's my thought too, should we just move to TOSA completely? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When I investigated I realized that TOSA does not have unsigned types. I could not figure out if that's a gap or a design decision. If we can assume that TOSA can support unsigned ints then the only reason for not using TOSA might be compile-time. But I tend to think that's probably acceptable at this point. |
||
// 2. Detect if there is rhs2, if yes, build accordingly. | ||
// 3. We need to introduce a Boolean type. | ||
template <typename IOP, typename FOP = void> | ||
static BodyType buildTosa(::mlir::Type typ) { | ||
return [typ](mlir::OpBuilder &builder, ::mlir::Location loc, | ||
::mlir::ValueRange args) -> void { | ||
auto lhs_typ = args[0].getType(); | ||
if (lhs_typ.isIntOrIndex()) { | ||
if constexpr (!std::is_same_v<IOP, void>) { | ||
auto lhs = doSignCast(builder, loc, args[0]); | ||
auto rhs1 = doSignCast(builder, loc, args[1]); | ||
// auto rhs2 = doSignCast(builder, loc, args[2]); // sometimes there can | ||
// be two params | ||
yield(builder, loc, typ, | ||
builder.create<IOP>(loc, typ, lhs, rhs1).getResult()); | ||
return; | ||
} else | ||
assert("Found integer type but binary op not defined for integers" == | ||
nullptr); | ||
} else if (lhs_typ.isIntOrIndexOrFloat()) { | ||
if constexpr (!std::is_same_v<FOP, void>) { | ||
yield(builder, loc, typ, | ||
builder.create<FOP>(loc, typ, args[0], args[1]).getResult()); | ||
return; | ||
} else | ||
assert("Found float type but binary op not defined for floats" == | ||
nullptr); | ||
} else { | ||
assert("Only integers and floats supported for binary ops" == nullptr); | ||
} | ||
}; | ||
} | ||
|
||
// get a body builder for given binary operation and result type | ||
// we accept a result type to insert a cast after the operation if needed | ||
static BodyType getBodyBuilder(::imex::ptensor::EWBinOpId bop, | ||
::mlir::Type typ) { | ||
switch (bop) { | ||
case ptensor::ADD: | ||
return buildTrivial<mlir::arith::AddIOp, mlir::arith::AddFOp>(typ); | ||
// case ptensor::ATAN2] = | ||
case ptensor::ATAN2: | ||
return buildTrivial<void, mlir::math::Atan2Op>(typ); | ||
case ptensor::FLOOR_DIVIDE: | ||
return buildTrivial<mlir::arith::FloorDivSIOp>(typ); | ||
// case ptensor::LOGADDEXP] = | ||
// case ptensor::LSHIFT] = | ||
// case ptensor::MATMUL] = | ||
case ptensor::LSHIFT: | ||
return buildTosa<mlir::tosa::LogicalLeftShiftOp, void>(typ); | ||
case ptensor::RSHIFT: | ||
return buildTosa<mlir::tosa::LogicalRightShiftOp, void>(typ); | ||
case ptensor::MATMUL: | ||
return buildTosa<void, mlir::tosa::MatMulOp>(typ); | ||
case ptensor::MAXIMUM: | ||
return buildTrivial<mlir::arith::MaxSIOp, mlir::arith::MaxFOp>(typ); | ||
case ptensor::MINIMUM: | ||
|
@@ -314,24 +360,33 @@ static BodyType getBodyBuilder(::imex::ptensor::EWBinOpId bop, | |
return buildTrivial<mlir::arith::RemSIOp, mlir::arith::RemFOp>(typ); | ||
case ptensor::MULTIPLY: | ||
return buildTrivial<mlir::arith::MulIOp, mlir::arith::MulFOp>(typ); | ||
// case ptensor::POW] = | ||
case ptensor::POWER: | ||
return buildTrivial<mlir::math::IPowIOp, mlir::math::PowFOp>(typ); | ||
case ptensor::SUBTRACT: | ||
return buildTrivial<mlir::arith::SubIOp, mlir::arith::SubFOp>(typ); | ||
// case ptensor::TRUE_DIVIDE] = | ||
// case ptensor::BITWISE_AND] = | ||
case ptensor::BITWISE_AND: | ||
return buildTosa<mlir::tosa::BitwiseAndOp, void>(typ); | ||
// case ptensor::BITWISE_LEFT_SHIFT] = | ||
// case ptensor::BITWISE_OR] = | ||
case ptensor::BITWISE_OR: | ||
return buildTosa<mlir::tosa::BitwiseOrOp, void>(typ); | ||
// case ptensor::BITWISE_RIGHT_SHIFT] = | ||
// case ptensor::BITWISE_XOR] = | ||
|
||
// case ptensor::EQUAL] = | ||
// case ptensor::GREATER] = | ||
// case ptensor::GREATER_EQUAL] = | ||
case ptensor::BITWISE_XOR: | ||
return buildTosa<mlir::tosa::BitwiseXorOp, void>(typ); | ||
case ptensor::EQUAL: | ||
return buildTosa<mlir::tosa::EqualOp, void>(typ); | ||
case ptensor::GREATER: | ||
return buildTosa<mlir::tosa::GreaterOp, void>(typ); | ||
case ptensor::GREATER_EQUAL: | ||
return buildTosa<mlir::tosa::GreaterEqualOp, void>(typ); | ||
// case ptensor::LESS] = | ||
// case ptensor::LESS_EQUAL] = | ||
// case ptensor::LOGICAL_AND] = | ||
// case ptensor::LOGICAL_OR] = | ||
// case ptensor::LOGICAL_XOR] = | ||
case ptensor::LOGICAL_AND: | ||
return buildTosa<mlir::tosa::LogicalAndOp, void>(typ); | ||
case ptensor::LOGICAL_OR: | ||
return buildTosa<mlir::tosa::LogicalOrOp, void>(typ); | ||
case ptensor::LOGICAL_XOR: | ||
return buildTosa<mlir::tosa::LogicalXorOp, void>(typ); | ||
// case ptensor::NOT_EQUAL] = | ||
default: | ||
assert("unsupported elementwise binary operation" == nullptr); | ||
|
@@ -592,6 +647,8 @@ struct ConvertPTensorToLinalgPass | |
target.addLegalDialect<::mlir::AffineDialect>(); | ||
target.addLegalDialect<::mlir::tensor::TensorDialect>(); | ||
target.addLegalDialect<::mlir::arith::ArithmeticDialect>(); | ||
target.addLegalDialect<::mlir::math::MathDialect>(); | ||
target.addLegalDialect<::mlir::tosa::TosaDialect>(); | ||
target.addLegalDialect<::mlir::shape::ShapeDialect>(); | ||
target.addLegalOp<::mlir::UnrealizedConversionCastOp>(); // FIXME | ||
target.addDynamicallyLegalOp<::mlir::func::FuncOp>( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess LSHIFT and RSHIFT are duplicates of BITWISE_LEFT_SHIFT and BITWISE_RIGHT_SHIFT.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fschlimb Okay, we can just duplicate them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, I actually thought we should not. Any reason for doing so?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, actually there was
LSHIFT
but noRSHIFT
, so I added it. Which ones we should keep?LSHIFT/RSHIFT
orBITWISE_LEFT_SHIFT/BITWISE_RIGHT_SHIFT
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer BITWISE_*_SHIFT.