-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Lower RAM implementation of slice_columns for BRWT #226
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -3,6 +3,10 @@ | |||||
#include <queue> | ||||||
#include <numeric> | ||||||
|
||||||
#include <omp.h> | ||||||
|
||||||
#include <tsl/hopscotch_map.h> | ||||||
|
||||||
#include "common/algorithms.hpp" | ||||||
#include "common/serialization.hpp" | ||||||
|
||||||
|
@@ -189,6 +193,99 @@ std::vector<BRWT::Column> BRWT::slice_rows(const std::vector<Row> &row_ids) cons | |||||
return slice; | ||||||
} | ||||||
|
||||||
void BRWT::slice_columns(const std::vector<Column> &column_ids, | ||||||
const ColumnCallback &callback) const { | ||||||
if (column_ids.empty()) | ||||||
return; | ||||||
|
||||||
auto num_nonzero_rows = nonzero_rows_->num_set_bits(); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
// check if the column is empty | ||||||
if (!num_nonzero_rows) | ||||||
return; | ||||||
Comment on lines
+203
to
+205
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Even if they are empty, you still need to call them. Add unit tests? |
||||||
|
||||||
// check whether it is a leaf | ||||||
if (!child_nodes_.size()) { | ||||||
// return the index column | ||||||
for (size_t k = 0; k < column_ids.size(); ++k) { | ||||||
callback(column_ids[k], std::move(*nonzero_rows_->copy())); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Better call a const reference, so the column can be copied by the caller if it's needed, and otherwise, there is no overhead.
Suggested change
Comment on lines
+210
to
+211
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not range-based loop? |
||||||
} | ||||||
|
||||||
return; | ||||||
} | ||||||
|
||||||
tsl::hopscotch_map<uint32_t, std::vector<Column>> child_columns_map; | ||||||
for (size_t i = 0; i < column_ids.size(); ++i) { | ||||||
assert(column_ids[i] < num_columns()); | ||||||
auto child_node = assignments_.group(column_ids[i]); | ||||||
auto child_column = assignments_.rank(column_ids[i]); | ||||||
|
||||||
auto it = child_columns_map.find(child_node); | ||||||
if (it == child_columns_map.end()) | ||||||
it = child_columns_map.emplace(child_node, std::vector<Column>{}).first; | ||||||
|
||||||
it.value().push_back(child_column); | ||||||
} | ||||||
|
||||||
auto process = [&](auto child_node, auto *child_columns_ptr) { | ||||||
if (num_nonzero_rows == nonzero_rows_->size()) { | ||||||
child_nodes_[child_node]->slice_columns(*child_columns_ptr, | ||||||
[&](Column j, bitmap&& rows) { | ||||||
callback(assignments_.get(child_node, j), std::move(rows)); | ||||||
} | ||||||
); | ||||||
} else { | ||||||
const BRWT *child_node_brwt = dynamic_cast<const BRWT*>( | ||||||
child_nodes_[child_node].get() | ||||||
); | ||||||
if (child_node_brwt | ||||||
&& child_columns_ptr->size() > 1 | ||||||
&& !child_node_brwt->child_nodes_.size()) { | ||||||
// if there are multiple column ids corresponding to the same leaf | ||||||
// node, then this branch avoids doing redundant select1 calls | ||||||
const auto *nonzero_rows = child_node_brwt->nonzero_rows_.get(); | ||||||
size_t num_nonzero_rows = nonzero_rows->num_set_bits(); | ||||||
if (num_nonzero_rows) { | ||||||
std::vector<uint64_t> set_bits; | ||||||
set_bits.reserve(num_nonzero_rows); | ||||||
nonzero_rows->call_ones([&](auto i) { | ||||||
set_bits.push_back(nonzero_rows->select1(i + 1)); | ||||||
}); | ||||||
|
||||||
for (size_t k = 0; k < child_columns_ptr->size() - 1; ++k) { | ||||||
callback(assignments_.get(child_node, (*child_columns_ptr)[k]), | ||||||
bitmap_generator(std::move(set_bits), num_rows())); | ||||||
} | ||||||
|
||||||
callback(assignments_.get(child_node, child_columns_ptr->back()), | ||||||
bitmap_generator(std::move(set_bits), num_rows())); | ||||||
} | ||||||
} else { | ||||||
Comment on lines
+217
to
+263
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add some comments to explain why this is going to make things faster than the basic call? |
||||||
child_nodes_[child_node]->slice_columns(*child_columns_ptr, | ||||||
[&](Column j, bitmap&& rows) { | ||||||
size_t num_set_bits = rows.num_set_bits(); | ||||||
callback(assignments_.get(child_node, j), | ||||||
bitmap_generator(std::move(rows), [&](uint64_t i) { | ||||||
return nonzero_rows_->select1(i + 1); | ||||||
}, num_rows(), num_set_bits)); | ||||||
} | ||||||
); | ||||||
} | ||||||
} | ||||||
}; | ||||||
|
||||||
for (auto it = ++child_columns_map.begin(); it != child_columns_map.end(); ++it) { | ||||||
auto child_node = it->first; | ||||||
auto *child_columns_ptr = &it->second; | ||||||
#pragma omp task firstprivate(child_node, child_columns_ptr) | ||||||
process(child_node, child_columns_ptr); | ||||||
} | ||||||
|
||||||
process(child_columns_map.begin()->first, &child_columns_map.begin()->second); | ||||||
|
||||||
#pragma omp taskwait | ||||||
} | ||||||
|
||||||
std::vector<BRWT::Row> BRWT::get_column(Column column) const { | ||||||
assert(column < num_columns()); | ||||||
|
||||||
|
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -131,6 +131,25 @@ Rainbow<MatrixType>::get_column(Column column) const { | |||||||
return row_indices; | ||||||||
} | ||||||||
|
||||||||
template <class MatrixType> | ||||||||
void | ||||||||
Rainbow<MatrixType>::slice_columns(const std::vector<Column> &columns, | ||||||||
const ColumnCallback &callback) const { | ||||||||
uint64_t nrows = num_rows(); | ||||||||
sdsl::bit_vector code_column(reduced_matrix_.num_rows()); | ||||||||
reduced_matrix_.slice_columns(columns, [&](Column j, bitmap&& rows) { | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
sdsl::util::set_to_value(code_column, false); | ||||||||
rows.add_to(&code_column); | ||||||||
|
||||||||
callback(j, bitmap_generator([&](const auto &index_callback) { | ||||||||
for (uint64_t i = 0; i < nrows; ++i) { | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will take forever. Make it parallel
Suggested change
|
||||||||
if (code_column[get_code(i)]) | ||||||||
index_callback(i); | ||||||||
} | ||||||||
}, nrows)); | ||||||||
}); | ||||||||
} | ||||||||
|
||||||||
template <class MatrixType> | ||||||||
bool Rainbow<MatrixType>::load(std::istream &in) { | ||||||||
try { | ||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -40,6 +40,9 @@ class Rainbow : public RainbowMatrix { | |
size_t num_threads = 1) const override; | ||
std::vector<Row> get_column(Column column) const override; | ||
|
||
void slice_columns(const std::vector<Column> &columns, | ||
const ColumnCallback &callback) const override; | ||
Comment on lines
+43
to
+44
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rename to |
||
|
||
bool load(std::istream &in) override; | ||
void serialize(std::ostream &out) const override; | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
call_columns