Skip to content

Commit

Permalink
Merge pull request #1228 from ZenithalHourlyRate:mgmt-bootstrap
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 713710443
  • Loading branch information
copybara-github committed Jan 9, 2025
2 parents 3b91726 + a780710 commit 48c8f20
Show file tree
Hide file tree
Showing 13 changed files with 236 additions and 1 deletion.
6 changes: 6 additions & 0 deletions lib/Analysis/LevelAnalysis/LevelAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ LogicalResult LevelAnalysis::visitOperation(
auto level = operandLattice->getValue().getLevel();
propagate(modReduceOp.getResult(), LevelState(level + 1));
})
.Case<mgmt::BootstrapOp>([&](auto bootstrapOp) {
// implicitly ensure that the result is secret
// reset level to 0
// TODO(#1207): reset level to currentLevel - bootstrapDepth
propagate(bootstrapOp.getResult(), LevelState(0));
})
.Default([&](auto &op) {
// condition on result secretness
SmallVector<OpResult> secretResults;
Expand Down
19 changes: 19 additions & 0 deletions lib/Dialect/CKKS/IR/CKKSOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -206,4 +206,23 @@ def CKKS_RescaleOp : CKKS_Op<"rescale", [Pure]> {
let assemblyFormat = "operands attr-dict `:` qualified(type($input)) `->` qualified(type($output))" ;
}

def CKKS_BootstrapOp : CKKS_Op<"bootstrap", [Pure]> {
let summary = "Bootstrap the ciphertext to reduce noise and refresh its parameters.";

let description = [{
Bootstrapping is a technique used in FHE to reduce the noise in a ciphertext
and refresh its parameters, allowing for further computations on the ciphertext.
}];

let arguments = (ins
NewLWECiphertext:$input
);

let results = (outs
NewLWECiphertext:$output
);

let assemblyFormat = "operands attr-dict `:` qualified(type($input)) `->` qualified(type($output))" ;
}

#endif // LIB_DIALECT_CKKS_IR_CKKSOPS_TD_
23 changes: 22 additions & 1 deletion lib/Dialect/LWE/Conversions/LWEToOpenfhe/LWEToOpenfhe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,25 @@ struct ConvertEncodeOp : public OpConversionPattern<lwe::RLWEEncodeOp> {
}
};

struct ConvertBootstrapOp : public OpConversionPattern<ckks::BootstrapOp> {
ConvertBootstrapOp(mlir::MLIRContext *context)
: OpConversionPattern<ckks::BootstrapOp>(context) {}

using OpConversionPattern<ckks::BootstrapOp>::OpConversionPattern;

LogicalResult matchAndRewrite(
ckks::BootstrapOp op, ckks::BootstrapOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
FailureOr<Value> result = getContextualCryptoContext(op.getOperation());
if (failed(result)) return result;

Value cryptoContext = result.value();
rewriter.replaceOpWithNewOp<openfhe::BootstrapOp>(
op, op.getOutput().getType(), cryptoContext, adaptor.getInput());
return success();
}
};

struct LWEToOpenfhe : public impl::LWEToOpenfheBase<LWEToOpenfhe> {
void runOnOperation() override {
MLIRContext *context = &getContext();
Expand Down Expand Up @@ -309,7 +328,9 @@ struct LWEToOpenfhe : public impl::LWEToOpenfheBase<LWEToOpenfhe> {
// Modulus Switch (BGV only)
lwe::ConvertModulusSwitchOp<bgv::ModulusSwitchOp>,
// Rescale (CKKS version of Modulus Switch)
lwe::ConvertModulusSwitchOp<ckks::RescaleOp>
lwe::ConvertModulusSwitchOp<ckks::RescaleOp>,
// Bootstrap (CKKS only)
ConvertBootstrapOp
// End of Pattern List
>(typeConverter, context);

Expand Down
25 changes: 25 additions & 0 deletions lib/Dialect/Mgmt/IR/MgmtOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,29 @@ def Mgmt_RelinearizeOp : Mgmt_Op<"relinearize"> {
let assemblyFormat = "operands attr-dict `:` type($output)";
}

def Mgmt_BootstrapOp : Mgmt_Op<"bootstrap"> {
let summary = "Bootstrap the input ciphertext to refresh its noise budget";

let description = [{
This is a scheme-agnostic operation that implies bootstrapping
of the input ciphertext to refresh its noise budget.

Bootstrapping is a technique used in homomorphic encryption to
reduce the noise in a ciphertext, allowing further operations
to be performed on it without decryption.

When further lowered, it could be lowered to bgv.bootstrap
or ckks.bootstrap depending on the scheme.

For the current backend, only ckks.bootstrap is supported.
Further backend may include bgv.bootstrap.
}];

let arguments = (ins
AnyType:$input
);
let results = (outs AnyType:$output);
let assemblyFormat = "operands attr-dict `:` type($output)";
}

#endif // LIB_DIALECT_MGMT_IR_MGMTOPS_TD_
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ struct SecretToCKKS : public impl::SecretToCKKSBase<SecretToCKKS> {
SecretGenericOpCipherConversion<tensor::EmptyOp, tensor::EmptyOp>,
SecretGenericOpRelinearizeConversion<ckks::RelinearizeOp>,
SecretGenericOpModulusSwitchConversion<ckks::RescaleOp>,
SecretGenericOpCipherConversion<mgmt::BootstrapOp, ckks::BootstrapOp>,
SecretGenericTensorExtractConversion,
SecretGenericTensorInsertConversion,
SecretGenericOpRotateConversion<ckks::RotateOp>,
Expand Down
1 change: 1 addition & 0 deletions lib/Transforms/SecretInsertMgmt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ cc_library(
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
],
Expand Down
13 changes: 13 additions & 0 deletions lib/Transforms/SecretInsertMgmt/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,17 @@ def SecretInsertMgmtCKKS : Pass<"secret-insert-mgmt-ckks", "ModuleOp"> {
Check the description of secret-insert-mgmt-bgv. This pass
implements similar strategy, where mgmt.modreduce stands for
ckks.rescale here.

For bootstrap insertion policy, currently a greedy policy is used
where when all levels are consumed then a bootstrap is inserted.

The max level available after bootstrap is controlled by the option
`bootstrap-waterline`.

Number of bootstrap consumed level is not shown here, which is
handled by further lowering.
TODO(#1207): handle it here so parameter selection can depend on it.
TODO(#1207): with this info we can encrypt at max level (with bootstrap consumed level).
}];

let dependentDialects = [
Expand All @@ -97,6 +108,8 @@ def SecretInsertMgmtCKKS : Pass<"secret-insert-mgmt-ckks", "ModuleOp"> {
/*default=*/"false", "Modulus switching right before the first multiplication (default to false)">,
Option<"slotNumber", "slot-number", "int",
/*default=*/"1024", "Default number of slots use for ciphertext space.">,
Option<"bootstrapWaterline", "bootstrap-waterline", "int",
/*default=*/"10", "Waterline for insert bootstrap op">,
];
}

Expand Down
15 changes: 15 additions & 0 deletions lib/Transforms/SecretInsertMgmt/SecretInsertMgmtCKKS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "lib/Analysis/LevelAnalysis/LevelAnalysis.h"
#include "lib/Analysis/MulResultAnalysis/MulResultAnalysis.h"
#include "lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h"
#include "lib/Dialect/Mgmt/IR/MgmtOps.h"
#include "lib/Dialect/Mgmt/Transforms/AnnotateMgmt.h"
#include "lib/Dialect/Mgmt/Transforms/Passes.h"
#include "lib/Dialect/Secret/IR/SecretOps.h"
Expand Down Expand Up @@ -129,6 +130,7 @@ struct SecretInsertMgmtCKKS

// re-run analysis as MulResultAnalysis is affected by slot_extract
solver.load<MulResultAnalysis>();
solver.eraseAllStates();
if (failed(solver.initializeAndRun(getOperation()))) {
getOperation()->emitOpError() << "Failed to run the analysis.\n";
signalPassFailure();
Expand Down Expand Up @@ -161,6 +163,19 @@ struct SecretInsertMgmtCKKS
(void)walkAndApplyPatterns(getOperation(),
std::move(patternsMultModReduce));

// insert BootstrapOp after mgmt::ModReduceOp
// This must be run before level mismatch
// NOTE: actually bootstrap before mod reduce is better
// as after modreduce to level `0` there still might be add/sub
// and these op done there could be minimal cost.
// However, this greedy strategy is temporary so not too much
// optimization now
RewritePatternSet patternsBootstrapWaterLine(&getContext());
patternsBootstrapWaterLine.add<BootstrapWaterLine<mgmt::ModReduceOp>>(
&getContext(), getOperation(), &solver, bootstrapWaterline);
(void)walkAndApplyPatterns(getOperation(),
std::move(patternsBootstrapWaterLine));

// when other binary op operands level mismatch
// includeFirstMul not used for these ops
RewritePatternSet patternsAddModReduce(&getContext());
Expand Down
38 changes: 38 additions & 0 deletions lib/Transforms/SecretInsertMgmt/SecretInsertMgmtPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ LogicalResult MultRelinearize<MulOp>::matchAndRewrite(
rewriter.create<mgmt::RelinearizeOp>(mulOp.getLoc(), result);
result.replaceAllUsesExcept(relinearized, {relinearized});

solver->eraseAllStates();
return solver->initializeAndRun(top);
}

Expand Down Expand Up @@ -140,11 +141,46 @@ LogicalResult ModReduceBefore<Op>::matchAndRewrite(
if (inserted) {
// propagateIfChanged only push workitem to the worklist queue
// actually execute the transfer for the new values
solver->eraseAllStates();
return solver->initializeAndRun(top);
}
return success();
}

template <typename Op>
LogicalResult BootstrapWaterLine<Op>::matchAndRewrite(
Op op, PatternRewriter &rewriter) const {
auto levelLattice = solver->lookupState<LevelLattice>(op->getResult(0));
if (!levelLattice->getValue().isInitialized()) {
return failure();
}

auto level = levelLattice->getValue().getLevel();

if (level < waterline) {
return success();
}
if (level > waterline) {
// should never met!
LLVM_DEBUG(llvm::dbgs()
<< "BootstrapWaterLine: met " << op << " with level: " << level
<< " but waterline: " << waterline << "\n");
return failure();
}

// insert mgmt::BootstrapOp after
rewriter.setInsertionPointAfter(op);
auto bootstrap = rewriter.create<mgmt::BootstrapOp>(
op.getLoc(), op->getResultTypes(), op->getResult(0));
op->getResult(0).replaceAllUsesExcept(bootstrap, {bootstrap});

// greedy rewrite! note that we may get undeterministic insertion result
// if we use different order of rewrites
// currently walkAndApplyPatterns is deterministic
solver->eraseAllStates();
return solver->initializeAndRun(top);
}

// for BGV
template struct MultRelinearize<arith::MulIOp>;

Expand All @@ -167,5 +203,7 @@ template struct ModReduceBefore<arith::MulFOp>;
template struct ModReduceBefore<arith::AddFOp>;
template struct ModReduceBefore<arith::SubFOp>;

template struct BootstrapWaterLine<mgmt::ModReduceOp>;

} // namespace heir
} // namespace mlir
22 changes: 22 additions & 0 deletions lib/Transforms/SecretInsertMgmt/SecretInsertMgmtPatterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/include/mlir/Pass/AnalysisManager.h" // from @llvm-project

namespace mlir {
namespace heir {
Expand Down Expand Up @@ -46,6 +47,27 @@ struct ModReduceBefore : public OpRewritePattern<Op> {
DataFlowSolver *solver;
};

// when reached a certain depth (water line), bootstrap
template <typename Op>
struct BootstrapWaterLine : public OpRewritePattern<Op> {
using OpRewritePattern<Op>::OpRewritePattern;

BootstrapWaterLine(MLIRContext *context, Operation *top,
DataFlowSolver *solver, int waterline)
: OpRewritePattern<Op>(context, /*benefit=*/1),
top(top),
solver(solver),
waterline(waterline) {}

LogicalResult matchAndRewrite(Op op,
PatternRewriter &rewriter) const override;

private:
Operation *top;
DataFlowSolver *solver;
int waterline;
};

} // namespace heir
} // namespace mlir

Expand Down
23 changes: 23 additions & 0 deletions tests/Transforms/mlir_to_openfhe_ckks/bootstrap_waterline.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// RUN: heir-opt --mlir-to-secret-arithmetic --secret-insert-mgmt-ckks=bootstrap-waterline=3 --mlir-to-openfhe-ckks %s | FileCheck %s

// CHECK: func.func @bootstrap_waterline
// CHECK: openfhe.bootstrap

func.func @bootstrap_waterline(
%x : f16 {secret.secret}
) -> f16 {
%0 = arith.addf %x, %x : f16
%r0 = mgmt.modreduce %0 : f16
%1 = arith.addf %r0, %r0 : f16
%r1 = mgmt.modreduce %1 : f16
%2 = arith.addf %r1, %r1 : f16
%r2 = mgmt.modreduce %2 : f16
%3 = arith.addf %r2, %r2 : f16
%r3 = mgmt.modreduce %3 : f16
%4 = arith.addf %r3, %r3 : f16
%r4 = mgmt.modreduce %4 : f16
%5 = arith.addf %r4, %r4 : f16
// cross level op
%mixed0 = arith.addf %5, %x : f16
return %mixed0 : f16
}
10 changes: 10 additions & 0 deletions tests/Transforms/secret_insert_mgmt/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
load("//bazel:lit.bzl", "glob_lit_tests")

package(default_applicable_licenses = ["@heir//:license"])

glob_lit_tests(
name = "all_tests",
data = ["@heir//tests:test_utilities"],
driver = "@heir//tests:run_lit.sh",
test_file_exts = ["mlir"],
)
41 changes: 41 additions & 0 deletions tests/Transforms/secret_insert_mgmt/bootstrap_waterline.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// RUN: heir-opt --mlir-to-secret-arithmetic --secret-insert-mgmt-ckks=bootstrap-waterline=3 %s | FileCheck %s

// CHECK: func.func @bootstrap_waterline(%arg0: !secret.secret<f16>) -> !secret.secret<f16> {
// CHECK: %0 = secret.generic ins(%[[arg0:.*]] : !secret.secret<f16>) attrs = {arg0 = {mgmt.mgmt = #mgmt.mgmt<level = 3>}} {
// CHECK: (%[[input0:.*]]: f16):
// CHECK: %[[v1:.*]] = arith.addf %[[input0]], %[[input0]] {mgmt.mgmt = #mgmt.mgmt<level = 3>} : f16
// CHECK: %[[v2:.*]] = mgmt.modreduce %[[v1]] {mgmt.mgmt = #mgmt.mgmt<level = 2>} : f16
// CHECK: %[[v3:.*]] = arith.addf %2, %[[v2]] {mgmt.mgmt = #mgmt.mgmt<level = 2>} : f16
// CHECK: %[[v4:.*]] = mgmt.modreduce %[[v3]] {mgmt.mgmt = #mgmt.mgmt<level = 1>} : f16
// CHECK: %[[v5:.*]] = arith.addf %4, %[[v4]] {mgmt.mgmt = #mgmt.mgmt<level = 1>} : f16
// CHECK: %[[v6:.*]] = mgmt.modreduce %[[v5]] {mgmt.mgmt = #mgmt.mgmt<level = 0>} : f16
// CHECK: %[[v7:.*]] = mgmt.bootstrap %[[v6]] {mgmt.mgmt = #mgmt.mgmt<level = 3>} : f16
// CHECK: %[[v8:.*]] = arith.addf %[[v7]], %[[v7]] {mgmt.mgmt = #mgmt.mgmt<level = 3>} : f16
// CHECK: %[[v9:.*]] = mgmt.modreduce %[[v8]] {mgmt.mgmt = #mgmt.mgmt<level = 2>} : f16
// CHECK: %[[v10:.*]] = arith.addf %[[v9]], %[[v9]] {mgmt.mgmt = #mgmt.mgmt<level = 2>} : f16
// CHECK: %[[v11:.*]] = mgmt.modreduce %[[v10]] {mgmt.mgmt = #mgmt.mgmt<level = 1>} : f16
// CHECK: %[[v12:.*]] = arith.addf %[[v11]], %[[v11]] {mgmt.mgmt = #mgmt.mgmt<level = 1>} : f16
// CHECK: %[[v13:.*]] = mgmt.modreduce %[[input0]] {mgmt.mgmt = #mgmt.mgmt<level = 2>} : f16
// CHECK: %[[v14:.*]] = mgmt.modreduce %[[v13]] {mgmt.mgmt = #mgmt.mgmt<level = 1>} : f16
// CHECK: %[[v15:.*]] = arith.addf %[[v12]], %[[v14]] {mgmt.mgmt = #mgmt.mgmt<level = 1>} : f16
// CHECK: secret.yield %[[v15]] : f16


func.func @bootstrap_waterline(
%x : f16 {secret.secret}
) -> f16 {
%0 = arith.addf %x, %x : f16
%r0 = mgmt.modreduce %0 : f16
%1 = arith.addf %r0, %r0 : f16
%r1 = mgmt.modreduce %1 : f16
%2 = arith.addf %r1, %r1 : f16
%r2 = mgmt.modreduce %2 : f16
%3 = arith.addf %r2, %r2 : f16
%r3 = mgmt.modreduce %3 : f16
%4 = arith.addf %r3, %r3 : f16
%r4 = mgmt.modreduce %4 : f16
%5 = arith.addf %r4, %r4 : f16
// cross level op
%mixed0 = arith.addf %5, %x : f16
return %mixed0 : f16
}

0 comments on commit 48c8f20

Please sign in to comment.