Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

arith-to-cggi conversion pass #1248

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
232 changes: 232 additions & 0 deletions lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
#include "lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.h"

#include "lib/Dialect/CGGI/IR/CGGIDialect.h"
#include "lib/Dialect/CGGI/IR/CGGIOps.h"
#include "lib/Dialect/LWE/IR/LWEOps.h"
#include "lib/Dialect/LWE/IR/LWETypes.h"
#include "lib/Utils/ConversionUtils.h"
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project

namespace mlir::heir::arith {

#define GEN_PASS_DEF_ARITHTOCGGI
#include "lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.h.inc"

static lwe::LWECiphertextType convertArithToCGGIType(IntegerType type,
MLIRContext *ctx) {
return lwe::LWECiphertextType::get(ctx,
lwe::UnspecifiedBitFieldEncodingAttr::get(
ctx, type.getIntOrFloatBitWidth()),
lwe::LWEParamsAttr());
;
}

static Type convertArithLikeToCGGIType(ShapedType type, MLIRContext *ctx) {
if (auto arithType = llvm::dyn_cast<IntegerType>(type.getElementType())) {
return type.cloneWith(type.getShape(),
convertArithToCGGIType(arithType, ctx));
}
return type;
}

// Remove this class if no type conversions are necessary
class ArithToCGGITypeConverter : public TypeConverter {
public:
ArithToCGGITypeConverter(MLIRContext *ctx) {
addConversion([](Type type) { return type; });

// Convert Integer types to LWE ciphertext types
addConversion([ctx](IntegerType type) -> Type {
return convertArithToCGGIType(type, ctx);
});

addConversion([ctx](ShapedType type) -> Type {
return convertArithLikeToCGGIType(type, ctx);
});
}
};

class SecretTypeConverter : public TypeConverter {
public:
SecretTypeConverter(MLIRContext *ctx, int minBitWidth)
: minBitWidth(minBitWidth) {
addConversion([](Type type) { return type; });
}

int minBitWidth;
};

struct ConvertConstantOp : public OpConversionPattern<mlir::arith::ConstantOp> {
ConvertConstantOp(mlir::MLIRContext *context)
: OpConversionPattern<mlir::arith::ConstantOp>(context) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
mlir::arith::ConstantOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (isa<IndexType>(op.getValue().getType())) {
return failure();
}
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

auto intValue = cast<IntegerAttr>(op.getValue()).getValue().getSExtValue();
auto inputValue = mlir::IntegerAttr::get(op.getType(), intValue);

auto encoding = lwe::UnspecifiedBitFieldEncodingAttr::get(
op->getContext(), op.getValue().getType().getIntOrFloatBitWidth());
auto lweType = lwe::LWECiphertextType::get(op->getContext(), encoding,
lwe::LWEParamsAttr());

auto encrypt = b.create<cggi::CreateTrivialOp>(lweType, inputValue);

rewriter.replaceOp(op, encrypt);
return success();
}
};

struct ConvertTruncIOp : public OpConversionPattern<mlir::arith::TruncIOp> {
ConvertTruncIOp(mlir::MLIRContext *context)
: OpConversionPattern<mlir::arith::TruncIOp>(context) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
mlir::arith::TruncIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

auto outType = convertArithToCGGIType(
cast<IntegerType>(op.getResult().getType()), op->getContext());
auto castOp = b.create<cggi::CastOp>(op.getLoc(), outType, adaptor.getIn());

rewriter.replaceOp(op, castOp);
return success();
}
};

struct ConvertExtUIOp : public OpConversionPattern<mlir::arith::ExtUIOp> {
ConvertExtUIOp(mlir::MLIRContext *context)
: OpConversionPattern<mlir::arith::ExtUIOp>(context) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
mlir::arith::ExtUIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

auto outType = convertArithToCGGIType(
cast<IntegerType>(op.getResult().getType()), op->getContext());
auto castOp = b.create<cggi::CastOp>(op.getLoc(), outType, adaptor.getIn());

rewriter.replaceOp(op, castOp);
return success();
}
};

struct ConvertShRUIOp : public OpConversionPattern<mlir::arith::ShRUIOp> {
ConvertShRUIOp(mlir::MLIRContext *context)
: OpConversionPattern<mlir::arith::ShRUIOp>(context) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
mlir::arith::ShRUIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

auto cteShiftSizeOp = op.getRhs().getDefiningOp<mlir::arith::ConstantOp>();

if (cteShiftSizeOp) {
auto outputType = adaptor.getLhs().getType();

auto shiftAmount = cast<IntegerAttr>(cteShiftSizeOp.getValue())
.getValue()
.getSExtValue();

auto inputValue =
mlir::IntegerAttr::get(rewriter.getI8Type(), (int8_t)shiftAmount);
auto cteOp = rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), rewriter.getI8Type(), inputValue);

auto shiftOp =
b.create<cggi::ShiftRightOp>(outputType, adaptor.getLhs(), cteOp);
rewriter.replaceOp(op, shiftOp);

return success();
}

cteShiftSizeOp = op.getLhs().getDefiningOp<mlir::arith::ConstantOp>();

auto outputType = adaptor.getRhs().getType();

auto shiftAmount =
cast<IntegerAttr>(cteShiftSizeOp.getValue()).getValue().getSExtValue();

auto inputValue = mlir::IntegerAttr::get(rewriter.getI8Type(), shiftAmount);
auto cteOp = rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), rewriter.getI8Type(), inputValue);

auto shiftOp =
b.create<cggi::ShiftRightOp>(outputType, adaptor.getLhs(), cteOp);
rewriter.replaceOp(op, shiftOp);
rewriter.replaceOp(op.getLhs().getDefiningOp(), cteOp);

return success();
}
};

struct ArithToCGGI : public impl::ArithToCGGIBase<ArithToCGGI> {
void runOnOperation() override {
MLIRContext *context = &getContext();
auto *module = getOperation();
ArithToCGGITypeConverter typeConverter(context);

RewritePatternSet patterns(context);
ConversionTarget target(*context);
target.addLegalDialect<cggi::CGGIDialect>();
target.addIllegalDialect<mlir::arith::ArithDialect>();

target.addDynamicallyLegalOp<mlir::arith::ConstantOp>(
[](mlir::arith::ConstantOp op) {
// Allow use of constant if it is used to denote the size of a shift
bool usedByShift = llvm::any_of(op->getUsers(), [&](Operation *user) {
return isa<cggi::ShiftRightOp>(user);
});
return (isa<IndexType>(op.getValue().getType()) || (usedByShift));
});

target.addDynamicallyLegalOp<
memref::AllocOp, memref::DeallocOp, memref::StoreOp, memref::SubViewOp,
memref::CopyOp, tensor::FromElementsOp, tensor::ExtractOp,
affine::AffineStoreOp, affine::AffineLoadOp>([&](Operation *op) {
return typeConverter.isLegal(op->getOperandTypes()) &&
typeConverter.isLegal(op->getResultTypes());
});

patterns.add<
ConvertConstantOp, ConvertTruncIOp, ConvertExtUIOp, ConvertShRUIOp,
ConvertBinOp<mlir::arith::AddIOp, cggi::AddOp>,
ConvertBinOp<mlir::arith::MulIOp, cggi::MulOp>,
ConvertBinOp<mlir::arith::SubIOp, cggi::SubOp>,
ConvertAny<memref::LoadOp>, ConvertAny<memref::AllocOp>,
ConvertAny<memref::DeallocOp>, ConvertAny<memref::StoreOp>,
ConvertAny<memref::SubViewOp>, ConvertAny<memref::CopyOp>,
ConvertAny<tensor::FromElementsOp>, ConvertAny<tensor::ExtractOp>,
ConvertAny<affine::AffineStoreOp>, ConvertAny<affine::AffineLoadOp> >(
typeConverter, context);

addStructuralConversionPatterns(typeConverter, patterns, target);

if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
return signalPassFailure();
}
}
};

} // namespace mlir::heir::arith
16 changes: 16 additions & 0 deletions lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#ifndef LIB_DIALECT_ARITH_CONVERSIONS_ARITHTOCGGI_ARITHTOCGGI_H_
#define LIB_DIALECT_ARITH_CONVERSIONS_ARITHTOCGGI_ARITHTOCGGI_H_

#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project

namespace mlir::heir::arith {

#define GEN_PASS_DECL
#include "lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.h.inc"

#define GEN_PASS_REGISTRATION
#include "lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.h.inc"

} // namespace mlir::heir::arith

#endif // LIB_DIALECT_ARITH_CONVERSIONS_ARITHTOCGGI_ARITHTOCGGI_H_
14 changes: 14 additions & 0 deletions lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#ifndef LIB_DIALECT_ARITH_CONVERSIONS_ARITHTOCGGI_TD_
#define LIB_DIALECT_ARITH_CONVERSIONS_ARITHTOCGGI_TD_

include "mlir/Pass/PassBase.td"

def ArithToCGGI : Pass<"arith-to-cggi"> {
let summary = "Lower `arith` to `cggi` dialect.";
let dependentDialects = [
"mlir::arith::ArithDialect",
"mlir::heir::cggi::CGGIDialect",
];
}

#endif // LIB_DIALECT_ARITH_CONVERSIONS_ARITHTOCGGI_ARITHTOCGGI_TD_
44 changes: 44 additions & 0 deletions lib/Dialect/Arith/Conversions/ArithToCGGI/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")

package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

cc_library(
name = "ArithToCGGI",
srcs = ["ArithToCGGI.cpp"],
hdrs = ["ArithToCGGI.h"],
deps = [
":pass_inc_gen",
"@heir//lib/Dialect/CGGI/IR:Dialect",
"@heir//lib/Utils:ConversionUtils",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Transforms",
],
)

gentbl_cc_library(
name = "pass_inc_gen",
tbl_outs = [
(
[
"-gen-pass-decls",
"-name=ArithToCGGI",
],
"ArithToCGGI.h.inc",
),
(
["-gen-pass-doc"],
"ArithToCGGI.md",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "ArithToCGGI.td",
deps = [
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:PassBaseTdFiles",
],
)
5 changes: 2 additions & 3 deletions lib/Dialect/Arith/Transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ cc_library(
deps = [
":QuarterWideInt",
":pass_inc_gen",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:IR",
],
)

Expand All @@ -21,16 +23,13 @@ cc_library(
deps = [
":pass_inc_gen",
"@heir//lib/Utils:ConversionUtils",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:FuncTransforms",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:TransformUtils",
"@llvm-project//mlir:Transforms",
"@llvm-project//mlir:VectorDialect",
],
)

Expand Down
30 changes: 30 additions & 0 deletions lib/Dialect/CGGI/Conversions/CGGIToTfheRust/CGGIToTfheRust.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,36 @@ struct ConvertTrivialEncryptOp
}
};

struct ConvertTrivialOp : public OpConversionPattern<cggi::CreateTrivialOp> {
ConvertTrivialOp(mlir::MLIRContext *context)
: OpConversionPattern<cggi::CreateTrivialOp>(context, /*benefit=*/2) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
cggi::CreateTrivialOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
FailureOr<Value> result = getContextualServerKey(op.getOperation());
if (failed(result)) return result;

Value serverKey = result.value();

auto intValue = op.getValue().getValue().getSExtValue();
auto inputValue = mlir::IntegerAttr::get(op.getValue().getType(), intValue);
auto constantWidth = op.getValue().getValue().getBitWidth();

auto cteOp = rewriter.create<arith::ConstantOp>(
op.getLoc(), rewriter.getIntegerType(constantWidth), inputValue);

auto outputType = encrytpedUIntTypeFromWidth(getContext(), constantWidth);

auto createTrivialOp = rewriter.create<tfhe_rust::CreateTrivialOp>(
op.getLoc(), outputType, serverKey, cteOp);
rewriter.replaceOp(op, createTrivialOp);
return success();
}
};

struct ConvertEncodeOp : public OpConversionPattern<lwe::EncodeOp> {
ConvertEncodeOp(mlir::MLIRContext *context)
: OpConversionPattern<lwe::EncodeOp>(context) {}
Expand Down
Loading
Loading