Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

♻️ Eliminate DD terminal nodes #381

Merged
merged 14 commits into from
Jul 22, 2023
26 changes: 15 additions & 11 deletions include/dd/DDDefinitions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,25 @@
#include <vector>

namespace dd {
// integer type used for indexing qubits
// needs to be a signed type to encode -1 as the index for the terminal
// std::int8_t can address up to 128 qubits as [0, ..., 127]
using Qubit = std::int8_t;
static_assert(std::is_signed_v<Qubit>, "Type Qubit must be signed.");

// integer type used for specifying numbers of qubits
using QubitCount = std::make_unsigned_t<Qubit>;
/**
* @brief Integer type used for indexing qubits
* @details `std::uint16_t` can address up to 65536 qubits as [0, ..., 65535].
* @note If you need even more qubits, this can be increased to `std::uint32_t`.
* Beware of the increased memory footprint of matrix nodes.
*/
using Qubit = std::uint16_t;

// integer type used for reference counting
// 32bit suffice for a max ref count of around 4 billion
/**
* @brief Integer type used for reference counting
* @details Allows a maximum reference count of roughly 4 billion.
*/
using RefCount = std::uint32_t;
static_assert(std::is_unsigned_v<RefCount>, "RefCount should be unsigned.");

// floating point type to use
/**
* @brief Floating point type to use for computations
* @note Adjusting the precision might lead to unexpected results.
*/
using fp = double;
static_assert(
std::is_floating_point_v<fp>,
Expand Down
14 changes: 13 additions & 1 deletion include/dd/Export.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,12 @@ static void toDot(const Edge& e, std::ostream& os, bool colored = true,
std::unordered_set<decltype(e.p)> nodes{};

auto priocmp = [](const Edge* left, const Edge* right) {
if (left->p == nullptr) {
return right->p != nullptr;
}
if (right->p == nullptr) {
return false;
}
return left->p->v < right->p->v;
};

Expand Down Expand Up @@ -695,7 +701,7 @@ static void toDot(const Edge& e, std::ostream& os, bool colored = true,
modernNode(*node, oss, formatAsPolar);
}

// iterate over edges in reverse to guarantee correct proceossing order
// iterate over edges in reverse to guarantee correct processing order
for (auto i = static_cast<std::int16_t>(node->p->e.size() - 1); i >= 0;
--i) {
auto& edge = node->p->e[static_cast<std::size_t>(i)];
Expand Down Expand Up @@ -936,6 +942,12 @@ template <typename Edge>
static void exportEdgeWeights(const Edge& edge, std::ostream& stream) {
struct Priocmp {
bool operator()(const Edge* left, const Edge* right) {
if (left->p == nullptr) {
return right->p != nullptr;
}
if (right->p == nullptr) {
return false;
}
return left->p->v < right->p->v;
}
};
Expand Down
80 changes: 44 additions & 36 deletions include/dd/Node.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,30 +9,34 @@
#include <utility>

namespace dd {
// NOLINTNEXTLINE(readability-identifier-naming)
struct vNode {

/**
* @brief A vector DD node
* @details Data Layout |24|24|8|4|2| = 62B (space for two more bytes)
*/
struct vNode { // NOLINT(readability-identifier-naming)
std::array<Edge<vNode>, RADIX> e{}; // edges out of this node
vNode* next{}; // used to link nodes in unique table
RefCount ref{}; // reference count
Qubit v{}; // variable index (nonterminal) value (-1 for terminal)

// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables
static vNode terminal;
Qubit v{}; // variable index

static constexpr bool isTerminal(const vNode* p) noexcept {
return p == &terminal;
return p == nullptr;
}
static constexpr vNode* getTerminal() noexcept { return &terminal; }
static constexpr vNode* getTerminal() noexcept { return nullptr; }
};
using vEdge = Edge<vNode>;
using vCachedEdge = CachedEdge<vNode>;

// NOLINTNEXTLINE(readability-identifier-naming)
struct mNode {
/**
* @brief A matrix DD node
* @details Data Layout |24|24|24|24|8|4|2|1| = 111B (space for one more byte)
*/
struct mNode { // NOLINT(readability-identifier-naming)
std::array<Edge<mNode>, NEDGE> e{}; // edges out of this node
mNode* next{}; // used to link nodes in unique table
RefCount ref{}; // reference count
Qubit v{}; // variable index (nonterminal) value (-1 for terminal)
Qubit v{}; // variable index
std::uint8_t flags = 0;
// 32 = marks a node with is symmetric.
// 16 = marks a node resembling identity
Expand All @@ -41,45 +45,53 @@ struct mNode {
// 2 = mark first path edge (tmp flag),
// 1 = mark path is conjugated (tmp flag))

// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
static mNode terminal;
[[nodiscard]] static constexpr bool isTerminal(const mNode* p) noexcept {
return p == nullptr;
}
[[nodiscard]] static constexpr mNode* getTerminal() noexcept {
return nullptr;
}

static constexpr bool isTerminal(const mNode* p) noexcept {
return p == &terminal;
[[nodiscard]] inline bool isSymmetric() const noexcept {
return (flags & static_cast<std::uint8_t>(32U)) != 0;
}
[[nodiscard]] static constexpr bool isSymmetric(const mNode* p) noexcept {
return p == nullptr || p->isSymmetric();
}
inline void setSymmetric(const bool symmetric) noexcept {
if (symmetric) {
flags = (flags | static_cast<std::uint8_t>(32U));
} else {
flags = (flags & static_cast<std::uint8_t>(~32U));
}
}
static constexpr mNode* getTerminal() noexcept { return &terminal; }

[[nodiscard]] inline bool isIdentity() const noexcept {
return (flags & static_cast<std::uint8_t>(16U)) != 0;
}
[[nodiscard]] inline bool isSymmetric() const noexcept {
return (flags & static_cast<std::uint8_t>(32U)) != 0;
[[nodiscard]] static constexpr bool isIdentity(const mNode* p) noexcept {
return p == nullptr || p->isIdentity();
}

inline void setIdentity(const bool identity) noexcept {
if (identity) {
flags = (flags | static_cast<std::uint8_t>(16U));
} else {
flags = (flags & static_cast<std::uint8_t>(~16U));
}
}
inline void setSymmetric(const bool symmetric) noexcept {
if (symmetric) {
flags = (flags | static_cast<std::uint8_t>(32U));
} else {
flags = (flags & static_cast<std::uint8_t>(~32U));
}
}
};
using mEdge = Edge<mNode>;
using mCachedEdge = CachedEdge<mNode>;

// NOLINTNEXTLINE(readability-identifier-naming)
struct dNode {
/**
* @brief A density matrix DD node
* @details Data Layout |24|24|24|24|8|4|2|1| = 111B (space for one more byte)
*/
struct dNode { // NOLINT(readability-identifier-naming)
std::array<Edge<dNode>, NEDGE> e{}; // edges out of this node
dNode* next{}; // used to link nodes in unique table
RefCount ref{}; // reference count
Qubit v{}; // variable index (nonterminal) value (-1 for terminal)
Qubit v{}; // variable index
std::uint8_t flags = 0;
// 32 = marks a node with is symmetric.
// 16 = marks a node resembling identity
Expand All @@ -88,13 +100,10 @@ struct dNode {
// 2 = mark first path edge (tmp flag),
// 1 = mark path is conjugated (tmp flag))

// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
static dNode terminal;

static constexpr bool isTerminal(const dNode* p) noexcept {
return p == &terminal;
return p == nullptr;
}
static constexpr dNode* getTerminal() noexcept { return &terminal; }
static constexpr dNode* getTerminal() noexcept { return nullptr; }

[[nodiscard]] [[maybe_unused]] static inline bool
tempDensityMatrixFlagsEqual(const std::uint8_t a,
Expand Down Expand Up @@ -192,8 +201,7 @@ static inline dEdge densityFromMatrixEdge(const mEdge& e) {
template <typename Node>
[[nodiscard]] static inline bool
noRefCountingNeeded(const Node* const p) noexcept {
return p == nullptr || Node::isTerminal(p) ||
p->ref == std::numeric_limits<RefCount>::max();
return Node::isTerminal(p) || p->ref == std::numeric_limits<RefCount>::max();
}

/**
Expand Down
Loading