Skip to content

Commit

Permalink
Revise docstring of ApproximateNearestNeighbors to be more accurate (#…
Browse files Browse the repository at this point in the history
…793)

* add cosine to ann docstring about metric.

* rephrase the docstring of ANN to be more accurate

* explain default value euclidean of metric

* improved phrasing

* add helpful information on metric restriction for cagra, with a test case

---------

Signed-off-by: Jinfeng <[email protected]>
  • Loading branch information
lijinf2 authored Dec 3, 2024
1 parent fe8c355 commit fdd5494
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 10 deletions.
12 changes: 9 additions & 3 deletions python/src/spark_rapids_ml/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,7 +901,8 @@ class ApproximateNearestNeighbors(
"""
ApproximateNearestNeighbors retrieves k approximate nearest neighbors (ANNs) in item vectors for each query.
The key APIs are similar to the NearestNeighbor class which returns the exact k nearest neighbors.
The ApproximateNearestNeighbors is currently built on the CAGRA (graph-based) algorithm of cuVS, and the IVFFLAT and IVFPQ algorithms of cuML.
The ApproximateNearestNeighbors is currently implemented using cuvs. It supports the IVFFLAT, IVFPQ, and
CAGRA (graph-based) algorithms and follows the API conventions of cuML.
The current implementation build index independently on each data partition of item_df. Queries will be broadcast to all GPUs,
then every query probes closest centers on individual index. Local topk results will be aggregated to obtain global topk ANNs.
Expand Down Expand Up @@ -966,7 +967,10 @@ class ApproximateNearestNeighbors(
Note cuml requires M * n_bits to be multiple of 8 for the best efficiency.
metric: str (default = "euclidean")
the distance metric to use. 'ivfflat' algorithm supports ['euclidean', 'sqeuclidean', 'l2', 'inner_product'].
the distance metric to use with the default set to "euclidean" (following cuml conventions, though cuvs defaults to "sqeuclidean").
The 'ivfflat' and 'ivfpq' algorithms support ['euclidean', 'sqeuclidean', 'l2', 'inner_product', 'cosine'].
The 'cagra' algorithm supports ['sqeuclidean'], and when using 'cagra' as an algorithm,
the metric must be explicitly set to 'sqeuclidean'.
inputCol: str or List[str]
The feature column names, spark-rapids-ml supports vector, array and columnar as the input.\n
Expand Down Expand Up @@ -1283,7 +1287,9 @@ def _agg_topk(
def _cal_cagra_params_and_check(
cls, algoParams: Optional[Dict[str, Any]], metric: str, topk: int
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
assert metric == "sqeuclidean"
assert (
metric == "sqeuclidean"
), "when using 'cagra' algorithm, the metric must be explicitly set to 'sqeuclidean'."

cagra_index_params: Dict[str, Any] = {"metric": metric}
cagra_search_params: Dict[str, Any] = {}
Expand Down
25 changes: 18 additions & 7 deletions python/tests/test_approximate_nearest_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,11 +782,7 @@ def test_cagra(
TODO: support compression index param
"""

VALID_METRIC = {"sqeuclidean"}
VALID_BUILD_ALGO = {"ivf_pq", "nn_descent"}
assert (
metric in VALID_METRIC
), f"cagra currently supports metric only in {VALID_METRIC}."
assert algo_params["build_algo"] in {
"ivf_pq",
"nn_descent",
Expand Down Expand Up @@ -865,7 +861,7 @@ def test_cagra_dtype(


@pytest.mark.parametrize(
"algorithm,feature_type,max_records_per_batch,algo_params,metric",
"algorithm,feature_type,max_records_per_batch,algo_params",
[
(
"cagra",
Expand All @@ -875,7 +871,6 @@ def test_cagra_dtype(
"build_algo": "ivf_pq",
"itopk_size": 32,
},
"sqeuclidean",
),
],
)
Expand All @@ -885,12 +880,12 @@ def test_cagra_params(
feature_type: str,
max_records_per_batch: int,
algo_params: Dict[str, Any],
metric: str,
data_type: np.dtype,
caplog: LogCaptureFixture,
) -> None:

data_shape = (1000, 20)
metric = "sqeuclidean"
itopk_size = 64 if "itopk_size" not in algo_params else algo_params["itopk_size"]

internal_topk_size = math.ceil(itopk_size / 32) * 32
Expand Down Expand Up @@ -928,6 +923,22 @@ def test_cagra_params(
)
assert error_msg in caplog.text

# test metric restriction
algo_params["intermediate_graph_degree"] = 255
metric = "euclidean"
error_msg = f"when using 'cagra' algorithm, the metric must be explicitly set to 'sqeuclidean'."
with pytest.raises(AssertionError, match=error_msg):
test_cagra(
algorithm,
feature_type,
max_records_per_batch,
algo_params,
metric,
data_shape,
data_type,
n_neighbors=n_neighbors,
)


@pytest.mark.parametrize(
"combo",
Expand Down

0 comments on commit fdd5494

Please sign in to comment.