Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add a small example #302

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,6 @@ endif()
add_executable(test_updates examples/updates_test.cpp)

add_executable(searchKnnCloserFirst_test examples/searchKnnCloserFirst_test.cpp)
add_executable(mnist mnist.cpp)

target_link_libraries(main sift_test)
2 changes: 1 addition & 1 deletion hnswlib/bruteforce.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ namespace hnswlib {

std::unordered_map<labeltype,size_t > dict_external_to_internal;

void addPoint(const void *datapoint, labeltype label) {
void addPoint(const void *datapoint, labeltype label, bool bUpdate = false) {

int idx;
{
Expand Down
92 changes: 35 additions & 57 deletions hnswlib/hnswalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ namespace hnswlib {
struct CompareByFirst {
constexpr bool operator()(std::pair<dist_t, tableint> const &a,
std::pair<dist_t, tableint> const &b) const noexcept {
return a.first < b.first;
return a.first > b.first; //let the smaller at the top
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That change seems to change the logic significantly...
As far as I understand the main goal is to omit the CompareByFirst or whatever in the template, but AFAIK that severely affect the performance of the queue operation (I think the default comparator is a bit more complicated, so slower).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see the "std::priority_queue" reference
"
template <class T, class Container = vector, class Compare = less > class priority_queue;

Template parameters

...

Compare
A binary predicate that takes two elements (of type T) as arguments and returns a bool.
The expression comp(a,b), where comp is an object of this type and a and b are elements in the container, shall return true if a is considered to go before b in the strict weak ordering the function defines.
The priority_queue uses this function to maintain the elements sorted in a way that preserves heap properties (i.e., that the element popped is the last according to this strict weak ordering).
This can be a function pointer or a function object, and defaults to less, which returns the same as applying the less-than operator (a<b).
"

I change "return a.first < b.first;" (this should be the default) to "return a.first > b.first;" to sort the smaller distance at the top, so we need not use minus distance. This will not "severely affect the performance".
By the way, at hnswalg.h you define many variables of "std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>". If this will affect the performance, reduce the usage of such variables will only improve the performance.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing the sign does not affect the performance, usage of the default comparator does (I can mark the code that does that). Slowness is an empirical fact that I've observed (I probably more lost more than few weeks before I realized that).

}
};

Expand Down Expand Up @@ -159,30 +159,30 @@ namespace hnswlib {
}


std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
std::priority_queue<std::pair<dist_t, tableint>>
searchBaseLayer(tableint ep_id, const void *data_point, int layer) {
VisitedList *vl = visited_list_pool_->getFreeVisitedList();
vl_type *visited_array = vl->mass;
vl_type visited_array_tag = vl->curV;

std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
std::priority_queue<std::pair<dist_t, tableint>> top_candidates;
Copy link
Member

@yurymalkov yurymalkov Apr 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

E.g. here removing CompareByFirst made the code significantly slower.

std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidateSet;

dist_t lowerBound;
if (!isMarkedDeleted(ep_id)) {
dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_);
top_candidates.emplace(dist, ep_id);
lowerBound = dist;
candidateSet.emplace(-dist, ep_id);
candidateSet.emplace(dist, ep_id);
} else {
lowerBound = std::numeric_limits<dist_t>::max();
candidateSet.emplace(-lowerBound, ep_id);
candidateSet.emplace(lowerBound, ep_id);
}
visited_array[ep_id] = visited_array_tag;

while (!candidateSet.empty()) {
std::pair<dist_t, tableint> curr_el_pair = candidateSet.top();
if ((-curr_el_pair.first) > lowerBound) {
if (curr_el_pair.first > lowerBound) {
break;
}
candidateSet.pop();
Expand All @@ -191,13 +191,7 @@ namespace hnswlib {

std::unique_lock <std::mutex> lock(link_list_locks_[curNodeNum]);

int *data;// = (int *)(linkList0_ + curNodeNum * size_links_per_element0_);
if (layer == 0) {
data = (int*)get_linklist0(curNodeNum);
} else {
data = (int*)get_linklist(curNodeNum, layer);
// data = (int *) (linkLists_[curNodeNum] + (layer - 1) * size_links_per_element_);
}
int *data = (int*)get_linklist_at_level(curNodeNum, layer);
size_t size = getListCount((linklistsizeint*)data);
tableint *datal = (tableint *) (data + 1);
#ifdef USE_SSE
Expand All @@ -220,7 +214,7 @@ namespace hnswlib {

dist_t dist1 = fstdistfunc_(data_point, currObj1, dist_func_param_);
if (top_candidates.size() < ef_construction_ || lowerBound > dist1) {
candidateSet.emplace(-dist1, candidate_id);
candidateSet.emplace(dist1, candidate_id);
#ifdef USE_SSE
_mm_prefetch(getDataByInternalId(candidateSet.top().second), _MM_HINT_T0);
#endif
Expand All @@ -245,24 +239,24 @@ namespace hnswlib {
mutable std::atomic<long> metric_hops;

template <bool has_deletions, bool collect_metrics=false>
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
std::priority_queue<std::pair<dist_t, tableint>>
searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef) const {
VisitedList *vl = visited_list_pool_->getFreeVisitedList();
vl_type *visited_array = vl->mass;
vl_type visited_array_tag = vl->curV;

std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
std::priority_queue<std::pair<dist_t, tableint>> top_candidates;
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidate_set;

dist_t lowerBound;
if (!has_deletions || !isMarkedDeleted(ep_id)) {
dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_);
lowerBound = dist;
top_candidates.emplace(dist, ep_id);
candidate_set.emplace(-dist, ep_id);
candidate_set.emplace(dist, ep_id);
} else {
lowerBound = std::numeric_limits<dist_t>::max();
candidate_set.emplace(-lowerBound, ep_id);
candidate_set.emplace(lowerBound, ep_id);
}

visited_array[ep_id] = visited_array_tag;
Expand All @@ -271,7 +265,7 @@ namespace hnswlib {

std::pair<dist_t, tableint> current_node_pair = candidate_set.top();

if ((-current_node_pair.first) > lowerBound) {
if (current_node_pair.first > lowerBound) {
break;
}
candidate_set.pop();
Expand Down Expand Up @@ -308,7 +302,7 @@ namespace hnswlib {
dist_t dist = fstdistfunc_(data_point, currObj1, dist_func_param_);

if (top_candidates.size() < ef || lowerBound > dist) {
candidate_set.emplace(-dist, candidate_id);
candidate_set.emplace(dist, candidate_id);
#ifdef USE_SSE
_mm_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_ +
offsetLevel0_,///////////
Expand All @@ -332,33 +326,31 @@ namespace hnswlib {
return top_candidates;
}

void getNeighborsByHeuristic2(
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates,
const size_t M) {
void getNeighborsByHeuristic2(std::priority_queue<std::pair<dist_t, tableint>> &top_candidates, const size_t M) {
if (top_candidates.size() < M) {
return;
}

std::priority_queue<std::pair<dist_t, tableint>> queue_closest;
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> queue_closest;
std::vector<std::pair<dist_t, tableint>> return_list;
while (top_candidates.size() > 0) {
queue_closest.emplace(-top_candidates.top().first, top_candidates.top().second);
queue_closest.emplace(top_candidates.top().first, top_candidates.top().second);
top_candidates.pop();
}

while (queue_closest.size()) {
if (return_list.size() >= M)
break;
std::pair<dist_t, tableint> curent_pair = queue_closest.top();
dist_t dist_to_query = -curent_pair.first;
dist_t dist_to_query = curent_pair.first;
queue_closest.pop();
bool good = true;

for (std::pair<dist_t, tableint> second_pair : return_list) {
dist_t curdist =
fstdistfunc_(getDataByInternalId(second_pair.second),
getDataByInternalId(curent_pair.second),
dist_func_param_);;
dist_func_param_);
if (curdist < dist_to_query) {
good = false;
break;
Expand All @@ -370,7 +362,7 @@ namespace hnswlib {
}

for (std::pair<dist_t, tableint> curent_pair : return_list) {
top_candidates.emplace(-curent_pair.first, curent_pair.second);
top_candidates.emplace(curent_pair.first, curent_pair.second);
}
}

Expand All @@ -379,10 +371,6 @@ namespace hnswlib {
return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_);
};

linklistsizeint *get_linklist0(tableint internal_id, char *data_level0_memory_) const {
return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_);
};

linklistsizeint *get_linklist(tableint internal_id, int level) const {
return (linklistsizeint *) (linkLists_[internal_id] + (level - 1) * size_links_per_element_);
};
Expand All @@ -392,7 +380,7 @@ namespace hnswlib {
};

tableint mutuallyConnectNewElement(const void *data_point, tableint cur_c,
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates,
std::priority_queue<std::pair<dist_t, tableint>> &top_candidates,
int level, bool isUpdate) {
size_t Mcurmax = level ? maxM_ : maxM0_;
getNeighborsByHeuristic2(top_candidates, M_);
Expand All @@ -409,11 +397,7 @@ namespace hnswlib {
tableint next_closest_entry_point = selectedNeighbors.back();

{
linklistsizeint *ll_cur;
if (level == 0)
ll_cur = get_linklist0(cur_c);
else
ll_cur = get_linklist(cur_c, level);
linklistsizeint *ll_cur = get_linklist_at_level(cur_c, level);

if (*ll_cur && !isUpdate) {
throw std::runtime_error("The newly inserted element should have blank link list");
Expand All @@ -435,11 +419,7 @@ namespace hnswlib {

std::unique_lock <std::mutex> lock(link_list_locks_[selectedNeighbors[idx]]);

linklistsizeint *ll_other;
if (level == 0)
ll_other = get_linklist0(selectedNeighbors[idx]);
else
ll_other = get_linklist(selectedNeighbors[idx], level);
linklistsizeint *ll_other = get_linklist_at_level(selectedNeighbors[idx], level);

size_t sz_link_list_other = getListCount(ll_other);

Expand Down Expand Up @@ -472,7 +452,7 @@ namespace hnswlib {
dist_t d_max = fstdistfunc_(getDataByInternalId(cur_c), getDataByInternalId(selectedNeighbors[idx]),
dist_func_param_);
// Heuristic:
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidates;
std::priority_queue<std::pair<dist_t, tableint>> candidates;
candidates.emplace(d_max, cur_c);

for (size_t j = 0; j < sz_link_list_other; j++) {
Expand Down Expand Up @@ -528,8 +508,7 @@ namespace hnswlib {
bool changed = true;
while (changed) {
changed = false;
int *data;
data = (int *) get_linklist(currObj,level);
int *data = (int *) get_linklist(currObj,level);
int size = getListCount(data);
tableint *datal = (tableint *) (data + 1);
for (int i = 0; i < size; i++) {
Expand Down Expand Up @@ -823,8 +802,8 @@ namespace hnswlib {
*((unsigned short int*)(ptr))=*((unsigned short int *)&size);
}

void addPoint(const void *data_point, labeltype label) {
addPoint(data_point, label,-1);
void addPoint(const void *data_point, labeltype label, bool bUpdate = false) {
addPoint(data_point, label, -1, bUpdate);
}

void updatePoint(const void *dataPoint, tableint internalId, float updateNeighborProbability) {
Expand Down Expand Up @@ -866,7 +845,7 @@ namespace hnswlib {
// if (neigh == internalId)
// continue;

std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidates;
std::priority_queue<std::pair<dist_t, tableint>> candidates;
size_t size = sCand.find(neigh) == sCand.end() ? sCand.size() : sCand.size() - 1; // sCand guaranteed to have size >= 1
size_t elementsToKeep = std::min(ef_construction_, size);
for (auto&& cand : sCand) {
Expand All @@ -889,8 +868,7 @@ namespace hnswlib {

{
std::unique_lock <std::mutex> lock(link_list_locks_[neigh]);
linklistsizeint *ll_cur;
ll_cur = get_linklist_at_level(neigh, layer);
linklistsizeint *ll_cur = get_linklist_at_level(neigh, layer);
size_t candSize = candidates.size();
setListCount(ll_cur, candSize);
tableint *data = (tableint *) (ll_cur + 1);
Expand Down Expand Up @@ -941,10 +919,10 @@ namespace hnswlib {
throw std::runtime_error("Level of item to be updated cannot be bigger than max level");

for (int level = dataPointLevel; level >= 0; level--) {
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> topCandidates = searchBaseLayer(
std::priority_queue<std::pair<dist_t, tableint>> topCandidates = searchBaseLayer(
currObj, dataPoint, level);

std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> filteredTopCandidates;
std::priority_queue<std::pair<dist_t, tableint>> filteredTopCandidates;
while (topCandidates.size() > 0) {
if (topCandidates.top().second != dataPointInternalId)
filteredTopCandidates.push(topCandidates.top());
Expand Down Expand Up @@ -977,15 +955,15 @@ namespace hnswlib {
return result;
};

tableint addPoint(const void *data_point, labeltype label, int level) {
tableint addPoint(const void *data_point, labeltype label, int level, bool bUpdate = false) {

tableint cur_c = 0;
{
// Checking if the element with the same label already exists
// if so, updating it *instead* of creating a new element.
std::unique_lock <std::mutex> templock_curr(cur_element_count_guard_);
auto search = label_lookup_.find(label);
if (search != label_lookup_.end()) {
if (bUpdate && search != label_lookup_.end()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not understand the bUpdate logic here. It add a new element with the same label? What happens to the first element? How label_lookup_ should operate?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From sfit_1b.cpp example, it seems that you sometimes use 'label' as the unique ID of elements (when Building index), sometimes use 'label' as the class of elements (see the usage of the return value of searchKnn). I think 'lable' should mean class. For example, in MNIST its value is one of (0,1,2,3,4,5,6,7,8,9). Then, the reason of use 'bUpdate' is obviously. We don't want to update an existing class every time we add a new point. We only let this happen when we need (e.g at online training). At that time, the data of the closest node will be replaced by the new coming data.
If this is not your original intention, omit this suggestion.

Copy link
Author

@intstellar intstellar Apr 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But if the meaning of label is not class, then the question is where and how to find the classes of the elements?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the class labels should be stored by the user (in a dict or an array)
From a design point of view, there should be a mechanism for the library user to extract the vectors back/delete/update by using a key. There are no other keys defined in the library so we ended up use labels as keys. To change that there should be a good alternative to support those operations.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand, it is hard to add an 'ID' at data_level0_memory_?

Copy link
Author

@intstellar intstellar Apr 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the labels are stored in a dict or an array, it still need spend time to retrieve them.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is possible, but I think it will break some stuff like compatibility with old saved indices as this label also needs to be stored somewhere.
It also would add complexity to the interface, there would be both id and and class in knn.
While labels for classifications can be done from the user side, both in C++ and python via simple lookups.

tableint existingInternalId = search->second;

templock_curr.unlock();
Expand Down Expand Up @@ -1073,7 +1051,7 @@ namespace hnswlib {
if (level > maxlevelcopy || level < 0) // possible?
throw std::runtime_error("Level error");

std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates = searchBaseLayer(
std::priority_queue<std::pair<dist_t, tableint>> top_candidates = searchBaseLayer(
currObj, data_point, level);
if (epDeleted) {
top_candidates.emplace(fstdistfunc_(data_point, getDataByInternalId(enterpoint_copy), dist_func_param_), enterpoint_copy);
Expand Down Expand Up @@ -1134,7 +1112,7 @@ namespace hnswlib {
}
}

std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
std::priority_queue<std::pair<dist_t, tableint>> top_candidates;
if (has_deletions_) {
top_candidates=searchBaseLayerST<true,true>(
currObj, query_data, std::max(ef_, k));
Expand Down
2 changes: 1 addition & 1 deletion hnswlib/hnswlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ namespace hnswlib {
template<typename dist_t>
class AlgorithmInterface {
public:
virtual void addPoint(const void *datapoint, labeltype label)=0;
virtual void addPoint(const void *datapoint, labeltype label, bool bUpdate = false) = 0;
virtual std::priority_queue<std::pair<dist_t, labeltype >> searchKnn(const void *, size_t) const = 0;

// Return k nearest neighbor in the order of closer fist
Expand Down
16 changes: 9 additions & 7 deletions hnswlib/visited_list_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,30 @@

#include <mutex>
#include <string.h>
#include <limits>

namespace hnswlib {
typedef unsigned short int vl_type;

class VisitedList {
public:
vl_type curV;
vl_type curV, max_vl_type;
vl_type *mass;
unsigned int numelements;

VisitedList(int numelements1) {
curV = -1;
numelements = numelements1;
mass = new vl_type[numelements];
curV = 0;
max_vl_type = std::numeric_limits<vl_type>::max();
mass = new vl_type[numelements]();
}

void reset() {
curV++;
if (curV == 0) {
memset(mass, 0, sizeof(vl_type) * numelements);
if (curV < max_vl_type){
curV++;
} else {
memset(mass, 0, sizeof(vl_type) * numelements);
curV = 1;
}
};

Expand Down Expand Up @@ -75,4 +78,3 @@ namespace hnswlib {
};
};
}

Loading