Skip to content

Commit

Permalink
Merge pull request #1189 from ZenithalHourlyRate:openfhe-configure-cr…
Browse files Browse the repository at this point in the history
…ypto-context

PiperOrigin-RevId: 706847893
  • Loading branch information
copybara-github committed Dec 16, 2024
2 parents 8c68164 + ca9b464 commit 8ba7a63
Show file tree
Hide file tree
Showing 8 changed files with 72 additions and 346 deletions.
19 changes: 0 additions & 19 deletions lib/Analysis/MulDepthAnalysis/BUILD

This file was deleted.

11 changes: 0 additions & 11 deletions lib/Analysis/MulDepthAnalysis/CMakeLists.txt

This file was deleted.

77 changes: 0 additions & 77 deletions lib/Analysis/MulDepthAnalysis/MulDepthAnalysis.cpp

This file was deleted.

123 changes: 0 additions & 123 deletions lib/Analysis/MulDepthAnalysis/MulDepthAnalysis.h

This file was deleted.

2 changes: 1 addition & 1 deletion lib/Dialect/Openfhe/Transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ cc_library(
],
deps = [
":pass_inc_gen",
"@heir//lib/Analysis/MulDepthAnalysis",
"@heir//lib/Dialect/Openfhe/IR:Dialect",
"@heir//lib/Dialect/RNS/IR:Dialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:ArithDialect",
Expand Down
63 changes: 26 additions & 37 deletions lib/Dialect/Openfhe/Transforms/ConfigureCryptoContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
#include <set>
#include <string>

#include "lib/Analysis/MulDepthAnalysis/MulDepthAnalysis.h"
#include "lib/Dialect/Openfhe/IR/OpenfheOps.h"
#include "lib/Dialect/Openfhe/IR/OpenfheTypes.h"
#include "lib/Dialect/RNS/IR/RNSTypes.h"
#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project
#include "mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" // from @llvm-project
#include "mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h" // from @llvm-project
Expand All @@ -31,12 +31,11 @@ namespace openfhe {
#define GEN_PASS_DEF_CONFIGURECRYPTOCONTEXT
#include "lib/Dialect/Openfhe/Transforms/Passes.h.inc"

// Helper function to check if the function has MulOp or MulNoRelinOp
bool hasMulOp(func::FuncOp op) {
// Helper function to check if the function has RelinOp
bool hasRelinOp(func::FuncOp op) {
bool result = false;
op.walk<WalkOrder::PreOrder>([&](Operation *op) {
if (isa<openfhe::MulOp>(op) || isa<openfhe::MulNoRelinOp>(op) ||
isa<openfhe::MulPlainOp>(op)) {
if (isa<openfhe::RelinOp>(op)) {
result = true;
return WalkResult::interrupt();
}
Expand Down Expand Up @@ -86,7 +85,8 @@ LogicalResult generateGenFunc(func::FuncOp op, const std::string &genFuncName,
// function that configures the crypto context with proper keygeneration
LogicalResult generateConfigFunc(func::FuncOp op,
const std::string &configFuncName,
bool hasMulOp, SmallVector<int64_t> rotIndices,
bool hasRelinOp,
SmallVector<int64_t> rotIndices,
ImplicitLocOpBuilder &builder) {
Type openfheContextType =
openfhe::CryptoContextType::get(builder.getContext());
Expand All @@ -108,7 +108,7 @@ LogicalResult generateConfigFunc(func::FuncOp op,
Value cryptoContext = configFuncOp.getArgument(0);
Value privateKey = configFuncOp.getArgument(1);

if (hasMulOp) {
if (hasRelinOp) {
builder.create<openfhe::GenMulKeyOp>(cryptoContext, privateKey);
}
if (!rotIndices.empty()) {
Expand All @@ -119,7 +119,7 @@ LogicalResult generateConfigFunc(func::FuncOp op,
return success();
}

LogicalResult convertFunc(func::FuncOp op, int64_t mulDepth) {
LogicalResult convertFunc(func::FuncOp op) {
auto module = op->getParentOfType<ModuleOp>();
std::string genFuncName("");
llvm::raw_string_ostream genNameOs(genFuncName);
Expand All @@ -132,16 +132,30 @@ LogicalResult convertFunc(func::FuncOp op, int64_t mulDepth) {
ImplicitLocOpBuilder builder =
ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody());

// get mulDepth from function argument ciphertext type
int64_t mulDepth = 0;
for (auto arg : op.getArguments()) {
if (auto argType = dyn_cast<lwe::NewLWECiphertextType>(
getElementTypeOrSelf(arg.getType()))) {
if (auto rnsType = dyn_cast<rns::RNSType>(
argType.getCiphertextSpace().getRing().getCoefficientType())) {
mulDepth = rnsType.getBasisTypes().size() - 1;
// implicitly assume arguments have the same level
break;
}
}
}

if (failed(generateGenFunc(op, genFuncName, mulDepth, builder))) {
return failure();
}

builder.setInsertionPointToEnd(module.getBody());

bool hasMulOpResult = hasMulOp(op);
bool hasRelinOpResult = hasRelinOp(op);
SmallVector<int64_t> rotIndices = findAllRotIndices(op);
if (failed(generateConfigFunc(op, configFuncName, hasMulOpResult, rotIndices,
builder))) {
if (failed(generateConfigFunc(op, configFuncName, hasRelinOpResult,
rotIndices, builder))) {
return failure();
}
return success();
Expand All @@ -152,35 +166,10 @@ struct ConfigureCryptoContext
using ConfigureCryptoContextBase::ConfigureCryptoContextBase;

void runOnOperation() override {
// Analyse the operations to find the MulDepth
DataFlowSolver solver;
solver.load<dataflow::DeadCodeAnalysis>();
solver.load<dataflow::SparseConstantPropagation>();
solver.load<MulDepthAnalysis>();
if (failed(solver.initializeAndRun(getOperation()))) {
getOperation()->emitOpError() << "Failed to run the analysis.\n";
signalPassFailure();
return;
}
int64_t maxMulDepth = 0;
// walk the operations to find the max MulDepth
getOperation()->walk([&](Operation *op) {
// if the lengths of the operands is 0, then return
if (op->getNumResults() == 0) return WalkResult::advance();
const MulDepthLattice *resultLattice =
solver.lookupState<MulDepthLattice>(op->getResult(0));
if (resultLattice->getValue().isInitialized()) {
maxMulDepth =
std::max(maxMulDepth, resultLattice->getValue().getValue());
}
return WalkResult::advance();
});

auto result =
getOperation()->walk<WalkOrder::PreOrder>([&](func::FuncOp op) {
auto funcName = op.getSymName();
if ((funcName == entryFunction) &&
failed(convertFunc(op, maxMulDepth))) {
if ((funcName == entryFunction) && failed(convertFunc(op))) {
op->emitError("Failed to configure the crypto context for func");
return WalkResult::interrupt();
}
Expand Down
Loading

0 comments on commit 8ba7a63

Please sign in to comment.