Skip to content

Commit

Permalink
SelectVariableNames: use suggested name from dialect
Browse files Browse the repository at this point in the history
  • Loading branch information
ZenithalHourlyRate committed Jan 7, 2025
1 parent 8dddd59 commit 7956696
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 15 deletions.
2 changes: 2 additions & 0 deletions lib/Analysis/SelectVariableNames/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ cc_library(
srcs = ["SelectVariableNames.cpp"],
hdrs = ["SelectVariableNames.h"],
deps = [
"@heir//lib/Utils/Tablegen:AsmInterfaces",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
Expand Down
61 changes: 52 additions & 9 deletions lib/Analysis/SelectVariableNames/SelectVariableNames.cpp
Original file line number Diff line number Diff line change
@@ -1,30 +1,73 @@
#include "lib/Analysis/SelectVariableNames/SelectVariableNames.h"

#include "llvm/include/llvm/ADT/DenseMap.h" // from @llvm-project
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "lib/Utils/Tablegen/AsmInterfaces.h"
#include "llvm/include/llvm/ADT/DenseMap.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project

namespace mlir {
namespace heir {

std::string SelectVariableNames::suggestNameForValue(Value value) {
if (auto typeAsmInterface =
mlir::dyn_cast<TypeAsmInterface>(value.getType())) {
return typeAsmInterface.suggestedName();
}
return defaultPrefix;
}

SelectVariableNames::SelectVariableNames(Operation *op) {
int i = 0;
op->walk<WalkOrder::PreOrder>([&](Operation *op) {
std::map<std::string, int> prefixCount;

auto assignName = [&](Value value) {
std::string name = suggestNameForValue(value);
// special case for default prefix, namely v0
if (prefixCount.count(name) == 0 && name != defaultPrefix) {
// for non-default prefix
// the first one is "prefix", the next on is "prefix1"
prefixCount[name] = 1;
variableNames.try_emplace(value, name);
} else if (variableNames.count(value) == 0) {
if (prefixCount.count(name) == 0) {
// for default prefix
// the first one is "v0", the next one is "v1"
prefixCount[name] = 0;
}
variableNames.try_emplace(value,
name + std::to_string(prefixCount[name]++));
}
// unique integer for each value
variableToInteger.try_emplace(value, i++);
};

auto assignForOp = [&](Operation *op) {
for (Value result : op->getResults()) {
variableNames.try_emplace(result, i++);
assignName(result);
}

for (Region &region : op->getRegions()) {
for (Block &block : region) {
for (Value arg : block.getArguments()) {
variableNames.try_emplace(arg, i++);
assignName(arg);
}
}
}
};

op->walk([&](func::FuncOp funcOp) {
// clear the prefix count
// different function has different name space
prefixCount.clear();

return WalkResult::advance();
assignForOp(funcOp);
funcOp->walk<WalkOrder::PreOrder>([&](Operation *op) {
assignForOp(op);
return WalkResult::advance();
});
});
}

Expand Down
12 changes: 7 additions & 5 deletions lib/Analysis/SelectVariableNames/SelectVariableNames.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,21 @@ class SelectVariableNames {
/// tree that this class was constructed with).
std::string getNameForValue(Value value) const {
assert(variableNames.contains(value));
return prefix + std::to_string(variableNames.lookup(value));
return variableNames.lookup(value);
}

// Return the unique integer assigned to a given value.
int getIntForValue(Value value) const {
assert(variableNames.contains(value));
return variableNames.lookup(value);
assert(variableToInteger.contains(value));
return variableToInteger.lookup(value);
}

private:
llvm::DenseMap<Value, int> variableNames;
std::string suggestNameForValue(Value value);

std::string prefix{"v"};
std::string defaultPrefix{"v"};
llvm::DenseMap<Value, std::string> variableNames;
llvm::DenseMap<Value, int> variableToInteger;
};

} // namespace heir
Expand Down
2 changes: 1 addition & 1 deletion tests/Dialect/Jaxite/Emitters/emit_jaxite.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


// CHECK-LABEL: def test_return_multiple_values(
// CHECK-NEXT: [[input:v[0-9]+]]: types.LweCiphertext,
// CHECK-NEXT: [[input:ct]]: types.LweCiphertext,
// CHECK-NEXT: [[v1:.*]]: jaxite_bool.ServerKeySet,
// CHECK-NEXT: [[v2:.*]]: jaxite_bool.Parameters,
// CHECK-NEXT: ) -> (types.LweCiphertext, types.LweCiphertext):
Expand Down

0 comments on commit 7956696

Please sign in to comment.