-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
custom near neigh algorithms in benchmarker (#78)
* custom near neigh * fix * fix * make jax array * update tutorial * Add changelog * bump version
- Loading branch information
1 parent
913e102
commit 39c2a7e
Showing
15 changed files
with
340 additions
and
77 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
[bumpversion] | ||
current_version = 0.1.1 | ||
current_version = 0.2.0 | ||
tag = True | ||
commit = True | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from ._dataclass import NeighborsOutput | ||
from ._jax import jax_approx_min_k | ||
from ._pynndescent import pynndescent | ||
|
||
__all__ = [ | ||
"pynndescent", | ||
"jax_approx_min_k", | ||
"NeighborsOutput", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from dataclasses import dataclass | ||
|
||
import numpy as np | ||
|
||
|
||
@dataclass | ||
class NeighborsOutput: | ||
"""Output of the nearest neighbors function. | ||
Attributes | ||
---------- | ||
distances : np.ndarray | ||
Array of distances to the nearest neighbors. | ||
indices : np.ndarray | ||
Array of indices of the nearest neighbors. | ||
""" | ||
|
||
indices: np.ndarray | ||
distances: np.ndarray |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import functools | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
import numpy as np | ||
|
||
from scib_metrics.utils import cdist, get_ndarray | ||
|
||
from ._dataclass import NeighborsOutput | ||
|
||
|
||
@functools.partial(jax.jit, static_argnames=["k", "recall_target"]) | ||
def _euclidean_ann(qy: jnp.ndarray, db: jnp.ndarray, k: int, recall_target: float = 0.95): | ||
"""Compute half squared L2 distance between query points and database points.""" | ||
dists = cdist(qy, db) | ||
return jax.lax.approx_min_k(dists, k=k, recall_target=recall_target) | ||
|
||
|
||
def jax_approx_min_k( | ||
X: np.ndarray, n_neighbors: int, recall_target: float = 0.95, chunk_size: int = 2048 | ||
) -> NeighborsOutput: | ||
"""Run approximate nearest neighbor search using jax. | ||
On TPU backends, this is approximate nearest neighbor search. On other backends, this is exact nearest neighbor search. | ||
Parameters | ||
---------- | ||
X | ||
Data matrix. | ||
n_neighbors | ||
Number of neighbors to search for. | ||
recall_target | ||
Target recall for approximate nearest neighbor search. | ||
chunk_size | ||
Number of query points to search for at once. | ||
""" | ||
db = jnp.asarray(X) | ||
# Loop over query points in chunks | ||
neighbors = [] | ||
dists = [] | ||
for i in range(0, db.shape[0], chunk_size): | ||
start = i | ||
end = min(i + chunk_size, db.shape[0]) | ||
qy = db[start:end] | ||
dist, neighbor = _euclidean_ann(qy, db, k=n_neighbors, recall_target=recall_target) | ||
neighbors.append(neighbor) | ||
dists.append(dist) | ||
neighbors = jnp.concatenate(neighbors, axis=0) | ||
dists = jnp.concatenate(dists, axis=0) | ||
return NeighborsOutput(indices=get_ndarray(neighbors), distances=get_ndarray(dists)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import numpy as np | ||
from pynndescent import NNDescent | ||
|
||
from ._dataclass import NeighborsOutput | ||
|
||
|
||
def pynndescent(X: np.ndarray, n_neighbors: int, random_state: int = 0, n_jobs: int = 1) -> NeighborsOutput: | ||
"""Run pynndescent approximate nearest neighbor search. | ||
Parameters | ||
---------- | ||
X | ||
Data matrix. | ||
n_neighbors | ||
Number of neighbors to search for. | ||
random_state | ||
Random state. | ||
n_jobs | ||
Number of jobs to use. | ||
""" | ||
# Variables from umap (https://github.com/lmcinnes/umap/blob/3f19ce19584de4cf99e3d0ae779ba13a57472cd9/umap/umap_.py#LL326-L327) | ||
# which is used by scanpy under the hood | ||
n_trees = min(64, 5 + int(round((X.shape[0]) ** 0.5 / 20.0))) | ||
n_iters = max(5, int(round(np.log2(X.shape[0])))) | ||
max_candidates = 60 | ||
|
||
knn_search_index = NNDescent( | ||
X, | ||
n_neighbors=n_neighbors, | ||
random_state=random_state, | ||
low_memory=True, | ||
n_jobs=n_jobs, | ||
compressed=False, | ||
n_trees=n_trees, | ||
n_iters=n_iters, | ||
max_candidates=max_candidates, | ||
) | ||
indices, distances = knn_search_index.neighbor_graph | ||
|
||
return NeighborsOutput(indices=indices, distances=distances) |
Oops, something went wrong.