Skip to content

Commit

Permalink
update names
Browse files Browse the repository at this point in the history
  • Loading branch information
Intron7 committed Dec 19, 2024
1 parent f71b304 commit 780a3ad
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 16 deletions.
33 changes: 19 additions & 14 deletions src/rapids_singlecell/preprocessing/_harmony/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,20 +195,20 @@ def harmonize(

R, E, O, objectives_harmony = _initialize_centroids(
Z_norm,
n_clusters,
sigma,
Pr_b,
Phi,
n_clusters=n_clusters,
sigma=sigma,
Pr_b=Pr_b,
Phi=Phi,
theta=theta,
random_state=random_state,
)
is_converged = False
for _ in range(max_iter_harmony):
_clustering(
Z_norm,
Pr_b,
Phi,
R,
Pr_b=Pr_b,
Phi=Phi,
R=R,
E=E,
O=O,
theta=theta,
Expand All @@ -220,7 +220,12 @@ def harmonize(
)

Z_hat = _correction(
Z, R, Phi, O, ridge_lambda, correction_method=correction_method
Z,
R=R,
Phi=Phi,
O=O,
ridge_lambda=ridge_lambda,
correction_method=correction_method,
)
Z_norm = _normalize_cp(Z_hat, p=2)
if _is_convergent_harmony(objectives_harmony, tol=tol_harmony):
Expand All @@ -236,11 +241,11 @@ def harmonize(

def _initialize_centroids(
Z_norm: cp.ndarray,
*,
n_clusters: int,
sigma: float,
Pr_b: cp.ndarray,
Phi: cp.ndarray,
*,
theta: cp.ndarray,
random_state: int = 0,
) -> tuple[cp.ndarray, cp.ndarray, cp.ndarray, list]:
Expand All @@ -262,7 +267,7 @@ def _initialize_centroids(
_compute_objective(
Y_norm,
Z_norm,
R,
R=R,
theta=theta,
sigma=sigma,
O=O,
Expand All @@ -275,10 +280,10 @@ def _initialize_centroids(

def _clustering(
Z_norm: cp.ndarray,
*,
Pr_b: cp.ndarray,
Phi: cp.ndarray,
R: cp.ndarray,
*,
E: cp.ndarray,
O: cp.ndarray,
theta: cp.ndarray,
Expand Down Expand Up @@ -355,7 +360,7 @@ def _clustering(
_compute_objective(
Y_norm,
Z_norm,
R,
R=R,
theta=theta,
sigma=sigma,
O=O,
Expand All @@ -370,11 +375,11 @@ def _clustering(

def _correction(
X: cp.ndarray,
*,
R: cp.ndarray,
Phi: cp.ndarray,
O: cp.ndarray,
ridge_lambda: float,
*,
correction_method: str = "fast",
) -> cp.ndarray:
if correction_method == "fast":
Expand Down Expand Up @@ -446,8 +451,8 @@ def _correction_fast(
def _compute_objective(
Y_norm: cp.ndarray,
Z_norm: cp.ndarray,
R: cp.ndarray,
*,
R: cp.ndarray,
theta: cp.ndarray,
sigma: float,
O: cp.ndarray,
Expand Down
6 changes: 4 additions & 2 deletions tests/test_harmony.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from __future__ import annotations

import pytest
import scanpy as sc

import rapids_singlecell as rsc


def test_harmony_integrate():
@pytest.mark.parametrize("correction_method", ["fast", "original"])
def test_harmony_integrate(correction_method):
"""
Test that Harmony integrate works.
Expand All @@ -17,5 +19,5 @@ def test_harmony_integrate():
sc.pp.recipe_zheng17(adata)
sc.tl.pca(adata)
adata.obs["batch"] = 1350 * ["a"] + 1350 * ["b"]
rsc.pp.harmony_integrate(adata, "batch")
rsc.pp.harmony_integrate(adata, "batch", correction_method=correction_method)
assert adata.obsm["X_pca_harmony"].shape == adata.obsm["X_pca"].shape

0 comments on commit 780a3ad

Please sign in to comment.