Skip to content
This repository has been archived by the owner on Oct 9, 2024. It is now read-only.

Commit

Permalink
fixes bug when obs index is set
Browse files Browse the repository at this point in the history
  • Loading branch information
justjhong committed Apr 15, 2024
1 parent ec84c24 commit e272f29
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/mrvi/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ def per_sample_inference_fn(pair):
mean_zs_,
dims=["cell_name", "sample", "latent_dim"],
coords={
"cell_name": self.adata.obs_names[indices],
"cell_name": self.adata.obs_names[indices].values,
"sample": self.sample_order,
},
name="sample_representations",
Expand All @@ -416,7 +416,7 @@ def per_sample_inference_fn(pair):
sampled_zs_,
dims=["cell_name", "mc_sample", "sample", "latent_dim"],
coords={
"cell_name": self.adata.obs_names[indices],
"cell_name": self.adata.obs_names[indices].values,
"sample": self.sample_order,
},
name="sample_representations",
Expand Down Expand Up @@ -562,7 +562,7 @@ def _compute_distance(rep):
dists,
dims=["cell_name", "sample_x", "sample_y"],
coords={
"cell_name": self.adata.obs_names[indices],
"cell_name": self.adata.obs_names[indices].values,
"sample_x": self.sample_order,
"sample_y": self.sample_order,
},
Expand All @@ -575,7 +575,7 @@ def _compute_distance(rep):
dists,
dims=["cell_name", "mc_sample", "sample_x", "sample_y"],
coords={
"cell_name": self.adata.obs_names[indices],
"cell_name": self.adata.obs_names[indices].values,
"mc_sample": np.arange(reps.shape[1]),
"sample_x": self.sample_order,
"sample_y": self.sample_order,
Expand Down
1 change: 1 addition & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
@pytest.fixture
def adata():
adata = synthetic_iid()
adata.obs.index.name = "cell_id"
adata.obs["sample"] = np.random.choice(15, size=adata.shape[0])
adata.obs["sample_str"] = [chr(i + ord("a")) for i in adata.obs["sample"]]
meta1 = np.random.randint(0, 2, size=15)
Expand Down

0 comments on commit e272f29

Please sign in to comment.