Skip to content

Commit

Permalink
comb-to-cggi: add plaintext-ciphertext boolean gate ops and dealloc l…
Browse files Browse the repository at this point in the history
…owering

PiperOrigin-RevId: 620338590
  • Loading branch information
asraa authored and copybara-github committed Mar 29, 2024
1 parent fe238e9 commit 503e6e0
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 29 deletions.
77 changes: 48 additions & 29 deletions lib/Conversion/CombToCGGI/CombToCGGI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,37 @@ Operation *convertReadOpInterface(Operation *op, SmallVector<Value> indices,
return subViewOp;
}

SmallVector<Value> encodeInputs(Operation *op, ValueRange inputs,
ConversionPatternRewriter &rewriter) {
// Get the ciphertext type.
lwe::LWECiphertextType ctxtTy;
for (auto input : inputs) {
if (isa<lwe::LWECiphertextType>(input.getType())) {
ctxtTy = cast<lwe::LWECiphertextType>(input.getType());
break;
}
}

// Encode any plaintexts in the inputs.
auto encoding = cast<lwe::LWECiphertextType>(ctxtTy).getEncoding();
auto ptxtTy = lwe::LWEPlaintextType::get(rewriter.getContext(), encoding);
return llvm::to_vector(llvm::map_range(inputs, [&](auto input) -> Value {
if (!isa<lwe::LWECiphertextType>(input.getType())) {
IntegerType integerTy = dyn_cast<IntegerType>(input.getType());
assert(integerTy && integerTy.getWidth() == 1 &&
"LUT inputs should be single-bit integers");
return rewriter
.create<lwe::TrivialEncryptOp>(
op->getLoc(), ctxtTy,
rewriter.create<lwe::EncodeOp>(op->getLoc(), ptxtTy, input,
encoding),
lwe::LWEParamsAttr())
.getResult();
}
return input;
}));
}

} // namespace

class SecretTypeConverter : public TypeConverter {
Expand Down Expand Up @@ -293,34 +324,8 @@ class SecretGenericOpLUTConversion
void replaceOp(secret::GenericOp op, TypeRange outputTypes, ValueRange inputs,
ArrayRef<NamedAttribute> attributes,
ConversionPatternRewriter &rewriter) const override {
// Get the ciphertext type.
lwe::LWECiphertextType ctxtTy;
for (auto input : inputs) {
if (isa<lwe::LWECiphertextType>(input.getType())) {
ctxtTy = cast<lwe::LWECiphertextType>(input.getType());
break;
}
}

// Encode any plaintexts in the inputs.
auto encoding = cast<lwe::LWECiphertextType>(ctxtTy).getEncoding();
auto ptxtTy = lwe::LWEPlaintextType::get(rewriter.getContext(), encoding);
SmallVector<Value> encodedInputs =
llvm::to_vector(llvm::map_range(inputs, [&](auto input) -> Value {
if (!isa<lwe::LWECiphertextType>(input.getType())) {
IntegerType integerTy = dyn_cast<IntegerType>(input.getType());
assert(integerTy && integerTy.getWidth() == 1 &&
"LUT inputs should be single-bit integers");
return rewriter
.create<lwe::TrivialEncryptOp>(
op.getLoc(), ctxtTy,
rewriter.create<lwe::EncodeOp>(op.getLoc(), ptxtTy, input,
encoding),
lwe::LWEParamsAttr())
.getResult();
}
return input;
}));
encodeInputs(op.getOperation(), inputs, rewriter);

// Assemble the lookup table.
comb::TruthTableOp truthOp =
Expand Down Expand Up @@ -358,8 +363,9 @@ class SecretGenericOpGateConversion : public SecretGenericOpConversion<GateOp> {
void replaceOp(secret::GenericOp op, TypeRange outputTypes, ValueRange inputs,
ArrayRef<NamedAttribute> attributes,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<CGGIGateOp>(op, outputTypes, inputs,
attributes);
rewriter.replaceOpWithNewOp<CGGIGateOp>(
op, outputTypes, encodeInputs(op.getOperation(), inputs, rewriter),
attributes);
}
};

Expand Down Expand Up @@ -469,6 +475,18 @@ class SecretGenericOpMemRefAllocConversion
}
};

class SecretGenericOpMemRefDeallocConversion
: public SecretGenericOpConversion<memref::DeallocOp> {
using SecretGenericOpConversion<memref::DeallocOp>::SecretGenericOpConversion;

void replaceOp(secret::GenericOp op, TypeRange outputTypes, ValueRange inputs,
ArrayRef<NamedAttribute> attributes,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<memref::DeallocOp>(op, outputTypes, inputs,
attributes);
}
};

// ConvertTruthTableOp converts truth table ops with fully plaintext values.
struct ConvertTruthTableOp : public OpConversionPattern<TruthTableOp> {
ConvertTruthTableOp(mlir::MLIRContext *context)
Expand Down Expand Up @@ -542,6 +560,7 @@ struct CombToCGGI : public impl::CombToCGGIBase<CombToCGGI> {

patterns
.add<SecretGenericOpLUTConversion, SecretGenericOpMemRefAllocConversion,
SecretGenericOpMemRefDeallocConversion,
SecretGenericOpMemRefLoadConversion,
SecretGenericOpAffineStoreConversion,
SecretGenericOpAffineLoadConversion,
Expand Down
44 changes: 44 additions & 0 deletions tests/comb_to_cggi/bool_ops.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// RUN: heir-opt --secret-distribute-generic --split-input-file --comb-to-cggi --cse %s | FileCheck %s

// CHECK-NOT: secret
// CHECK: @boolean_gates([[ARG:%.*]]: [[LWET:!lwe.lwe_ciphertext<.*>]]) -> [[LWET]]
func.func @boolean_gates(%arg0: !secret.secret<i1>) -> !secret.secret<i1> {
// CHECK: [[VAL1:%.+]] = cggi.and [[ARG]], [[ARG]]
// CHECK: [[VAL2:%.+]] = cggi.or [[VAL1]], [[ARG]]
// CHECK: [[VAL3:%.+]] = cggi.nand [[VAL2]], [[VAL1]]
// CHECK: [[VAL4:%.+]] = cggi.xor [[VAL3]], [[VAL2]]
// CHECK: [[VAL5:%.+]] = cggi.xnor [[VAL4]], [[VAL3]]
// CHECK: [[VAL6:%.+]] = cggi.nor [[VAL5]], [[VAL4]]
%0 = secret.generic
ins(%arg0: !secret.secret<i1>) {
^bb0(%ARG0: i1) :
%1 = comb.and %ARG0, %ARG0 : i1
%2 = comb.or %1, %ARG0 : i1
%3 = comb.nand %2, %1 : i1
%4 = comb.xor %3, %2 : i1
%5 = comb.xnor %4, %3 : i1
%6 = comb.nor %5, %4 : i1
secret.yield %6 : i1
} -> (!secret.secret<i1>)
// CHECK: return [[VAL6]] : [[LWET]]
func.return %0 : !secret.secret<i1>
}

// -----

// CHECK-NOT: secret
// CHECK: @boolean_gates_partial_secret(
// CHECK-SAME: [[ARG0:%.*]]: [[LWET:!lwe.lwe_ciphertext<.*>]], [[ARG1:%.*]]: i1) -> [[LWET]]
func.func @boolean_gates_partial_secret(%arg0: !secret.secret<i1>, %arg1 : i1) -> !secret.secret<i1> {
// CHECK: [[ENC:%.+]] = lwe.encode [[ARG1]]
// CHECK: [[LWE:%.+]] = lwe.trivial_encrypt [[ENC]]
// CHECK: [[VAL1:%.+]] = cggi.and [[ARG0]], [[LWE]]
%0 = secret.generic
ins(%arg0, %arg1: !secret.secret<i1>, i1) {
^bb0(%ARG0: i1, %ARG1: i1) :
%1 = comb.and %ARG0, %ARG1 : i1
secret.yield %1 : i1
} -> (!secret.secret<i1>)
// CHECK: return [[VAL1]] : [[LWET]]
func.return %0 : !secret.secret<i1>
}

0 comments on commit 503e6e0

Please sign in to comment.