diff --git a/src/mrvi/_model.py b/src/mrvi/_model.py index 955e8d4..5488a7b 100644 --- a/src/mrvi/_model.py +++ b/src/mrvi/_model.py @@ -430,8 +430,7 @@ def per_sample_inference_fn(pair): normalization_means = normalization_means.reshape(-1, 1, 1, 1) normalization_vars = normalization_vars.reshape(-1, 1, 1, 1) normalized_dists = ( - np.clip(sampled_dists - normalization_means, a_min=0, a_max=None) - / (normalization_vars**0.5) + (sampled_dists - normalization_means) / (normalization_vars**0.5) ).mean(dim="mc_sample") # (n_cells, n_samples, n_samples) # Compute each reduction