Skip to content

Commit

Permalink
custom near neigh algorithms in benchmarker (#78)
Browse files Browse the repository at this point in the history
* custom near neigh

* fix

* fix

* make jax array

* update tutorial

* Add changelog

* bump version
  • Loading branch information
adamgayoso authored Feb 3, 2023
1 parent 913e102 commit 39c2a7e
Show file tree
Hide file tree
Showing 15 changed files with 340 additions and 77 deletions.
2 changes: 1 addition & 1 deletion .bumpversion.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.1.1
current_version = 0.2.0
tag = True
commit = True

Expand Down
28 changes: 28 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,31 @@ repos:
Fix the merge conflicts manually and remove the .rej files.
language: fail
files: '.*\.rej$'
- repo: https://github.com/nbQA-dev/nbQA
rev: 1.6.1
hooks:
- id: nbqa-pyupgrade
args: [--py38-plus]
- id: nbqa-black
- id: nbqa-isort
- id: nbqa-ruff
args: [--fix]
- id: nbqa
entry: nbqa blacken-docs
name: nbqa-blacken-docs
alias: nbqa-blacken-docs
additional_dependencies: [blacken-docs]
args: [--nbqa-md]
- id: nbqa
entry: nbqa mdformat
name: nbqa-mdformat
alias: nbqa-mdformat
additional_dependencies:
[
mdformat,
mdformat-black,
mdformat-frontmatter,
mdformat-web,
mdformat-myst,
]
args: [--nbqa-md]
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ and this project adheres to [Semantic Versioning][].
[keep a changelog]: https://keepachangelog.com/en/1.0.0/
[semantic versioning]: https://semver.org/spec/v2.0.0.html

## 0.2.0 (2022-02-02)

- Allow custom nearest neighbors methods in Benchmarker ([#78][])

## 0.1.1 (2022-01-04)

- Add new tutorial and fix scalability of lisi ([#71][])
Expand Down
14 changes: 14 additions & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,20 @@ scib_metrics.ilisi_knn(...)
utils.diffusion_nn
```

### Nearest neighbors

```{eval-rst}
.. module:: scib_metrics.nearest_neighbors
.. currentmodule:: scib_metrics
.. autosummary::
:toctree: generated
nearest_neighbors.pynndescent
nearest_neighbors.jax_approx_min_k
nearest_neighbors.NeighborsOutput
```

## Settings

An instance of the {class}`~scib_metrics._settings.ScibConfig` is available as `scib_metrics.settings` and allows configuring scib_metrics.
Expand Down
163 changes: 118 additions & 45 deletions docs/notebooks/large_scale.ipynb

Large diffs are not rendered by default.

17 changes: 10 additions & 7 deletions docs/notebooks/lung_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,10 @@
"outputs": [],
"source": [
"import numpy as np\n",
"from anndata import AnnData\n",
"import matplotlib.pyplot as plt\n",
"import scanpy as sc\n",
"from plottable import Table\n",
"\n",
"from scib_metrics.benchmark import Benchmarker\n",
"\n",
"%matplotlib inline"
]
},
Expand Down Expand Up @@ -120,7 +118,7 @@
],
"source": [
"sc.pp.highly_variable_genes(adata, n_top_genes=2000, flavor=\"cell_ranger\", batch_key=\"batch\")\n",
"sc.tl.pca(adata, n_comps=30, use_highly_variable=True)\n"
"sc.tl.pca(adata, n_comps=30, use_highly_variable=True)"
]
},
{
Expand Down Expand Up @@ -177,6 +175,7 @@
"source": [
"%%capture\n",
"import scanorama\n",
"\n",
"# List of adata per batch\n",
"batch_cats = adata.obs.batch.cat.categories\n",
"adata_list = [adata[adata.obs.batch == b].copy() for b in batch_cats]\n",
Expand Down Expand Up @@ -211,6 +210,7 @@
],
"source": [
"import pyliger\n",
"\n",
"bdata = adata.copy()\n",
"# Pyliger normalizes by library size with a size factor of 1\n",
"# So here we give it the count data\n",
Expand All @@ -220,7 +220,7 @@
"for i, ad in enumerate(adata_list):\n",
" ad.uns[\"sample_name\"] = batch_cats[i]\n",
" # Hack to make sure each method uses the same genes\n",
" ad.uns['var_gene_idx'] = np.arange(bdata.n_vars)\n",
" ad.uns[\"var_gene_idx\"] = np.arange(bdata.n_vars)\n",
"\n",
"\n",
"liger_data = pyliger.create_liger(adata_list, remove_missing=False, make_sparse=False)\n",
Expand Down Expand Up @@ -270,7 +270,8 @@
],
"source": [
"from harmony import harmonize\n",
"adata.obsm[\"Harmony\"] = harmonize(adata.obsm[\"X_pca\"], adata.obs, batch_key = \"batch\")"
"\n",
"adata.obsm[\"Harmony\"] = harmonize(adata.obsm[\"X_pca\"], adata.obs, batch_key=\"batch\")"
]
},
{
Expand Down Expand Up @@ -303,6 +304,7 @@
"source": [
"%%capture\n",
"import scvi\n",
"\n",
"scvi.model.SCVI.setup_anndata(adata, layer=\"counts\", batch_key=\"batch\")\n",
"vae = scvi.model.SCVI(adata, gene_likelihood=\"nb\", n_layers=2, n_latent=30)\n",
"vae.train()\n",
Expand Down Expand Up @@ -389,7 +391,7 @@
" embedding_obsm_keys=[\"Unintegrated\", \"Scanorama\", \"LIGER\", \"Harmony\", \"scVI\", \"scANVI\"],\n",
" n_jobs=6,\n",
")\n",
"bm.benchmark()\n"
"bm.benchmark()"
]
},
{
Expand Down Expand Up @@ -585,6 +587,7 @@
],
"source": [
"from rich import print\n",
"\n",
"df = bm.get_results(min_max_scale=False)\n",
"print(df)"
]
Expand Down
7 changes: 6 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ requires = ["hatchling"]

[project]
name = "scib-metrics"
version = "0.1.1"
version = "0.2.0"
description = "Accelerated and Python-only scIB metrics"
readme = "README.md"
requires-python = ">=3.8"
Expand Down Expand Up @@ -99,6 +99,11 @@ multi_line_output = 3
profile = "black"
skip_glob = ["docs/*"]

[tool.ruff]
line-length = 88
exclude = [".git","__pycache__","build","docs/","_build","dist"]
ignore = ["E402","E501", "F821", "E741"]

[tool.black]
line-length = 120
target-version = ['py38']
Expand Down
3 changes: 2 additions & 1 deletion src/scib_metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from importlib.metadata import version

from . import utils
from . import nearest_neighbors, utils
from ._graph_connectivity import graph_connectivity
from ._isolated_labels import isolated_labels
from ._kbet import kbet, kbet_per_label
Expand All @@ -13,6 +13,7 @@

__all__ = [
"utils",
"nearest_neighbors",
"isolated_labels",
"pcr_comparison",
"silhouette_label",
Expand Down
42 changes: 20 additions & 22 deletions src/scib_metrics/benchmark/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import asdict, dataclass
from enum import Enum
from functools import partial
from typing import Any, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union

import matplotlib
import matplotlib.pyplot as plt
Expand All @@ -14,11 +14,11 @@
from plottable import ColumnDefinition, Table
from plottable.cmap import normed_cmap
from plottable.plots import bar
from pynndescent import NNDescent
from sklearn.preprocessing import MinMaxScaler
from tqdm import tqdm

import scib_metrics
from scib_metrics.nearest_neighbors import NeighborsOutput, pynndescent

Kwargs = Dict[str, Any]
MetricType = Union[bool, Kwargs]
Expand Down Expand Up @@ -156,8 +156,17 @@ def __init__(
"Batch correction": self._batch_correction_metrics,
}

def prepare(self) -> None:
"""Prepare the data for benchmarking."""
def prepare(self, neighbor_computer: Optional[Callable[[np.ndarray, int], NeighborsOutput]] = None) -> None:
"""Prepare the data for benchmarking.
Parameters
----------
neighbor_computer
Function that computes the neighbors of the data. If `None`, the neighbors will be computed
with :func:`~scib_metrics.utils.nearest_neighbors.pynndescent`. The function should take as input
the data and the number of neighbors to compute and return a :class:`~scib_metrics.utils.nearest_neighbors.NeighborsOutput`
object.
"""
# Compute PCA
if self._pre_integrated_embedding_obsm_key is None:
# This is how scib does it
Expand All @@ -173,24 +182,13 @@ def prepare(self) -> None:

# Compute neighbors
for ad in tqdm(self._emb_adatas.values(), desc="Computing neighbors"):
# Variables from umap (https://github.com/lmcinnes/umap/blob/3f19ce19584de4cf99e3d0ae779ba13a57472cd9/umap/umap_.py#LL326-L327)
# which is used by scanpy under the hood
n_trees = min(64, 5 + int(round((ad.X.shape[0]) ** 0.5 / 20.0)))
n_iters = max(5, int(round(np.log2(ad.X.shape[0]))))
max_candidates = 60

knn_search_index = NNDescent(
ad.X,
n_neighbors=max(self._neighbor_values),
random_state=0,
low_memory=True,
n_jobs=self._n_jobs,
compressed=False,
n_trees=n_trees,
n_iters=n_iters,
max_candidates=max_candidates,
)
indices, distances = knn_search_index.neighbor_graph
if neighbor_computer is not None:
neigh_output = neighbor_computer(ad.X, max(self._neighbor_values))
else:
neigh_output = pynndescent(
ad.X, n_neighbors=max(self._neighbor_values), random_state=0, n_jobs=self._n_jobs
)
indices, distances = neigh_output.indices, neigh_output.distances
for n in self._neighbor_values:
sp_distances, sp_conns = sc.neighbors._compute_connectivities_umap(
indices[:, :n], distances[:, :n], ad.n_obs, n_neighbors=n
Expand Down
9 changes: 9 additions & 0 deletions src/scib_metrics/nearest_neighbors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from ._dataclass import NeighborsOutput
from ._jax import jax_approx_min_k
from ._pynndescent import pynndescent

__all__ = [
"pynndescent",
"jax_approx_min_k",
"NeighborsOutput",
]
19 changes: 19 additions & 0 deletions src/scib_metrics/nearest_neighbors/_dataclass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from dataclasses import dataclass

import numpy as np


@dataclass
class NeighborsOutput:
"""Output of the nearest neighbors function.
Attributes
----------
distances : np.ndarray
Array of distances to the nearest neighbors.
indices : np.ndarray
Array of indices of the nearest neighbors.
"""

indices: np.ndarray
distances: np.ndarray
50 changes: 50 additions & 0 deletions src/scib_metrics/nearest_neighbors/_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import functools

import jax
import jax.numpy as jnp
import numpy as np

from scib_metrics.utils import cdist, get_ndarray

from ._dataclass import NeighborsOutput


@functools.partial(jax.jit, static_argnames=["k", "recall_target"])
def _euclidean_ann(qy: jnp.ndarray, db: jnp.ndarray, k: int, recall_target: float = 0.95):
"""Compute half squared L2 distance between query points and database points."""
dists = cdist(qy, db)
return jax.lax.approx_min_k(dists, k=k, recall_target=recall_target)


def jax_approx_min_k(
X: np.ndarray, n_neighbors: int, recall_target: float = 0.95, chunk_size: int = 2048
) -> NeighborsOutput:
"""Run approximate nearest neighbor search using jax.
On TPU backends, this is approximate nearest neighbor search. On other backends, this is exact nearest neighbor search.
Parameters
----------
X
Data matrix.
n_neighbors
Number of neighbors to search for.
recall_target
Target recall for approximate nearest neighbor search.
chunk_size
Number of query points to search for at once.
"""
db = jnp.asarray(X)
# Loop over query points in chunks
neighbors = []
dists = []
for i in range(0, db.shape[0], chunk_size):
start = i
end = min(i + chunk_size, db.shape[0])
qy = db[start:end]
dist, neighbor = _euclidean_ann(qy, db, k=n_neighbors, recall_target=recall_target)
neighbors.append(neighbor)
dists.append(dist)
neighbors = jnp.concatenate(neighbors, axis=0)
dists = jnp.concatenate(dists, axis=0)
return NeighborsOutput(indices=get_ndarray(neighbors), distances=get_ndarray(dists))
40 changes: 40 additions & 0 deletions src/scib_metrics/nearest_neighbors/_pynndescent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import numpy as np
from pynndescent import NNDescent

from ._dataclass import NeighborsOutput


def pynndescent(X: np.ndarray, n_neighbors: int, random_state: int = 0, n_jobs: int = 1) -> NeighborsOutput:
"""Run pynndescent approximate nearest neighbor search.
Parameters
----------
X
Data matrix.
n_neighbors
Number of neighbors to search for.
random_state
Random state.
n_jobs
Number of jobs to use.
"""
# Variables from umap (https://github.com/lmcinnes/umap/blob/3f19ce19584de4cf99e3d0ae779ba13a57472cd9/umap/umap_.py#LL326-L327)
# which is used by scanpy under the hood
n_trees = min(64, 5 + int(round((X.shape[0]) ** 0.5 / 20.0)))
n_iters = max(5, int(round(np.log2(X.shape[0]))))
max_candidates = 60

knn_search_index = NNDescent(
X,
n_neighbors=n_neighbors,
random_state=random_state,
low_memory=True,
n_jobs=n_jobs,
compressed=False,
n_trees=n_trees,
n_iters=n_iters,
max_candidates=max_candidates,
)
indices, distances = knn_search_index.neighbor_graph

return NeighborsOutput(indices=indices, distances=distances)
Loading

0 comments on commit 39c2a7e

Please sign in to comment.