-
Notifications
You must be signed in to change notification settings - Fork 673
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add a small example #302
base: master
Are you sure you want to change the base?
add a small example #302
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -73,7 +73,7 @@ namespace hnswlib { | |
struct CompareByFirst { | ||
constexpr bool operator()(std::pair<dist_t, tableint> const &a, | ||
std::pair<dist_t, tableint> const &b) const noexcept { | ||
return a.first < b.first; | ||
return a.first > b.first; //let the smaller at the top | ||
} | ||
}; | ||
|
||
|
@@ -159,30 +159,30 @@ namespace hnswlib { | |
} | ||
|
||
|
||
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> | ||
std::priority_queue<std::pair<dist_t, tableint>> | ||
searchBaseLayer(tableint ep_id, const void *data_point, int layer) { | ||
VisitedList *vl = visited_list_pool_->getFreeVisitedList(); | ||
vl_type *visited_array = vl->mass; | ||
vl_type visited_array_tag = vl->curV; | ||
|
||
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates; | ||
std::priority_queue<std::pair<dist_t, tableint>> top_candidates; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. E.g. here removing |
||
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidateSet; | ||
|
||
dist_t lowerBound; | ||
if (!isMarkedDeleted(ep_id)) { | ||
dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_); | ||
top_candidates.emplace(dist, ep_id); | ||
lowerBound = dist; | ||
candidateSet.emplace(-dist, ep_id); | ||
candidateSet.emplace(dist, ep_id); | ||
} else { | ||
lowerBound = std::numeric_limits<dist_t>::max(); | ||
candidateSet.emplace(-lowerBound, ep_id); | ||
candidateSet.emplace(lowerBound, ep_id); | ||
} | ||
visited_array[ep_id] = visited_array_tag; | ||
|
||
while (!candidateSet.empty()) { | ||
std::pair<dist_t, tableint> curr_el_pair = candidateSet.top(); | ||
if ((-curr_el_pair.first) > lowerBound) { | ||
if (curr_el_pair.first > lowerBound) { | ||
break; | ||
} | ||
candidateSet.pop(); | ||
|
@@ -191,13 +191,7 @@ namespace hnswlib { | |
|
||
std::unique_lock <std::mutex> lock(link_list_locks_[curNodeNum]); | ||
|
||
int *data;// = (int *)(linkList0_ + curNodeNum * size_links_per_element0_); | ||
if (layer == 0) { | ||
data = (int*)get_linklist0(curNodeNum); | ||
} else { | ||
data = (int*)get_linklist(curNodeNum, layer); | ||
// data = (int *) (linkLists_[curNodeNum] + (layer - 1) * size_links_per_element_); | ||
} | ||
int *data = (int*)get_linklist_at_level(curNodeNum, layer); | ||
size_t size = getListCount((linklistsizeint*)data); | ||
tableint *datal = (tableint *) (data + 1); | ||
#ifdef USE_SSE | ||
|
@@ -220,7 +214,7 @@ namespace hnswlib { | |
|
||
dist_t dist1 = fstdistfunc_(data_point, currObj1, dist_func_param_); | ||
if (top_candidates.size() < ef_construction_ || lowerBound > dist1) { | ||
candidateSet.emplace(-dist1, candidate_id); | ||
candidateSet.emplace(dist1, candidate_id); | ||
#ifdef USE_SSE | ||
_mm_prefetch(getDataByInternalId(candidateSet.top().second), _MM_HINT_T0); | ||
#endif | ||
|
@@ -245,24 +239,24 @@ namespace hnswlib { | |
mutable std::atomic<long> metric_hops; | ||
|
||
template <bool has_deletions, bool collect_metrics=false> | ||
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> | ||
std::priority_queue<std::pair<dist_t, tableint>> | ||
searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef) const { | ||
VisitedList *vl = visited_list_pool_->getFreeVisitedList(); | ||
vl_type *visited_array = vl->mass; | ||
vl_type visited_array_tag = vl->curV; | ||
|
||
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates; | ||
std::priority_queue<std::pair<dist_t, tableint>> top_candidates; | ||
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidate_set; | ||
|
||
dist_t lowerBound; | ||
if (!has_deletions || !isMarkedDeleted(ep_id)) { | ||
dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_); | ||
lowerBound = dist; | ||
top_candidates.emplace(dist, ep_id); | ||
candidate_set.emplace(-dist, ep_id); | ||
candidate_set.emplace(dist, ep_id); | ||
} else { | ||
lowerBound = std::numeric_limits<dist_t>::max(); | ||
candidate_set.emplace(-lowerBound, ep_id); | ||
candidate_set.emplace(lowerBound, ep_id); | ||
} | ||
|
||
visited_array[ep_id] = visited_array_tag; | ||
|
@@ -271,7 +265,7 @@ namespace hnswlib { | |
|
||
std::pair<dist_t, tableint> current_node_pair = candidate_set.top(); | ||
|
||
if ((-current_node_pair.first) > lowerBound) { | ||
if (current_node_pair.first > lowerBound) { | ||
break; | ||
} | ||
candidate_set.pop(); | ||
|
@@ -308,7 +302,7 @@ namespace hnswlib { | |
dist_t dist = fstdistfunc_(data_point, currObj1, dist_func_param_); | ||
|
||
if (top_candidates.size() < ef || lowerBound > dist) { | ||
candidate_set.emplace(-dist, candidate_id); | ||
candidate_set.emplace(dist, candidate_id); | ||
#ifdef USE_SSE | ||
_mm_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_ + | ||
offsetLevel0_,/////////// | ||
|
@@ -332,33 +326,31 @@ namespace hnswlib { | |
return top_candidates; | ||
} | ||
|
||
void getNeighborsByHeuristic2( | ||
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates, | ||
const size_t M) { | ||
void getNeighborsByHeuristic2(std::priority_queue<std::pair<dist_t, tableint>> &top_candidates, const size_t M) { | ||
if (top_candidates.size() < M) { | ||
return; | ||
} | ||
|
||
std::priority_queue<std::pair<dist_t, tableint>> queue_closest; | ||
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> queue_closest; | ||
std::vector<std::pair<dist_t, tableint>> return_list; | ||
while (top_candidates.size() > 0) { | ||
queue_closest.emplace(-top_candidates.top().first, top_candidates.top().second); | ||
queue_closest.emplace(top_candidates.top().first, top_candidates.top().second); | ||
top_candidates.pop(); | ||
} | ||
|
||
while (queue_closest.size()) { | ||
if (return_list.size() >= M) | ||
break; | ||
std::pair<dist_t, tableint> curent_pair = queue_closest.top(); | ||
dist_t dist_to_query = -curent_pair.first; | ||
dist_t dist_to_query = curent_pair.first; | ||
queue_closest.pop(); | ||
bool good = true; | ||
|
||
for (std::pair<dist_t, tableint> second_pair : return_list) { | ||
dist_t curdist = | ||
fstdistfunc_(getDataByInternalId(second_pair.second), | ||
getDataByInternalId(curent_pair.second), | ||
dist_func_param_);; | ||
dist_func_param_); | ||
if (curdist < dist_to_query) { | ||
good = false; | ||
break; | ||
|
@@ -370,7 +362,7 @@ namespace hnswlib { | |
} | ||
|
||
for (std::pair<dist_t, tableint> curent_pair : return_list) { | ||
top_candidates.emplace(-curent_pair.first, curent_pair.second); | ||
top_candidates.emplace(curent_pair.first, curent_pair.second); | ||
} | ||
} | ||
|
||
|
@@ -379,10 +371,6 @@ namespace hnswlib { | |
return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_); | ||
}; | ||
|
||
linklistsizeint *get_linklist0(tableint internal_id, char *data_level0_memory_) const { | ||
return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_); | ||
}; | ||
|
||
linklistsizeint *get_linklist(tableint internal_id, int level) const { | ||
return (linklistsizeint *) (linkLists_[internal_id] + (level - 1) * size_links_per_element_); | ||
}; | ||
|
@@ -392,7 +380,7 @@ namespace hnswlib { | |
}; | ||
|
||
tableint mutuallyConnectNewElement(const void *data_point, tableint cur_c, | ||
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates, | ||
std::priority_queue<std::pair<dist_t, tableint>> &top_candidates, | ||
int level, bool isUpdate) { | ||
size_t Mcurmax = level ? maxM_ : maxM0_; | ||
getNeighborsByHeuristic2(top_candidates, M_); | ||
|
@@ -409,11 +397,7 @@ namespace hnswlib { | |
tableint next_closest_entry_point = selectedNeighbors.back(); | ||
|
||
{ | ||
linklistsizeint *ll_cur; | ||
if (level == 0) | ||
ll_cur = get_linklist0(cur_c); | ||
else | ||
ll_cur = get_linklist(cur_c, level); | ||
linklistsizeint *ll_cur = get_linklist_at_level(cur_c, level); | ||
|
||
if (*ll_cur && !isUpdate) { | ||
throw std::runtime_error("The newly inserted element should have blank link list"); | ||
|
@@ -435,11 +419,7 @@ namespace hnswlib { | |
|
||
std::unique_lock <std::mutex> lock(link_list_locks_[selectedNeighbors[idx]]); | ||
|
||
linklistsizeint *ll_other; | ||
if (level == 0) | ||
ll_other = get_linklist0(selectedNeighbors[idx]); | ||
else | ||
ll_other = get_linklist(selectedNeighbors[idx], level); | ||
linklistsizeint *ll_other = get_linklist_at_level(selectedNeighbors[idx], level); | ||
|
||
size_t sz_link_list_other = getListCount(ll_other); | ||
|
||
|
@@ -472,7 +452,7 @@ namespace hnswlib { | |
dist_t d_max = fstdistfunc_(getDataByInternalId(cur_c), getDataByInternalId(selectedNeighbors[idx]), | ||
dist_func_param_); | ||
// Heuristic: | ||
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidates; | ||
std::priority_queue<std::pair<dist_t, tableint>> candidates; | ||
candidates.emplace(d_max, cur_c); | ||
|
||
for (size_t j = 0; j < sz_link_list_other; j++) { | ||
|
@@ -528,8 +508,7 @@ namespace hnswlib { | |
bool changed = true; | ||
while (changed) { | ||
changed = false; | ||
int *data; | ||
data = (int *) get_linklist(currObj,level); | ||
int *data = (int *) get_linklist(currObj,level); | ||
int size = getListCount(data); | ||
tableint *datal = (tableint *) (data + 1); | ||
for (int i = 0; i < size; i++) { | ||
|
@@ -823,8 +802,8 @@ namespace hnswlib { | |
*((unsigned short int*)(ptr))=*((unsigned short int *)&size); | ||
} | ||
|
||
void addPoint(const void *data_point, labeltype label) { | ||
addPoint(data_point, label,-1); | ||
void addPoint(const void *data_point, labeltype label, bool bUpdate = false) { | ||
addPoint(data_point, label, -1, bUpdate); | ||
} | ||
|
||
void updatePoint(const void *dataPoint, tableint internalId, float updateNeighborProbability) { | ||
|
@@ -866,7 +845,7 @@ namespace hnswlib { | |
// if (neigh == internalId) | ||
// continue; | ||
|
||
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidates; | ||
std::priority_queue<std::pair<dist_t, tableint>> candidates; | ||
size_t size = sCand.find(neigh) == sCand.end() ? sCand.size() : sCand.size() - 1; // sCand guaranteed to have size >= 1 | ||
size_t elementsToKeep = std::min(ef_construction_, size); | ||
for (auto&& cand : sCand) { | ||
|
@@ -889,8 +868,7 @@ namespace hnswlib { | |
|
||
{ | ||
std::unique_lock <std::mutex> lock(link_list_locks_[neigh]); | ||
linklistsizeint *ll_cur; | ||
ll_cur = get_linklist_at_level(neigh, layer); | ||
linklistsizeint *ll_cur = get_linklist_at_level(neigh, layer); | ||
size_t candSize = candidates.size(); | ||
setListCount(ll_cur, candSize); | ||
tableint *data = (tableint *) (ll_cur + 1); | ||
|
@@ -941,10 +919,10 @@ namespace hnswlib { | |
throw std::runtime_error("Level of item to be updated cannot be bigger than max level"); | ||
|
||
for (int level = dataPointLevel; level >= 0; level--) { | ||
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> topCandidates = searchBaseLayer( | ||
std::priority_queue<std::pair<dist_t, tableint>> topCandidates = searchBaseLayer( | ||
currObj, dataPoint, level); | ||
|
||
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> filteredTopCandidates; | ||
std::priority_queue<std::pair<dist_t, tableint>> filteredTopCandidates; | ||
while (topCandidates.size() > 0) { | ||
if (topCandidates.top().second != dataPointInternalId) | ||
filteredTopCandidates.push(topCandidates.top()); | ||
|
@@ -977,15 +955,15 @@ namespace hnswlib { | |
return result; | ||
}; | ||
|
||
tableint addPoint(const void *data_point, labeltype label, int level) { | ||
tableint addPoint(const void *data_point, labeltype label, int level, bool bUpdate = false) { | ||
|
||
tableint cur_c = 0; | ||
{ | ||
// Checking if the element with the same label already exists | ||
// if so, updating it *instead* of creating a new element. | ||
std::unique_lock <std::mutex> templock_curr(cur_element_count_guard_); | ||
auto search = label_lookup_.find(label); | ||
if (search != label_lookup_.end()) { | ||
if (bUpdate && search != label_lookup_.end()) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I do not understand the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From sfit_1b.cpp example, it seems that you sometimes use 'label' as the unique ID of elements (when Building index), sometimes use 'label' as the class of elements (see the usage of the return value of searchKnn). I think 'lable' should mean class. For example, in MNIST its value is one of (0,1,2,3,4,5,6,7,8,9). Then, the reason of use 'bUpdate' is obviously. We don't want to update an existing class every time we add a new point. We only let this happen when we need (e.g at online training). At that time, the data of the closest node will be replaced by the new coming data. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But if the meaning of label is not class, then the question is where and how to find the classes of the elements? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the class labels should be stored by the user (in a dict or an array) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand, it is hard to add an 'ID' at data_level0_memory_? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the labels are stored in a dict or an array, it still need spend time to retrieve them. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is possible, but I think it will break some stuff like compatibility with old saved indices as this label also needs to be stored somewhere. |
||
tableint existingInternalId = search->second; | ||
|
||
templock_curr.unlock(); | ||
|
@@ -1073,7 +1051,7 @@ namespace hnswlib { | |
if (level > maxlevelcopy || level < 0) // possible? | ||
throw std::runtime_error("Level error"); | ||
|
||
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates = searchBaseLayer( | ||
std::priority_queue<std::pair<dist_t, tableint>> top_candidates = searchBaseLayer( | ||
currObj, data_point, level); | ||
if (epDeleted) { | ||
top_candidates.emplace(fstdistfunc_(data_point, getDataByInternalId(enterpoint_copy), dist_func_param_), enterpoint_copy); | ||
|
@@ -1134,7 +1112,7 @@ namespace hnswlib { | |
} | ||
} | ||
|
||
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates; | ||
std::priority_queue<std::pair<dist_t, tableint>> top_candidates; | ||
if (has_deletions_) { | ||
top_candidates=searchBaseLayerST<true,true>( | ||
currObj, query_data, std::max(ef_, k)); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That change seems to change the logic significantly...
As far as I understand the main goal is to omit the CompareByFirst or whatever in the template, but AFAIK that severely affect the performance of the queue operation (I think the default comparator is a bit more complicated, so slower).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see the "std::priority_queue" reference
"
template <class T, class Container = vector, class Compare = less > class priority_queue;
Template parameters
...
Compare
A binary predicate that takes two elements (of type T) as arguments and returns a bool.
The expression comp(a,b), where comp is an object of this type and a and b are elements in the container, shall return true if a is considered to go before b in the strict weak ordering the function defines.
The priority_queue uses this function to maintain the elements sorted in a way that preserves heap properties (i.e., that the element popped is the last according to this strict weak ordering).
This can be a function pointer or a function object, and defaults to less, which returns the same as applying the less-than operator (a<b).
"
I change "return a.first < b.first;" (this should be the default) to "return a.first > b.first;" to sort the smaller distance at the top, so we need not use minus distance. This will not "severely affect the performance".
By the way, at hnswalg.h you define many variables of "std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>". If this will affect the performance, reduce the usage of such variables will only improve the performance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changing the sign does not affect the performance, usage of the default comparator does (I can mark the code that does that). Slowness is an empirical fact that I've observed (I probably more lost more than few weeks before I realized that).