From 7311556e83b0a8ca6e04fb6a04b3c12be7946f7f Mon Sep 17 00:00:00 2001 From: WoutLegiest Date: Wed, 25 Dec 2024 01:30:54 +0000 Subject: [PATCH] Quarter Wide Int Arith Pass --- .../Conversions/ArithToCGGI/ArithToCGGI.cpp | 232 ++++++++++++++++++ .../Conversions/ArithToCGGI/ArithToCGGI.h | 16 ++ .../Conversions/ArithToCGGI/ArithToCGGI.td | 14 ++ .../Arith/Conversions/ArithToCGGI/BUILD | 44 ++++ lib/Dialect/Arith/Transforms/BUILD | 5 +- .../CGGIToTfheRust/CGGIToTfheRust.cpp | 30 +++ lib/Dialect/CGGI/IR/CGGIOps.td | 28 ++- lib/Dialect/LWE/IR/LWETypes.td | 5 + .../Arith/Conversions/ArithToCGGI/BUILD | 10 + .../ArithToCGGI/arith-to-cggi.mlir | 71 ++++++ tools/BUILD | 1 + tools/heir-opt.cpp | 4 +- 12 files changed, 455 insertions(+), 5 deletions(-) create mode 100644 lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.cpp create mode 100644 lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.h create mode 100644 lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.td create mode 100644 lib/Dialect/Arith/Conversions/ArithToCGGI/BUILD create mode 100644 tests/Dialect/Arith/Conversions/ArithToCGGI/BUILD create mode 100644 tests/Dialect/Arith/Conversions/ArithToCGGI/arith-to-cggi.mlir diff --git a/lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.cpp b/lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.cpp new file mode 100644 index 000000000..52ff51c34 --- /dev/null +++ b/lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.cpp @@ -0,0 +1,232 @@ +#include "lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.h" + +#include "lib/Dialect/CGGI/IR/CGGIDialect.h" +#include "lib/Dialect/CGGI/IR/CGGIOps.h" +#include "lib/Dialect/LWE/IR/LWEOps.h" +#include "lib/Dialect/LWE/IR/LWETypes.h" +#include "lib/Utils/ConversionUtils.h" +#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project +#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project + +namespace mlir::heir::arith { + +#define GEN_PASS_DEF_ARITHTOCGGI +#include "lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.h.inc" + +static lwe::LWECiphertextType convertArithToCGGIType(IntegerType type, + MLIRContext *ctx) { + return lwe::LWECiphertextType::get(ctx, + lwe::UnspecifiedBitFieldEncodingAttr::get( + ctx, type.getIntOrFloatBitWidth()), + lwe::LWEParamsAttr()); + ; +} + +static Type convertArithLikeToCGGIType(ShapedType type, MLIRContext *ctx) { + if (auto arithType = llvm::dyn_cast(type.getElementType())) { + return type.cloneWith(type.getShape(), + convertArithToCGGIType(arithType, ctx)); + } + return type; +} + +// Remove this class if no type conversions are necessary +class ArithToCGGITypeConverter : public TypeConverter { + public: + ArithToCGGITypeConverter(MLIRContext *ctx) { + addConversion([](Type type) { return type; }); + + // Convert Integer types to LWE ciphertext types + addConversion([ctx](IntegerType type) -> Type { + return convertArithToCGGIType(type, ctx); + }); + + addConversion([ctx](ShapedType type) -> Type { + return convertArithLikeToCGGIType(type, ctx); + }); + } +}; + +class SecretTypeConverter : public TypeConverter { + public: + SecretTypeConverter(MLIRContext *ctx, int minBitWidth) + : minBitWidth(minBitWidth) { + addConversion([](Type type) { return type; }); + } + + int minBitWidth; +}; + +struct ConvertConstantOp : public OpConversionPattern { + ConvertConstantOp(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mlir::arith::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (isa(op.getValue().getType())) { + return failure(); + } + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + + auto intValue = cast(op.getValue()).getValue().getSExtValue(); + auto inputValue = mlir::IntegerAttr::get(op.getType(), intValue); + + auto encoding = lwe::UnspecifiedBitFieldEncodingAttr::get( + op->getContext(), op.getValue().getType().getIntOrFloatBitWidth()); + auto lweType = lwe::LWECiphertextType::get(op->getContext(), encoding, + lwe::LWEParamsAttr()); + + auto encrypt = b.create(lweType, inputValue); + + rewriter.replaceOp(op, encrypt); + return success(); + } +}; + +struct ConvertTruncIOp : public OpConversionPattern { + ConvertTruncIOp(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mlir::arith::TruncIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + + auto outType = convertArithToCGGIType( + cast(op.getResult().getType()), op->getContext()); + auto castOp = b.create(op.getLoc(), outType, adaptor.getIn()); + + rewriter.replaceOp(op, castOp); + return success(); + } +}; + +struct ConvertExtUIOp : public OpConversionPattern { + ConvertExtUIOp(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mlir::arith::ExtUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + + auto outType = convertArithToCGGIType( + cast(op.getResult().getType()), op->getContext()); + auto castOp = b.create(op.getLoc(), outType, adaptor.getIn()); + + rewriter.replaceOp(op, castOp); + return success(); + } +}; + +struct ConvertShRUIOp : public OpConversionPattern { + ConvertShRUIOp(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mlir::arith::ShRUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + + auto cteShiftSizeOp = op.getRhs().getDefiningOp(); + + if (cteShiftSizeOp) { + auto outputType = adaptor.getLhs().getType(); + + auto shiftAmount = cast(cteShiftSizeOp.getValue()) + .getValue() + .getSExtValue(); + + auto inputValue = + mlir::IntegerAttr::get(rewriter.getI8Type(), (int8_t)shiftAmount); + auto cteOp = rewriter.create( + op.getLoc(), rewriter.getI8Type(), inputValue); + + auto shiftOp = + b.create(outputType, adaptor.getLhs(), cteOp); + rewriter.replaceOp(op, shiftOp); + + return success(); + } + + cteShiftSizeOp = op.getLhs().getDefiningOp(); + + auto outputType = adaptor.getRhs().getType(); + + auto shiftAmount = + cast(cteShiftSizeOp.getValue()).getValue().getSExtValue(); + + auto inputValue = mlir::IntegerAttr::get(rewriter.getI8Type(), shiftAmount); + auto cteOp = rewriter.create( + op.getLoc(), rewriter.getI8Type(), inputValue); + + auto shiftOp = + b.create(outputType, adaptor.getLhs(), cteOp); + rewriter.replaceOp(op, shiftOp); + rewriter.replaceOp(op.getLhs().getDefiningOp(), cteOp); + + return success(); + } +}; + +struct ArithToCGGI : public impl::ArithToCGGIBase { + void runOnOperation() override { + MLIRContext *context = &getContext(); + auto *module = getOperation(); + ArithToCGGITypeConverter typeConverter(context); + + RewritePatternSet patterns(context); + ConversionTarget target(*context); + target.addLegalDialect(); + target.addIllegalDialect(); + + target.addDynamicallyLegalOp( + [](mlir::arith::ConstantOp op) { + // Allow use of constant if it is used to denote the size of a shift + bool usedByShift = llvm::any_of(op->getUsers(), [&](Operation *user) { + return isa(user); + }); + return (isa(op.getValue().getType()) || (usedByShift)); + }); + + target.addDynamicallyLegalOp< + memref::AllocOp, memref::DeallocOp, memref::StoreOp, memref::SubViewOp, + memref::CopyOp, tensor::FromElementsOp, tensor::ExtractOp, + affine::AffineStoreOp, affine::AffineLoadOp>([&](Operation *op) { + return typeConverter.isLegal(op->getOperandTypes()) && + typeConverter.isLegal(op->getResultTypes()); + }); + + patterns.add< + ConvertConstantOp, ConvertTruncIOp, ConvertExtUIOp, ConvertShRUIOp, + ConvertBinOp, + ConvertBinOp, + ConvertBinOp, + ConvertAny, ConvertAny, + ConvertAny, ConvertAny, + ConvertAny, ConvertAny, + ConvertAny, ConvertAny, + ConvertAny, ConvertAny >( + typeConverter, context); + + addStructuralConversionPatterns(typeConverter, patterns, target); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) { + return signalPassFailure(); + } + } +}; + +} // namespace mlir::heir::arith diff --git a/lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.h b/lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.h new file mode 100644 index 000000000..2197eed72 --- /dev/null +++ b/lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.h @@ -0,0 +1,16 @@ +#ifndef LIB_DIALECT_ARITH_CONVERSIONS_ARITHTOCGGI_ARITHTOCGGI_H_ +#define LIB_DIALECT_ARITH_CONVERSIONS_ARITHTOCGGI_ARITHTOCGGI_H_ + +#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir::heir::arith { + +#define GEN_PASS_DECL +#include "lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.h.inc" + +#define GEN_PASS_REGISTRATION +#include "lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.h.inc" + +} // namespace mlir::heir::arith + +#endif // LIB_DIALECT_ARITH_CONVERSIONS_ARITHTOCGGI_ARITHTOCGGI_H_ diff --git a/lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.td b/lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.td new file mode 100644 index 000000000..876e2c589 --- /dev/null +++ b/lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.td @@ -0,0 +1,14 @@ +#ifndef LIB_DIALECT_ARITH_CONVERSIONS_ARITHTOCGGI_TD_ +#define LIB_DIALECT_ARITH_CONVERSIONS_ARITHTOCGGI_TD_ + +include "mlir/Pass/PassBase.td" + +def ArithToCGGI : Pass<"arith-to-cggi"> { + let summary = "Lower `arith` to `cggi` dialect."; + let dependentDialects = [ + "mlir::arith::ArithDialect", + "mlir::heir::cggi::CGGIDialect", + ]; +} + +#endif // LIB_DIALECT_ARITH_CONVERSIONS_ARITHTOCGGI_ARITHTOCGGI_TD_ diff --git a/lib/Dialect/Arith/Conversions/ArithToCGGI/BUILD b/lib/Dialect/Arith/Conversions/ArithToCGGI/BUILD new file mode 100644 index 000000000..62cf1d599 --- /dev/null +++ b/lib/Dialect/Arith/Conversions/ArithToCGGI/BUILD @@ -0,0 +1,44 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") + +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "ArithToCGGI", + srcs = ["ArithToCGGI.cpp"], + hdrs = ["ArithToCGGI.h"], + deps = [ + ":pass_inc_gen", + "@heir//lib/Dialect/CGGI/IR:Dialect", + "@heir//lib/Utils:ConversionUtils", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Transforms", + ], +) + +gentbl_cc_library( + name = "pass_inc_gen", + tbl_outs = [ + ( + [ + "-gen-pass-decls", + "-name=ArithToCGGI", + ], + "ArithToCGGI.h.inc", + ), + ( + ["-gen-pass-doc"], + "ArithToCGGI.md", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "ArithToCGGI.td", + deps = [ + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:PassBaseTdFiles", + ], +) diff --git a/lib/Dialect/Arith/Transforms/BUILD b/lib/Dialect/Arith/Transforms/BUILD index f74fc9c22..a9b35b87f 100644 --- a/lib/Dialect/Arith/Transforms/BUILD +++ b/lib/Dialect/Arith/Transforms/BUILD @@ -11,6 +11,8 @@ cc_library( deps = [ ":QuarterWideInt", ":pass_inc_gen", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", ], ) @@ -21,16 +23,13 @@ cc_library( deps = [ ":pass_inc_gen", "@heir//lib/Utils:ConversionUtils", - "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FuncTransforms", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:TensorDialect", - "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", - "@llvm-project//mlir:VectorDialect", ], ) diff --git a/lib/Dialect/CGGI/Conversions/CGGIToTfheRust/CGGIToTfheRust.cpp b/lib/Dialect/CGGI/Conversions/CGGIToTfheRust/CGGIToTfheRust.cpp index d9130b84c..7f16fed4a 100644 --- a/lib/Dialect/CGGI/Conversions/CGGIToTfheRust/CGGIToTfheRust.cpp +++ b/lib/Dialect/CGGI/Conversions/CGGIToTfheRust/CGGIToTfheRust.cpp @@ -364,6 +364,36 @@ struct ConvertTrivialEncryptOp } }; +struct ConvertTrivialOp : public OpConversionPattern { + ConvertTrivialOp(mlir::MLIRContext *context) + : OpConversionPattern(context, /*benefit=*/2) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + cggi::CreateTrivialOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr result = getContextualServerKey(op.getOperation()); + if (failed(result)) return result; + + Value serverKey = result.value(); + + auto intValue = op.getValue().getValue().getSExtValue(); + auto inputValue = mlir::IntegerAttr::get(op.getValue().getType(), intValue); + auto constantWidth = op.getValue().getValue().getBitWidth(); + + auto cteOp = rewriter.create( + op.getLoc(), rewriter.getIntegerType(constantWidth), inputValue); + + auto outputType = encrytpedUIntTypeFromWidth(getContext(), constantWidth); + + auto createTrivialOp = rewriter.create( + op.getLoc(), outputType, serverKey, cteOp); + rewriter.replaceOp(op, createTrivialOp); + return success(); + } +}; + struct ConvertEncodeOp : public OpConversionPattern { ConvertEncodeOp(mlir::MLIRContext *context) : OpConversionPattern(context) {} diff --git a/lib/Dialect/CGGI/IR/CGGIOps.td b/lib/Dialect/CGGI/IR/CGGIOps.td index 777ee7287..537670903 100644 --- a/lib/Dialect/CGGI/IR/CGGIOps.td +++ b/lib/Dialect/CGGI/IR/CGGIOps.td @@ -59,7 +59,6 @@ def CGGI_MulOp : CGGI_BinaryOp<"mul"> { }]; } - def CGGI_NotOp : CGGI_Op<"not", [ Pure, Involution, @@ -312,4 +311,31 @@ def CGGI_ShiftLeftOp : CGGI_Op<"shl", [ let summary = "Arithmetic shift to left of a ciphertext by an integer. Note this operations to mirror the TFHE-rs implmementation."; } +// FIXME: Two options: +// 1. Allow arith.constant and use an encryption op to bring it to the ciphertext space. +// 2. Use a trivial op where the constant is embedded in the ciphertext. +def CGGI_CreateTrivialOp : CGGI_Op<"create_trivial", [Pure]> { + let arguments = (ins Builtin_IntegerAttr:$value); + let results = (outs LWECiphertextLike:$output); +} + +def CGGI_CastOp : CGGI_Op<"cast", [Pure, SameOperandsAndResultShape]> { + let summary = "change the plaintext space of a CGGI ciphertext"; + + let description = [{ + "cast" operation to change the plaintext size of a CGGI ciphertext. + Note this operations is not a standard CGGI operation, but an mirror of the cast op implemented in TFHE-rs. + + Examples: + ``` + `cggi.cast %c0 : !lwe.lwe_ciphertext to !lwe.lwe_ciphertext` + ``` + }]; + + let arguments = (ins LWECiphertextLike:$input); + let results = (outs LWECiphertextLike:$output); + + let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)"; +} + #endif // LIB_DIALECT_CGGI_IR_CGGIOPS_TD_ diff --git a/lib/Dialect/LWE/IR/LWETypes.td b/lib/Dialect/LWE/IR/LWETypes.td index 8cf1fc552..5ed5d6d60 100644 --- a/lib/Dialect/LWE/IR/LWETypes.td +++ b/lib/Dialect/LWE/IR/LWETypes.td @@ -74,6 +74,8 @@ def LWEPlaintext : LWE_Type<"LWEPlaintext", "lwe_plaintext"> { let nameSuggestion = "pt"; } +def LWEPlaintextLike : TypeOrContainer; + def RLWEPlaintext : LWE_Type<"RLWEPlaintext", "rlwe_plaintext"> { let summary = "A type for RLWE plaintexts"; @@ -85,4 +87,7 @@ def RLWEPlaintext : LWE_Type<"RLWEPlaintext", "rlwe_plaintext"> { let nameSuggestion = "pt"; } +def RLWEPlaintextLike : TypeOrContainer; + + #endif // LIB_DIALECT_LWE_IR_LWETYPES_TD_ diff --git a/tests/Dialect/Arith/Conversions/ArithToCGGI/BUILD b/tests/Dialect/Arith/Conversions/ArithToCGGI/BUILD new file mode 100644 index 000000000..c571e6fc6 --- /dev/null +++ b/tests/Dialect/Arith/Conversions/ArithToCGGI/BUILD @@ -0,0 +1,10 @@ +load("//bazel:lit.bzl", "glob_lit_tests") + +package(default_applicable_licenses = ["@heir//:license"]) + +glob_lit_tests( + name = "all_tests", + data = ["@heir//tests:test_utilities"], + driver = "@heir//tests:run_lit.sh", + test_file_exts = ["mlir"], +) diff --git a/tests/Dialect/Arith/Conversions/ArithToCGGI/arith-to-cggi.mlir b/tests/Dialect/Arith/Conversions/ArithToCGGI/arith-to-cggi.mlir new file mode 100644 index 000000000..dfc88b26a --- /dev/null +++ b/tests/Dialect/Arith/Conversions/ArithToCGGI/arith-to-cggi.mlir @@ -0,0 +1,71 @@ +// RUN: heir-opt --arith-to-cggi --split-input-file %s | FileCheck %s --enable-var-scope + +// CHECK-LABEL: @test_lower_add +// CHECK-SAME: (%[[LHS:.*]]: !lwe.lwe_ciphertext, %[[RHS:.*]]: !lwe.lwe_ciphertext) -> [[T:.*]] { +func.func @test_lower_add(%lhs : i32, %rhs : i32) -> i32 { + // CHECK: %[[ADD:.*]] = cggi.add %[[LHS]], %[[RHS]] : [[T]] + // CHECK: return %[[ADD:.*]] : [[T]] + %res = arith.addi %lhs, %rhs : i32 + return %res : i32 +} + +// CHECK-LABEL: @test_lower_add_vec +// CHECK-SAME: (%[[LHS:.*]]: tensor<4x!lwe.lwe_ciphertext>, %[[RHS:.*]]: tensor<4x!lwe.lwe_ciphertext>) -> [[T:.*]] { +func.func @test_lower_add_vec(%lhs : tensor<4xi32>, %rhs : tensor<4xi32>) -> tensor<4xi32> { + // CHECK: %[[ADD:.*]] = cggi.add %[[LHS]], %[[RHS]] : [[T]] + // CHECK: return %[[ADD:.*]] : [[T]] + %res = arith.addi %lhs, %rhs : tensor<4xi32> + return %res : tensor<4xi32> +} + +// CHECK-LABEL: @test_lower_sub_vec +// CHECK-SAME: (%[[LHS:.*]]: [[T:.*]], %[[RHS:.*]]: [[T]]) -> [[T]] { +func.func @test_lower_sub_vec(%lhs : tensor<4xi32>, %rhs : tensor<4xi32>) -> tensor<4xi32> { + // CHECK: %[[ADD:.*]] = cggi.sub %[[LHS]], %[[RHS]] : [[T]] + // CHECK: return %[[ADD:.*]] : [[T]] + %res = arith.subi %lhs, %rhs : tensor<4xi32> + return %res : tensor<4xi32> +} + +// CHECK-LABEL: @test_lower_sub +// CHECK-SAME: (%[[LHS:.*]]: !lwe.lwe_ciphertext, %[[RHS:.*]]: !lwe.lwe_ciphertext) -> [[T:.*]] { +func.func @test_lower_sub(%lhs : i16, %rhs : i16) -> i16 { + // CHECK: %[[ADD:.*]] = cggi.sub %[[LHS]], %[[RHS]] : [[T]] + // CHECK: return %[[ADD:.*]] : [[T]] + %res = arith.subi %lhs, %rhs : i16 + return %res : i16 +} + +// CHECK-LABEL: @test_lower_mul +// CHECK-SAME: (%[[LHS:.*]]: !lwe.lwe_ciphertext, %[[RHS:.*]]: !lwe.lwe_ciphertext) -> [[T:.*]] { +func.func @test_lower_mul(%lhs : i8, %rhs : i8) -> i8 { + // CHECK: %[[ADD:.*]] = cggi.mul %[[LHS]], %[[RHS]] : [[T]] + // CHECK: return %[[ADD:.*]] : [[T]] + %res = arith.muli %lhs, %rhs : i8 + return %res : i8 +} + +// CHECK-LABEL: @test_lower_mul_vec +// CHECK-SAME: (%[[LHS:.*]]: [[T:.*]], %[[RHS:.*]]: [[T]]) -> [[T]] { +func.func @test_lower_mul_vec(%lhs : tensor<4xi8>, %rhs : tensor<4xi8>) -> tensor<4xi8> { + // CHECK: %[[ADD:.*]] = cggi.mul %[[LHS]], %[[RHS]] : [[T]] + // CHECK: return %[[ADD:.*]] : [[T]] + %res = arith.muli %lhs, %rhs : tensor<4xi8> + return %res : tensor<4xi8> +} + + +// CHECK-LABEL: @test_affine +// CHECK-SAME: (%[[ARG:.*]]: memref<1x1x!lwe.lwe_ciphertext>) -> [[T:.*]] { +func.func @test_affine(%arg0: memref<1x1xi32>) -> memref<1x1xi32> { + // CHECK: return %[[ADD:.*]] : [[T]] + %c429_i32 = arith.constant 429 : i32 + %c33_i8 = arith.constant 33 : i32 + %0 = affine.load %arg0[0, 0] : memref<1x1xi32> + %c0 = arith.constant 0 : index + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x1xi32> + %25 = arith.muli %0, %c33_i8 : i32 + %26 = arith.addi %c429_i32, %25 : i32 + affine.store %26, %alloc[0, 0] : memref<1x1xi32> + return %alloc : memref<1x1xi32> +} diff --git a/tools/BUILD b/tools/BUILD index d7f97305e..1bb8d2d94 100644 --- a/tools/BUILD +++ b/tools/BUILD @@ -32,6 +32,7 @@ cc_binary( }), includes = ["include"], deps = [ + "@heir//lib/Dialect/Arith/Conversions/ArithToCGGI", "@heir//lib/Dialect/Arith/Conversions/ArithToModArith", "@heir//lib/Dialect/Arith/Transforms", "@heir//lib/Dialect/Arith/Transforms:QuarterWideInt", diff --git a/tools/heir-opt.cpp b/tools/heir-opt.cpp index 93a4507fd..496fabf99 100644 --- a/tools/heir-opt.cpp +++ b/tools/heir-opt.cpp @@ -3,6 +3,7 @@ #include #include +#include "lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.h" #include "lib/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.h" #include "lib/Dialect/Arith/Transforms/Passes.h" #include "lib/Dialect/BGV/Conversions/BGVToLWE/BGVToLWE.h" @@ -305,7 +306,8 @@ int main(int argc, char **argv) { // Dialect conversion passes in HEIR mod_arith::registerModArithToArithPasses(); - ::mlir::heir::arith::registerArithToModArithPasses(); + mlir::heir::arith::registerArithToModArithPasses(); + mlir::heir::arith::registerArithToCGGIPasses(); mod_arith::registerConvertToMacPass(); bgv::registerBGVToLWEPasses(); bgv::registerBGVToLattigoPasses();