Skip to content

Commit

Permalink
expose chunk size, bump version (#82)
Browse files Browse the repository at this point in the history
* expose chunk size

* add rel note
  • Loading branch information
adamgayoso authored Feb 16, 2023
1 parent 84d3a77 commit 0fd9b70
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 5 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ and this project adheres to [Semantic Versioning][].
[keep a changelog]: https://keepachangelog.com/en/1.0.0/
[semantic versioning]: https://semver.org/spec/v2.0.0.html

## 0.3.1 (2022-02-16)

- Expose chunk size for silhouette ([#82][])

[#82]: https://github.com/YosefLab/scib-metrics/pull/82

## 0.3.0 (2022-02-16)

- Rename `KmeansJax` to `Kmeans` and fix ++ initialization, use Kmeans as default in benchmarker instead of Leiden ([#81][])
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ requires = ["hatchling"]

[project]
name = "scib-metrics"
version = "0.3.0"
version = "0.3.1"
description = "Accelerated and Python-only scIB metrics"
readme = "README.md"
requires-python = ">=3.8"
Expand Down
14 changes: 10 additions & 4 deletions src/scib_metrics/_silhouette.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from scib_metrics.utils import silhouette_samples


def silhouette_label(X: np.ndarray, labels: np.ndarray, rescale: bool = True) -> float:
def silhouette_label(X: np.ndarray, labels: np.ndarray, rescale: bool = True, chunk_size: int = 256) -> float:
"""Average silhouette width (ASW) :cite:p:`luecken2022benchmarking`.
Parameters
Expand All @@ -15,18 +15,22 @@ def silhouette_label(X: np.ndarray, labels: np.ndarray, rescale: bool = True) ->
Array of shape (n_cells,) representing label values
rescale
Scale asw into the range [0, 1].
chunk_size
Size of chunks to process at a time for distance computation.
Returns
-------
silhouette score
"""
asw = np.mean(silhouette_samples(X, labels))
asw = np.mean(silhouette_samples(X, labels, chunk_size=chunk_size))
if rescale:
asw = (asw + 1) / 2
return np.mean(asw)


def silhouette_batch(X: np.ndarray, labels: np.ndarray, batch: np.ndarray, rescale: bool = True) -> float:
def silhouette_batch(
X: np.ndarray, labels: np.ndarray, batch: np.ndarray, rescale: bool = True, chunk_size: int = 256
) -> float:
"""Average silhouette width (ASW) with respect to batch ids within each label :cite:p:`luecken2022benchmarking`.
Parameters
Expand All @@ -39,6 +43,8 @@ def silhouette_batch(X: np.ndarray, labels: np.ndarray, batch: np.ndarray, resca
Array of shape (n_cells,) representing batch values
rescale
Scale asw into the range [0, 1]. If True, higher values are better.
chunk_size
Size of chunks to process at a time for distance computation.
Returns
-------
Expand All @@ -55,7 +61,7 @@ def silhouette_batch(X: np.ndarray, labels: np.ndarray, batch: np.ndarray, resca
if (n_batches == 1) or (n_batches == X_subset.shape[0]):
continue

sil_per_group = silhouette_samples(X_subset, batch_subset)
sil_per_group = silhouette_samples(X_subset, batch_subset, chunk_size=chunk_size)

# take only absolute value
sil_per_group = np.abs(sil_per_group)
Expand Down

0 comments on commit 0fd9b70

Please sign in to comment.