diff --git a/lib/Conversion/CombToCGGI/CombToCGGI.cpp b/lib/Conversion/CombToCGGI/CombToCGGI.cpp index 87a14100e4..fb32b2e45e 100644 --- a/lib/Conversion/CombToCGGI/CombToCGGI.cpp +++ b/lib/Conversion/CombToCGGI/CombToCGGI.cpp @@ -190,6 +190,37 @@ Operation *convertReadOpInterface(Operation *op, SmallVector indices, return subViewOp; } +SmallVector encodeInputs(Operation *op, ValueRange inputs, + ConversionPatternRewriter &rewriter) { + // Get the ciphertext type. + lwe::LWECiphertextType ctxtTy; + for (auto input : inputs) { + if (isa(input.getType())) { + ctxtTy = cast(input.getType()); + break; + } + } + + // Encode any plaintexts in the inputs. + auto encoding = cast(ctxtTy).getEncoding(); + auto ptxtTy = lwe::LWEPlaintextType::get(rewriter.getContext(), encoding); + return llvm::to_vector(llvm::map_range(inputs, [&](auto input) -> Value { + if (!isa(input.getType())) { + IntegerType integerTy = dyn_cast(input.getType()); + assert(integerTy && integerTy.getWidth() == 1 && + "LUT inputs should be single-bit integers"); + return rewriter + .create( + op->getLoc(), ctxtTy, + rewriter.create(op->getLoc(), ptxtTy, input, + encoding), + lwe::LWEParamsAttr()) + .getResult(); + } + return input; + })); +} + } // namespace class SecretTypeConverter : public TypeConverter { @@ -293,34 +324,8 @@ class SecretGenericOpLUTConversion void replaceOp(secret::GenericOp op, TypeRange outputTypes, ValueRange inputs, ArrayRef attributes, ConversionPatternRewriter &rewriter) const override { - // Get the ciphertext type. - lwe::LWECiphertextType ctxtTy; - for (auto input : inputs) { - if (isa(input.getType())) { - ctxtTy = cast(input.getType()); - break; - } - } - - // Encode any plaintexts in the inputs. - auto encoding = cast(ctxtTy).getEncoding(); - auto ptxtTy = lwe::LWEPlaintextType::get(rewriter.getContext(), encoding); SmallVector encodedInputs = - llvm::to_vector(llvm::map_range(inputs, [&](auto input) -> Value { - if (!isa(input.getType())) { - IntegerType integerTy = dyn_cast(input.getType()); - assert(integerTy && integerTy.getWidth() == 1 && - "LUT inputs should be single-bit integers"); - return rewriter - .create( - op.getLoc(), ctxtTy, - rewriter.create(op.getLoc(), ptxtTy, input, - encoding), - lwe::LWEParamsAttr()) - .getResult(); - } - return input; - })); + encodeInputs(op.getOperation(), inputs, rewriter); // Assemble the lookup table. comb::TruthTableOp truthOp = @@ -358,8 +363,9 @@ class SecretGenericOpGateConversion : public SecretGenericOpConversion { void replaceOp(secret::GenericOp op, TypeRange outputTypes, ValueRange inputs, ArrayRef attributes, ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, outputTypes, inputs, - attributes); + rewriter.replaceOpWithNewOp( + op, outputTypes, encodeInputs(op.getOperation(), inputs, rewriter), + attributes); } }; @@ -469,6 +475,18 @@ class SecretGenericOpMemRefAllocConversion } }; +class SecretGenericOpMemRefDeallocConversion + : public SecretGenericOpConversion { + using SecretGenericOpConversion::SecretGenericOpConversion; + + void replaceOp(secret::GenericOp op, TypeRange outputTypes, ValueRange inputs, + ArrayRef attributes, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, outputTypes, inputs, + attributes); + } +}; + // ConvertTruthTableOp converts truth table ops with fully plaintext values. struct ConvertTruthTableOp : public OpConversionPattern { ConvertTruthTableOp(mlir::MLIRContext *context) @@ -542,6 +560,7 @@ struct CombToCGGI : public impl::CombToCGGIBase { patterns .add]]) -> [[LWET]] +func.func @boolean_gates(%arg0: !secret.secret) -> !secret.secret { + // CHECK: [[VAL1:%.+]] = cggi.and [[ARG]], [[ARG]] + // CHECK: [[VAL2:%.+]] = cggi.or [[VAL1]], [[ARG]] + // CHECK: [[VAL3:%.+]] = cggi.nand [[VAL2]], [[VAL1]] + // CHECK: [[VAL4:%.+]] = cggi.xor [[VAL3]], [[VAL2]] + // CHECK: [[VAL5:%.+]] = cggi.xnor [[VAL4]], [[VAL3]] + // CHECK: [[VAL6:%.+]] = cggi.nor [[VAL5]], [[VAL4]] + %0 = secret.generic + ins(%arg0: !secret.secret) { + ^bb0(%ARG0: i1) : + %1 = comb.and %ARG0, %ARG0 : i1 + %2 = comb.or %1, %ARG0 : i1 + %3 = comb.nand %2, %1 : i1 + %4 = comb.xor %3, %2 : i1 + %5 = comb.xnor %4, %3 : i1 + %6 = comb.nor %5, %4 : i1 + secret.yield %6 : i1 + } -> (!secret.secret) + // CHECK: return [[VAL6]] : [[LWET]] + func.return %0 : !secret.secret +} + +// ----- + +// CHECK-NOT: secret +// CHECK: @boolean_gates_partial_secret( +// CHECK-SAME: [[ARG0:%.*]]: [[LWET:!lwe.lwe_ciphertext<.*>]], [[ARG1:%.*]]: i1) -> [[LWET]] +func.func @boolean_gates_partial_secret(%arg0: !secret.secret, %arg1 : i1) -> !secret.secret { + // CHECK: [[ENC:%.+]] = lwe.encode [[ARG1]] + // CHECK: [[LWE:%.+]] = lwe.trivial_encrypt [[ENC]] + // CHECK: [[VAL1:%.+]] = cggi.and [[ARG0]], [[LWE]] + %0 = secret.generic + ins(%arg0, %arg1: !secret.secret, i1) { + ^bb0(%ARG0: i1, %ARG1: i1) : + %1 = comb.and %ARG0, %ARG1 : i1 + secret.yield %1 : i1 + } -> (!secret.secret) + // CHECK: return [[VAL1]] : [[LWET]] + func.return %0 : !secret.secret +}