Skip to content
This repository has been archived by the owner on Oct 9, 2024. It is now read-only.

Commit

Permalink
fixes bug when obs index is set
Browse files Browse the repository at this point in the history
  • Loading branch information
justjhong committed Apr 15, 2024
1 parent ee602c6 commit 1470788
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/mrvi/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ def per_sample_inference_fn(pair):
mean_zs_,
dims=["cell_name", "sample", "latent_dim"],
coords={
"cell_name": self.adata.obs_names[indices],
"cell_name": self.adata.obs_names[indices].values,
"sample": self.sample_order,
},
name="sample_representations",
Expand All @@ -419,7 +419,7 @@ def per_sample_inference_fn(pair):
sampled_zs_,
dims=["cell_name", "mc_sample", "sample", "latent_dim"],
coords={
"cell_name": self.adata.obs_names[indices],
"cell_name": self.adata.obs_names[indices].values,
"sample": self.sample_order,
},
name="sample_representations",
Expand Down Expand Up @@ -565,7 +565,7 @@ def _compute_distance(rep):
dists,
dims=["cell_name", "sample_x", "sample_y"],
coords={
"cell_name": self.adata.obs_names[indices],
"cell_name": self.adata.obs_names[indices].values,
"sample_x": self.sample_order,
"sample_y": self.sample_order,
},
Expand All @@ -578,7 +578,7 @@ def _compute_distance(rep):
dists,
dims=["cell_name", "mc_sample", "sample_x", "sample_y"],
coords={
"cell_name": self.adata.obs_names[indices],
"cell_name": self.adata.obs_names[indices].values,
"mc_sample": np.arange(reps.shape[1]),
"sample_x": self.sample_order,
"sample_y": self.sample_order,
Expand Down
1 change: 1 addition & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
@pytest.fixture
def adata():
adata = synthetic_iid()
adata.obs.index.name = "cell_id"
adata.obs["sample"] = np.random.choice(15, size=adata.shape[0])
meta1 = np.random.randint(0, 2, size=15)
adata.obs["meta1"] = meta1[adata.obs["sample"].values]
Expand Down

0 comments on commit 1470788

Please sign in to comment.