diff --git a/nerfacc/losses.py b/nerfacc/losses.py index 6e96396..50c9eef 100644 --- a/nerfacc/losses.py +++ b/nerfacc/losses.py @@ -11,17 +11,17 @@ def distortion( ray_indices: Tensor, n_rays: int, ) -> Tensor: - """Distortion Regularization proposed in Mip-NeRF 360 (on a single GPU). + """Distortion Regularization proposed in Mip-NeRF 360. Args: - weights: [n_samples,] The weights of the samples. - t_starts: [n_samples,] The start points of the samples. - t_ends: [n_samples,] The end points of the samples. - ray_indices: [n_samples,] The ray indices of the samples. + weights: The flattened weights of the samples. Shape (n_samples,) + t_starts: The start points of the samples. Shape (n_samples,) + t_ends: The end points of the samples. Shape (n_samples,) + ray_indices: The ray indices of the samples. LongTensor with shape (n_samples,) n_rays: The total number of rays. Returns: - The per-ray distortion loss with the shape of [n_rays, 1]. + The per-ray distortion loss with the shape (n_rays, 1). """ assert ( weights.shape == t_starts.shape == t_ends.shape == ray_indices.shape