From 865c283b67323fa283b20d7b015859ac5d1ead90 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Mon, 6 Jan 2025 15:58:27 +0100 Subject: [PATCH 1/8] `metatensor` interface --- python/src/sphericart/metatensor/__init__.py | 1 + .../metatensor/spherical_harmonics.py | 236 ++++++++++++++++++ python/tests/test_metatensor.py | 79 ++++++ 3 files changed, 316 insertions(+) create mode 100644 python/src/sphericart/metatensor/__init__.py create mode 100644 python/src/sphericart/metatensor/spherical_harmonics.py create mode 100644 python/tests/test_metatensor.py diff --git a/python/src/sphericart/metatensor/__init__.py b/python/src/sphericart/metatensor/__init__.py new file mode 100644 index 000000000..268dc55d7 --- /dev/null +++ b/python/src/sphericart/metatensor/__init__.py @@ -0,0 +1 @@ +from .spherical_harmonics import SphericalHarmonics, SolidHarmonics # noqa diff --git a/python/src/sphericart/metatensor/spherical_harmonics.py b/python/src/sphericart/metatensor/spherical_harmonics.py new file mode 100644 index 000000000..884fd71a4 --- /dev/null +++ b/python/src/sphericart/metatensor/spherical_harmonics.py @@ -0,0 +1,236 @@ +from typing import List, Optional + +import numpy as np +from metatensor import Labels, TensorBlock, TensorMap + +from ..spherical_harmonics import SolidHarmonics as RawSolidHarmonics +from ..spherical_harmonics import SphericalHarmonics as RawSphericalHarmonics + + +class SphericalHarmonics: + + def __init__(self, l_max: int): + self.l_max = l_max + self.raw_calculator = RawSphericalHarmonics(l_max) + + # precompute some labels + self.precomputed_keys = Labels( + names=["o3_lambda"], + values=np.arange(l_max + 1).reshape(-1, 1), + ) + self.precomputed_mu_components = [ + Labels( + names=["o3_mu"], + values=np.arange(-l, l + 1).reshape(-1, 1), + ) + for l in range(l_max + 1) # noqa E741 + ] + self.precomputed_xyz_components = Labels( + names=["xyz"], + values=np.arange(2).reshape(-1, 1), + ) + self.precomputed_xyz_1_components = Labels( + names=["xyz_1"], + values=np.arange(2).reshape(-1, 1), + ) + self.precomputed_xyz_2_components = Labels( + names=["xyz_2"], + values=np.arange(2).reshape(-1, 1), + ) + self.precomputed_properties = Labels.single() + + def compute(self, xyz: TensorMap) -> TensorMap: + _check_xyz_tensor_map(xyz) + sh_values = self.raw_calculator.compute(xyz.block().values.squeeze(-1)) + return _wrap_into_tensor_map( + sh_values, + self.precomputed_keys, + xyz.block().samples, + self.precomputed_mu_components, + self.precomputed_xyz_components, + self.precomputed_xyz_1_components, + self.precomputed_xyz_2_components, + self.precomputed_properties, + ) + + def compute_with_gradients(self, xyz: TensorMap) -> TensorMap: + _check_xyz_tensor_map(xyz) + sh_values, sh_gradients = self.raw_calculator.compute_with_gradients( + xyz.block().values.squeeze(-1) + ) + return _wrap_into_tensor_map( + sh_values, + self.precomputed_keys, + xyz.block().samples, + self.precomputed_mu_components, + self.precomputed_xyz_components, + self.precomputed_xyz_1_components, + self.precomputed_xyz_2_components, + self.precomputed_properties, + sh_gradients, + ) + + def compute_with_hessians(self, xyz: TensorMap) -> TensorMap: + _check_xyz_tensor_map(xyz) + sh_values, sh_gradients, sh_hessians = ( + self.raw_calculator.compute_with_hessians(xyz.block().values.squeeze(-1)) + ) + return _wrap_into_tensor_map( + sh_values, + self.precomputed_keys, + xyz.block().samples, + self.precomputed_mu_components, + self.precomputed_properties, + self.precomputed_xyz_components, + self.precomputed_xyz_1_components, + self.precomputed_xyz_2_components, + sh_gradients, + sh_hessians, + ) + + +class SolidHarmonics: + + def __init__(self, l_max: int): + self.l_max = l_max + self.raw_calculator = RawSolidHarmonics(l_max) + + # precompute some labels + self.precomputed_keys = Labels( + names=["o3_lambda"], + values=np.arange(l_max + 1).reshape(-1, 1), + ) + self.precomputed_mu_components = [ + Labels( + names=["o3_mu"], + values=np.arange(-l, l + 1).reshape(-1, 1), + ) + for l in range(l_max + 1) # noqa E741 + ] + self.precomputed_xyz_components = Labels( + names=["xyz"], + values=np.arange(2).reshape(-1, 1), + ) + self.precomputed_xyz_1_components = Labels( + names=["xyz_1"], + values=np.arange(2).reshape(-1, 1), + ) + self.precomputed_xyz_2_components = Labels( + names=["xyz_2"], + values=np.arange(2).reshape(-1, 1), + ) + self.precomputed_properties = Labels.single() + + def compute(self, xyz: np.ndarray) -> TensorMap: + _check_xyz_tensor_map(xyz) + sh_values = self.raw_calculator.compute(xyz.block().values.squeeze(-1)) + return _wrap_into_tensor_map( + sh_values, + self.precomputed_keys, + xyz.block().samples, + self.precomputed_mu_components, + self.precomputed_xyz_components, + self.precomputed_xyz_1_components, + self.precomputed_xyz_2_components, + self.precomputed_properties, + ) + + def compute_with_gradients(self, xyz: np.ndarray) -> TensorMap: + _check_xyz_tensor_map(xyz) + sh_values, sh_gradients = self.raw_calculator.compute_with_gradients( + xyz.block().values.squeeze(-1) + ) + return _wrap_into_tensor_map( + sh_values, + self.precomputed_keys, + xyz.block().samples, + self.precomputed_mu_components, + self.precomputed_xyz_components, + self.precomputed_xyz_1_components, + self.precomputed_xyz_2_components, + self.precomputed_properties, + sh_gradients, + ) + + def compute_with_hessians(self, xyz: np.ndarray) -> TensorMap: + _check_xyz_tensor_map(xyz) + sh_values, sh_gradients, sh_hessians = ( + self.raw_calculator.compute_with_hessians(xyz.block().values.squeeze(-1)) + ) + return _wrap_into_tensor_map( + sh_values, + self.precomputed_keys, + xyz.block().samples, + self.precomputed_mu_components, + self.precomputed_xyz_components, + self.precomputed_xyz_1_components, + self.precomputed_xyz_2_components, + self.precomputed_properties, + sh_gradients, + sh_hessians, + ) + + +def _check_xyz_tensor_map(xyz: TensorMap): + if len(xyz.blocks()) != 1: + raise ValueError("`xyz` should have only one block") + if len(xyz.block().components) != 1: + raise ValueError("`xyz` should have only one component") + if xyz.block().components[0].names != ["xyz"]: + raise ValueError("`xyz` should have only one component named 'xyz'") + if xyz.block().components[0].values.shape[0] != 3: + raise ValueError("`xyz` should have 3 Cartesian coordinates") + if xyz.block().properties.values.shape[0] != 1: + raise ValueError("`xyz` should have only one property") + + +def _wrap_into_tensor_map( + sh_values: np.ndarray, + keys: Labels, + samples: Labels, + components: List[Labels], + xyz_components: Labels, + xyz_1_components: Labels, + xyz_2_components: Labels, + properties: Labels, + sh_gradients: Optional[np.ndarray] = None, + sh_hessians: Optional[np.ndarray] = None, +) -> TensorMap: + + # infer l_max + l_max = len(components) - 1 + + blocks = [] + for l in range(l_max + 1): # noqa E741 + l_start = l**2 + l_end = (l + 1) ** 2 + sh_values_block = TensorBlock( + values=sh_values[:, l_start:l_end, None], + samples=samples, + components=[components[l]], + properties=properties, + ) + if sh_gradients is not None: + sh_gradients_block = TensorBlock( + values=sh_gradients[:, :, l_start:l_end, None], + samples=samples, + components=[components[l], xyz_components], + properties=properties, + ) + if sh_hessians is not None: + sh_hessians_block = TensorBlock( + values=sh_hessians[:, :, :, l_start:l_end, None], + samples=samples, + components=[ + components[l], + xyz_1_components, + xyz_2_components, + ], + properties=properties, + ) + sh_gradients_block.add_gradient("positions", sh_hessians_block) + sh_values_block.add_gradient("positions", sh_gradients_block) + + blocks.append(sh_values_block) + + return TensorMap(keys=keys, blocks=blocks) diff --git a/python/tests/test_metatensor.py b/python/tests/test_metatensor.py new file mode 100644 index 000000000..796e116c0 --- /dev/null +++ b/python/tests/test_metatensor.py @@ -0,0 +1,79 @@ +import numpy as np +import pytest +from metatensor import Labels, TensorBlock, TensorMap + +import sphericart +import sphericart.metatensor + + +L_MAX = 15 +N_SAMPLES = 100 + + +@pytest.fixture +def xyz(): + np.random.seed(0) + return TensorMap( + keys=Labels.single(), + blocks=[ + TensorBlock( + values=np.random.rand(N_SAMPLES, 3, 1), + samples=Labels( + names=["sample"], + values=np.arange(N_SAMPLES).reshape(-1, 1), + ), + components=[ + Labels( + names=["xyz"], + values=np.arange(3).reshape(-1, 1), + ) + ], + properties=Labels.single(), + ) + ], + ) + + +def test_metatensor(xyz): + for l in range(L_MAX + 1): # noqa E741 + calculator_spherical = sphericart.metatensor.SphericalHarmonics(l) + calculator_solid = sphericart.metatensor.SolidHarmonics(l) + + spherical = calculator_spherical.compute(xyz) + solid = calculator_solid.compute(xyz) + + assert spherical.keys == Labels( + names=["o3_lambda"], + values=np.arange(l + 1).reshape(-1, 1), + ) + for single_l in range(l + 1): # noqa E741 + spherical_block = spherical.block({"o3_lambda": single_l}) + solid_block = solid.block({"o3_lambda": single_l}) + + # check samples + assert spherical_block.samples == xyz.block().samples + + # check components + assert spherical_block.components == [ + Labels( + names=["o3_mu"], + values=np.arange(-single_l, single_l + 1).reshape(-1, 1), + ) + ] + + # check properties + assert spherical_block.properties == Labels.single() + + # check values + assert np.allclose( + spherical_block.values.squeeze(-1), + sphericart.SphericalHarmonics(single_l).compute( + xyz.block().values.squeeze(-1) + )[:, single_l**2 : (single_l + 1) ** 2], + ) + assert np.allclose( + solid_block.values.squeeze(-1), + sphericart.SolidHarmonics(l).compute(xyz.block().values.squeeze(-1))[ + :, single_l**2 : (single_l + 1) ** 2 + ], + ) From ccd82ef48ecaf31985e07a58032dc945f484b79d Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Mon, 6 Jan 2025 16:16:27 +0100 Subject: [PATCH 2/8] `metatensor.torch` for `sphericart.torch` --- .../metatensor/spherical_harmonics.py | 19 +- python/tests/test_metatensor.py | 1 - .../sphericart/torch/metatensor/__init__.py | 188 ++++++++++++++++++ .../python/tests/test_metatensor.py | 78 ++++++++ tox.ini | 3 +- 5 files changed, 282 insertions(+), 7 deletions(-) create mode 100644 sphericart-torch/python/sphericart/torch/metatensor/__init__.py create mode 100644 sphericart-torch/python/tests/test_metatensor.py diff --git a/python/src/sphericart/metatensor/spherical_harmonics.py b/python/src/sphericart/metatensor/spherical_harmonics.py index 884fd71a4..c2af2341f 100644 --- a/python/src/sphericart/metatensor/spherical_harmonics.py +++ b/python/src/sphericart/metatensor/spherical_harmonics.py @@ -1,12 +1,20 @@ from typing import List, Optional import numpy as np -from metatensor import Labels, TensorBlock, TensorMap from ..spherical_harmonics import SolidHarmonics as RawSolidHarmonics from ..spherical_harmonics import SphericalHarmonics as RawSphericalHarmonics +try: + import metatensor + from metatensor import Labels, TensorMap +except ImportError as e: + raise ImportError( + "the `sphericart.metatensor` module requires `metatensor` to be installed" + ) from e + + class SphericalHarmonics: def __init__(self, l_max: int): @@ -195,6 +203,7 @@ def _wrap_into_tensor_map( properties: Labels, sh_gradients: Optional[np.ndarray] = None, sh_hessians: Optional[np.ndarray] = None, + metatensor_module=metatensor, # can be replaced with metatensor.torch ) -> TensorMap: # infer l_max @@ -204,21 +213,21 @@ def _wrap_into_tensor_map( for l in range(l_max + 1): # noqa E741 l_start = l**2 l_end = (l + 1) ** 2 - sh_values_block = TensorBlock( + sh_values_block = metatensor_module.TensorBlock( values=sh_values[:, l_start:l_end, None], samples=samples, components=[components[l]], properties=properties, ) if sh_gradients is not None: - sh_gradients_block = TensorBlock( + sh_gradients_block = metatensor_module.TensorBlock( values=sh_gradients[:, :, l_start:l_end, None], samples=samples, components=[components[l], xyz_components], properties=properties, ) if sh_hessians is not None: - sh_hessians_block = TensorBlock( + sh_hessians_block = metatensor_module.TensorBlock( values=sh_hessians[:, :, :, l_start:l_end, None], samples=samples, components=[ @@ -233,4 +242,4 @@ def _wrap_into_tensor_map( blocks.append(sh_values_block) - return TensorMap(keys=keys, blocks=blocks) + return metatensor_module.TensorMap(keys=keys, blocks=blocks) diff --git a/python/tests/test_metatensor.py b/python/tests/test_metatensor.py index 796e116c0..09939c47e 100644 --- a/python/tests/test_metatensor.py +++ b/python/tests/test_metatensor.py @@ -12,7 +12,6 @@ @pytest.fixture def xyz(): - np.random.seed(0) return TensorMap( keys=Labels.single(), blocks=[ diff --git a/sphericart-torch/python/sphericart/torch/metatensor/__init__.py b/sphericart-torch/python/sphericart/torch/metatensor/__init__.py new file mode 100644 index 000000000..b5cfa61c2 --- /dev/null +++ b/sphericart-torch/python/sphericart/torch/metatensor/__init__.py @@ -0,0 +1,188 @@ +import torch + +from .. import SolidHarmonics as RawSolidHarmonics +from .. import SphericalHarmonics as RawSphericalHarmonics +from sphericart.metatensor.spherical_harmonics import ( + _check_xyz_tensor_map, + _wrap_into_tensor_map, +) + + +try: + import metatensor.torch + from metatensor.torch import Labels, TensorMap +except ImportError as e: + raise ImportError( + "the `sphericart.torch.metatensor` module requires " + "`metatensor-torch` to be installed" + ) from e + + +class SphericalHarmonics: + + def __init__(self, l_max: int): + self.l_max = l_max + self.raw_calculator = RawSphericalHarmonics(l_max) + + # precompute some labels + self.precomputed_keys = Labels( + names=["o3_lambda"], + values=torch.arange(l_max + 1).reshape(-1, 1), + ) + self.precomputed_mu_components = [ + Labels( + names=["o3_mu"], + values=torch.arange(-l, l + 1).reshape(-1, 1), + ) + for l in range(l_max + 1) # noqa E741 + ] + self.precomputed_xyz_components = Labels( + names=["xyz"], + values=torch.arange(2).reshape(-1, 1), + ) + self.precomputed_xyz_1_components = Labels( + names=["xyz_1"], + values=torch.arange(2).reshape(-1, 1), + ) + self.precomputed_xyz_2_components = Labels( + names=["xyz_2"], + values=torch.arange(2).reshape(-1, 1), + ) + self.precomputed_properties = Labels.single() + + def compute(self, xyz: TensorMap) -> TensorMap: + _check_xyz_tensor_map(xyz) + sh_values = self.raw_calculator.compute(xyz.block().values.squeeze(-1)) + return _wrap_into_tensor_map( + sh_values, + self.precomputed_keys, + xyz.block().samples, + self.precomputed_mu_components, + self.precomputed_xyz_components, + self.precomputed_xyz_1_components, + self.precomputed_xyz_2_components, + self.precomputed_properties, + metatensor_module=metatensor.torch, + ) + + def compute_with_gradients(self, xyz: TensorMap) -> TensorMap: + _check_xyz_tensor_map(xyz) + sh_values, sh_gradients = self.raw_calculator.compute_with_gradients( + xyz.block().values.squeeze(-1) + ) + return _wrap_into_tensor_map( + sh_values, + self.precomputed_keys, + xyz.block().samples, + self.precomputed_mu_components, + self.precomputed_xyz_components, + self.precomputed_xyz_1_components, + self.precomputed_xyz_2_components, + self.precomputed_properties, + sh_gradients, + metatensor_module=metatensor.torch, + ) + + def compute_with_hessians(self, xyz: TensorMap) -> TensorMap: + _check_xyz_tensor_map(xyz) + sh_values, sh_gradients, sh_hessians = ( + self.raw_calculator.compute_with_hessians(xyz.block().values.squeeze(-1)) + ) + return _wrap_into_tensor_map( + sh_values, + self.precomputed_keys, + xyz.block().samples, + self.precomputed_mu_components, + self.precomputed_properties, + self.precomputed_xyz_components, + self.precomputed_xyz_1_components, + self.precomputed_xyz_2_components, + sh_gradients, + sh_hessians, + metatensor_module=metatensor.torch, + ) + + +class SolidHarmonics: + + def __init__(self, l_max: int): + self.l_max = l_max + self.raw_calculator = RawSolidHarmonics(l_max) + + # precompute some labels + self.precomputed_keys = Labels( + names=["o3_lambda"], + values=torch.arange(l_max + 1).reshape(-1, 1), + ) + self.precomputed_mu_components = [ + Labels( + names=["o3_mu"], + values=torch.arange(-l, l + 1).reshape(-1, 1), + ) + for l in range(l_max + 1) # noqa E741 + ] + self.precomputed_xyz_components = Labels( + names=["xyz"], + values=torch.arange(2).reshape(-1, 1), + ) + self.precomputed_xyz_1_components = Labels( + names=["xyz_1"], + values=torch.arange(2).reshape(-1, 1), + ) + self.precomputed_xyz_2_components = Labels( + names=["xyz_2"], + values=torch.arange(2).reshape(-1, 1), + ) + self.precomputed_properties = Labels.single() + + def compute(self, xyz: torch.Tensor) -> TensorMap: + _check_xyz_tensor_map(xyz) + sh_values = self.raw_calculator.compute(xyz.block().values.squeeze(-1)) + return _wrap_into_tensor_map( + sh_values, + self.precomputed_keys, + xyz.block().samples, + self.precomputed_mu_components, + self.precomputed_xyz_components, + self.precomputed_xyz_1_components, + self.precomputed_xyz_2_components, + self.precomputed_properties, + metatensor_module=metatensor.torch, + ) + + def compute_with_gradients(self, xyz: torch.Tensor) -> TensorMap: + _check_xyz_tensor_map(xyz) + sh_values, sh_gradients = self.raw_calculator.compute_with_gradients( + xyz.block().values.squeeze(-1) + ) + return _wrap_into_tensor_map( + sh_values, + self.precomputed_keys, + xyz.block().samples, + self.precomputed_mu_components, + self.precomputed_xyz_components, + self.precomputed_xyz_1_components, + self.precomputed_xyz_2_components, + self.precomputed_properties, + sh_gradients, + metatensor_module=metatensor.torch, + ) + + def compute_with_hessians(self, xyz: torch.Tensor) -> TensorMap: + _check_xyz_tensor_map(xyz) + sh_values, sh_gradients, sh_hessians = ( + self.raw_calculator.compute_with_hessians(xyz.block().values.squeeze(-1)) + ) + return _wrap_into_tensor_map( + sh_values, + self.precomputed_keys, + xyz.block().samples, + self.precomputed_mu_components, + self.precomputed_xyz_components, + self.precomputed_xyz_1_components, + self.precomputed_xyz_2_components, + self.precomputed_properties, + sh_gradients, + sh_hessians, + metatensor_module=metatensor.torch, + ) diff --git a/sphericart-torch/python/tests/test_metatensor.py b/sphericart-torch/python/tests/test_metatensor.py new file mode 100644 index 000000000..8ec6b06e3 --- /dev/null +++ b/sphericart-torch/python/tests/test_metatensor.py @@ -0,0 +1,78 @@ +import pytest +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap + +import sphericart.torch +import sphericart.torch.metatensor + + +L_MAX = 15 +N_SAMPLES = 100 + + +@pytest.fixture +def xyz(): + return TensorMap( + keys=Labels.single(), + blocks=[ + TensorBlock( + values=torch.rand(N_SAMPLES, 3, 1), + samples=Labels( + names=["sample"], + values=torch.arange(N_SAMPLES).reshape(-1, 1), + ), + components=[ + Labels( + names=["xyz"], + values=torch.arange(3).reshape(-1, 1), + ) + ], + properties=Labels.single(), + ) + ], + ) + + +def test_metatensor(xyz): + for l in range(L_MAX + 1): # noqa E741 + calculator_spherical = sphericart.torch.metatensor.SphericalHarmonics(l) + calculator_solid = sphericart.torch.metatensor.SolidHarmonics(l) + + spherical = calculator_spherical.compute(xyz) + solid = calculator_solid.compute(xyz) + + assert spherical.keys == Labels( + names=["o3_lambda"], + values=torch.arange(l + 1).reshape(-1, 1), + ) + for single_l in range(l + 1): # noqa E741 + spherical_block = spherical.block({"o3_lambda": single_l}) + solid_block = solid.block({"o3_lambda": single_l}) + + # check samples + assert spherical_block.samples == xyz.block().samples + + # check components + assert spherical_block.components == [ + Labels( + names=["o3_mu"], + values=torch.arange(-single_l, single_l + 1).reshape(-1, 1), + ) + ] + + # check properties + assert spherical_block.properties == Labels.single() + + # check values + assert torch.allclose( + spherical_block.values.squeeze(-1), + sphericart.torch.SphericalHarmonics(single_l).compute( + xyz.block().values.squeeze(-1) + )[:, single_l**2 : (single_l + 1) ** 2], + ) + assert torch.allclose( + solid_block.values.squeeze(-1), + sphericart.torch.SolidHarmonics(l).compute( + xyz.block().values.squeeze(-1) + )[:, single_l**2 : (single_l + 1) ** 2], + ) diff --git a/tox.ini b/tox.ini index aa83f3f9e..481790fc9 100644 --- a/tox.ini +++ b/tox.ini @@ -33,6 +33,7 @@ deps = numpy<2.0.0 scipy pytest + metatensor commands = pip install {[testenv]pip_install_flags} . @@ -49,8 +50,8 @@ deps = numpy<2.0.0 torch pytest - e3nn + metatensor-torch changedir = sphericart-torch passenv= From d9d6f0373a128aaf609408150ca266f99d327340 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Mon, 6 Jan 2025 17:02:50 +0100 Subject: [PATCH 3/8] Add example --- docs/requirements.txt | 3 ++ docs/src/api.rst | 1 + docs/src/conf.py | 1 + docs/src/examples.rst | 1 + docs/src/jax-api.rst | 2 +- docs/src/metatensor-api.rst | 28 ++++++++++ docs/src/metatensor-examples.rst | 13 +++++ examples/metatensor/example.py | 54 +++++++++++++++++++ python/src/sphericart/__init__.py | 1 + .../metatensor/spherical_harmonics.py | 31 +++-------- .../python/sphericart/torch/__init__.py | 2 + .../sphericart/torch/metatensor/__init__.py | 22 ++------ tox.ini | 2 + 13 files changed, 119 insertions(+), 42 deletions(-) create mode 100644 docs/src/metatensor-api.rst create mode 100644 docs/src/metatensor-examples.rst create mode 100644 examples/metatensor/example.py diff --git a/docs/requirements.txt b/docs/requirements.txt index ed2606ae1..e4ec3cd2f 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -7,3 +7,6 @@ furo # sphinx theme # jax[cpu], because python -m pip install jax, which would be triggered # by the main package's dependencies, does not install jaxlib jax[cpu] >= 0.4.18 + +# metatensor and metatensor-torch for the metatensor API +metatensor-torch diff --git a/docs/src/api.rst b/docs/src/api.rst index cd1c40a05..42673b126 100644 --- a/docs/src/api.rst +++ b/docs/src/api.rst @@ -15,6 +15,7 @@ different languages and frameworks it supports. python-api pytorch-api jax-api + metatensor-api Although the Julia API is not fully documented yet, basic usage examples are available `here `_. diff --git a/docs/src/conf.py b/docs/src/conf.py index ed4ce7c69..d0c2bb356 100644 --- a/docs/src/conf.py +++ b/docs/src/conf.py @@ -57,6 +57,7 @@ "python": ("https://docs.python.org/3", None), "torch": ("https://pytorch.org/docs/stable/", None), "e3nn": ("https://docs.e3nn.org/en/latest/", None), + "metatensor": ("https://docs.metatensor.org/latest/index.html", None), } html_theme = "furo" diff --git a/docs/src/examples.rst b/docs/src/examples.rst index cf30f858c..34d4192e2 100644 --- a/docs/src/examples.rst +++ b/docs/src/examples.rst @@ -22,6 +22,7 @@ floating-point arithmetics, and they evaluate the mean relative error between th pytorch-examples jax-examples spherical-complex + metatensor-examples Although comprehensive Julia examples are not fully available yet, basic usage is illustrated `here `_. diff --git a/docs/src/jax-api.rst b/docs/src/jax-api.rst index c7876e86a..45cd1d173 100644 --- a/docs/src/jax-api.rst +++ b/docs/src/jax-api.rst @@ -1,5 +1,5 @@ JAX API -=========== +======= The `sphericart.jax` module aims to provide a functional-style and `JAX`-friendly framework. As a result, it does not follow the same syntax as diff --git a/docs/src/metatensor-api.rst b/docs/src/metatensor-api.rst new file mode 100644 index 000000000..39b3fd00d --- /dev/null +++ b/docs/src/metatensor-api.rst @@ -0,0 +1,28 @@ +Metatensor API +============== + +``sphericart`` can be used in conjunction with +`metatensor `_ in order to attach +metadata to inputs and outputs, as well as to naturally obtain spherical harmonics, +gradients and Hessians in a single object. + +Here is the API reference for the ``sphericart.metatensor`` and +``sphericart.metatensor.torch`` modules. + +sphericart.metatensor +--------------------- + +.. autoclass:: sphericart.metatensor.SphericalHarmonics + :members: + +.. autoclass:: sphericart.metatensor.SolidHarmonics + :members: + +sphericart.metatensor.torch +--------------------------- + +.. autoclass:: sphericart.metatensor.torch.SphericalHarmonics + :members: + +.. autoclass:: sphericart.metatensor.torch.SolidHarmonics + :members: diff --git a/docs/src/metatensor-examples.rst b/docs/src/metatensor-examples.rst new file mode 100644 index 000000000..0e448a55a --- /dev/null +++ b/docs/src/metatensor-examples.rst @@ -0,0 +1,13 @@ +Using sphericart with metatensor +-------------------------------- + +``sphericart`` can be used in conjunction with +`metatensor `_ in order to attach +metadata to inputs and outputs, as well as to naturally obtain spherical harmonics, +gradients and Hessians in a single object. + +This example shows how to use the ``sphericart.metatensor`` module to compute +spherical harmonics, their gradients and their Hessians. + +.. literalinclude:: ../../examples/metatensor/example.py + :language: python diff --git a/examples/metatensor/example.py b/examples/metatensor/example.py new file mode 100644 index 000000000..37d5f65bc --- /dev/null +++ b/examples/metatensor/example.py @@ -0,0 +1,54 @@ +import numpy as np +from metatensor import Labels, TensorBlock, TensorMap + +import sphericart +import sphericart.metatensor + + +l_max = 15 +n_samples = 100 + +xyz = TensorMap( + keys=Labels.single(), + blocks=[ + TensorBlock( + values=np.random.rand(n_samples, 3, 1), + samples=Labels( + names=["sample"], + values=np.arange(n_samples).reshape(-1, 1), + ), + components=[ + Labels( + names=["xyz"], + values=np.arange(3).reshape(-1, 1), + ) + ], + properties=Labels.single(), + ) + ], + ) + +calculator = sphericart.metatensor.SphericalHarmonics(l_max) + +spherical_harmonics = calculator.compute(xyz) + +for single_l in range(l_max + 1): + spherical_single_l = spherical_harmonics.block({"o3_lambda": single_l}) + + # check values against pure sphericart + assert np.allclose( + spherical_single_l.values.squeeze(-1), + sphericart.SphericalHarmonics(single_l).compute( + xyz.block().values.squeeze(-1) + )[:, single_l**2 : (single_l + 1) ** 2], + ) + +# further example: obtaining gradients of l = 2 spherical harmonics +spherical_harmonics = calculator.compute_with_gradients(xyz) +l_2_gradients = spherical_harmonics.block({"o3_lambda": 2}).gradient("positions") + +# further example: obtaining Hessians of l = 2 spherical harmonics +spherical_harmonics = calculator.compute_with_hessians(xyz) +l_2_hessians = spherical_harmonics.block( + {"o3_lambda": 2} +).gradient("positions").gradient("positions") diff --git a/python/src/sphericart/__init__.py b/python/src/sphericart/__init__.py index 268dc55d7..555405c44 100644 --- a/python/src/sphericart/__init__.py +++ b/python/src/sphericart/__init__.py @@ -1 +1,2 @@ from .spherical_harmonics import SphericalHarmonics, SolidHarmonics # noqa +from . import metatensor # noqa diff --git a/python/src/sphericart/metatensor/spherical_harmonics.py b/python/src/sphericart/metatensor/spherical_harmonics.py index c2af2341f..264d34770 100644 --- a/python/src/sphericart/metatensor/spherical_harmonics.py +++ b/python/src/sphericart/metatensor/spherical_harmonics.py @@ -35,15 +35,11 @@ def __init__(self, l_max: int): ] self.precomputed_xyz_components = Labels( names=["xyz"], - values=np.arange(2).reshape(-1, 1), - ) - self.precomputed_xyz_1_components = Labels( - names=["xyz_1"], - values=np.arange(2).reshape(-1, 1), + values=np.arange(3).reshape(-1, 1), ) self.precomputed_xyz_2_components = Labels( names=["xyz_2"], - values=np.arange(2).reshape(-1, 1), + values=np.arange(3).reshape(-1, 1), ) self.precomputed_properties = Labels.single() @@ -56,7 +52,6 @@ def compute(self, xyz: TensorMap) -> TensorMap: xyz.block().samples, self.precomputed_mu_components, self.precomputed_xyz_components, - self.precomputed_xyz_1_components, self.precomputed_xyz_2_components, self.precomputed_properties, ) @@ -72,7 +67,6 @@ def compute_with_gradients(self, xyz: TensorMap) -> TensorMap: xyz.block().samples, self.precomputed_mu_components, self.precomputed_xyz_components, - self.precomputed_xyz_1_components, self.precomputed_xyz_2_components, self.precomputed_properties, sh_gradients, @@ -88,10 +82,9 @@ def compute_with_hessians(self, xyz: TensorMap) -> TensorMap: self.precomputed_keys, xyz.block().samples, self.precomputed_mu_components, - self.precomputed_properties, self.precomputed_xyz_components, - self.precomputed_xyz_1_components, self.precomputed_xyz_2_components, + self.precomputed_properties, sh_gradients, sh_hessians, ) @@ -117,15 +110,11 @@ def __init__(self, l_max: int): ] self.precomputed_xyz_components = Labels( names=["xyz"], - values=np.arange(2).reshape(-1, 1), - ) - self.precomputed_xyz_1_components = Labels( - names=["xyz_1"], - values=np.arange(2).reshape(-1, 1), + values=np.arange(3).reshape(-1, 1), ) self.precomputed_xyz_2_components = Labels( names=["xyz_2"], - values=np.arange(2).reshape(-1, 1), + values=np.arange(3).reshape(-1, 1), ) self.precomputed_properties = Labels.single() @@ -138,7 +127,6 @@ def compute(self, xyz: np.ndarray) -> TensorMap: xyz.block().samples, self.precomputed_mu_components, self.precomputed_xyz_components, - self.precomputed_xyz_1_components, self.precomputed_xyz_2_components, self.precomputed_properties, ) @@ -154,7 +142,6 @@ def compute_with_gradients(self, xyz: np.ndarray) -> TensorMap: xyz.block().samples, self.precomputed_mu_components, self.precomputed_xyz_components, - self.precomputed_xyz_1_components, self.precomputed_xyz_2_components, self.precomputed_properties, sh_gradients, @@ -171,7 +158,6 @@ def compute_with_hessians(self, xyz: np.ndarray) -> TensorMap: xyz.block().samples, self.precomputed_mu_components, self.precomputed_xyz_components, - self.precomputed_xyz_1_components, self.precomputed_xyz_2_components, self.precomputed_properties, sh_gradients, @@ -198,7 +184,6 @@ def _wrap_into_tensor_map( samples: Labels, components: List[Labels], xyz_components: Labels, - xyz_1_components: Labels, xyz_2_components: Labels, properties: Labels, sh_gradients: Optional[np.ndarray] = None, @@ -223,7 +208,7 @@ def _wrap_into_tensor_map( sh_gradients_block = metatensor_module.TensorBlock( values=sh_gradients[:, :, l_start:l_end, None], samples=samples, - components=[components[l], xyz_components], + components=[xyz_components, components[l]], properties=properties, ) if sh_hessians is not None: @@ -231,9 +216,9 @@ def _wrap_into_tensor_map( values=sh_hessians[:, :, :, l_start:l_end, None], samples=samples, components=[ - components[l], - xyz_1_components, xyz_2_components, + xyz_components, + components[l], ], properties=properties, ) diff --git a/sphericart-torch/python/sphericart/torch/__init__.py b/sphericart-torch/python/sphericart/torch/__init__.py index f72070a9b..dc6f37c3a 100644 --- a/sphericart-torch/python/sphericart/torch/__init__.py +++ b/sphericart-torch/python/sphericart/torch/__init__.py @@ -10,6 +10,8 @@ from ._build_torch_version import BUILD_TORCH_VERSION import re +from . import metatensor # noqa + def parse_version_string(version_string): match = re.match(r"(\d+)\.(\d+)\.(\d+)", version_string) diff --git a/sphericart-torch/python/sphericart/torch/metatensor/__init__.py b/sphericart-torch/python/sphericart/torch/metatensor/__init__.py index b5cfa61c2..d2cb6338c 100644 --- a/sphericart-torch/python/sphericart/torch/metatensor/__init__.py +++ b/sphericart-torch/python/sphericart/torch/metatensor/__init__.py @@ -38,15 +38,11 @@ def __init__(self, l_max: int): ] self.precomputed_xyz_components = Labels( names=["xyz"], - values=torch.arange(2).reshape(-1, 1), - ) - self.precomputed_xyz_1_components = Labels( - names=["xyz_1"], - values=torch.arange(2).reshape(-1, 1), + values=torch.arange(3).reshape(-1, 1), ) self.precomputed_xyz_2_components = Labels( names=["xyz_2"], - values=torch.arange(2).reshape(-1, 1), + values=torch.arange(3).reshape(-1, 1), ) self.precomputed_properties = Labels.single() @@ -59,7 +55,6 @@ def compute(self, xyz: TensorMap) -> TensorMap: xyz.block().samples, self.precomputed_mu_components, self.precomputed_xyz_components, - self.precomputed_xyz_1_components, self.precomputed_xyz_2_components, self.precomputed_properties, metatensor_module=metatensor.torch, @@ -76,7 +71,6 @@ def compute_with_gradients(self, xyz: TensorMap) -> TensorMap: xyz.block().samples, self.precomputed_mu_components, self.precomputed_xyz_components, - self.precomputed_xyz_1_components, self.precomputed_xyz_2_components, self.precomputed_properties, sh_gradients, @@ -95,7 +89,6 @@ def compute_with_hessians(self, xyz: TensorMap) -> TensorMap: self.precomputed_mu_components, self.precomputed_properties, self.precomputed_xyz_components, - self.precomputed_xyz_1_components, self.precomputed_xyz_2_components, sh_gradients, sh_hessians, @@ -123,15 +116,11 @@ def __init__(self, l_max: int): ] self.precomputed_xyz_components = Labels( names=["xyz"], - values=torch.arange(2).reshape(-1, 1), - ) - self.precomputed_xyz_1_components = Labels( - names=["xyz_1"], - values=torch.arange(2).reshape(-1, 1), + values=torch.arange(3).reshape(-1, 1), ) self.precomputed_xyz_2_components = Labels( names=["xyz_2"], - values=torch.arange(2).reshape(-1, 1), + values=torch.arange(3).reshape(-1, 1), ) self.precomputed_properties = Labels.single() @@ -144,7 +133,6 @@ def compute(self, xyz: torch.Tensor) -> TensorMap: xyz.block().samples, self.precomputed_mu_components, self.precomputed_xyz_components, - self.precomputed_xyz_1_components, self.precomputed_xyz_2_components, self.precomputed_properties, metatensor_module=metatensor.torch, @@ -161,7 +149,6 @@ def compute_with_gradients(self, xyz: torch.Tensor) -> TensorMap: xyz.block().samples, self.precomputed_mu_components, self.precomputed_xyz_components, - self.precomputed_xyz_1_components, self.precomputed_xyz_2_components, self.precomputed_properties, sh_gradients, @@ -179,7 +166,6 @@ def compute_with_hessians(self, xyz: torch.Tensor) -> TensorMap: xyz.block().samples, self.precomputed_mu_components, self.precomputed_xyz_components, - self.precomputed_xyz_1_components, self.precomputed_xyz_2_components, self.precomputed_properties, sh_gradients, diff --git a/tox.ini b/tox.ini index 481790fc9..bdff46256 100644 --- a/tox.ini +++ b/tox.ini @@ -101,6 +101,7 @@ deps = numpy<2.0.0 torch pytest + metatensor passenv= PIP_EXTRA_INDEX_URL @@ -116,6 +117,7 @@ commands = python examples/python/example.py python examples/pytorch/example.py python examples/jax/example.py + python examples/metatensor/example.py python examples/python/spherical.py python examples/python/complex.py From 6d017c1070783da7b45cdf4ac8446dd7b827a349 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Mon, 6 Jan 2025 22:26:22 +0100 Subject: [PATCH 4/8] Fix circular imports --- docs/src/conf.py | 2 +- docs/src/metatensor-api.rst | 8 +- python/src/sphericart/__init__.py | 1 - .../spherical_harmonics.py => metatensor.py} | 13 +- python/src/sphericart/metatensor/__init__.py | 1 - .../python/sphericart/torch/__init__.py | 368 +----------------- .../python/sphericart/torch/e3nn.py | 122 ++++++ .../{metatensor/__init__.py => metatensor.py} | 96 ++++- .../sphericart/torch/spherical_hamonics.py | 245 ++++++++++++ 9 files changed, 465 insertions(+), 391 deletions(-) rename python/src/sphericart/{metatensor/spherical_harmonics.py => metatensor.py} (94%) delete mode 100644 python/src/sphericart/metatensor/__init__.py create mode 100644 sphericart-torch/python/sphericart/torch/e3nn.py rename sphericart-torch/python/sphericart/torch/{metatensor/__init__.py => metatensor.py} (64%) create mode 100644 sphericart-torch/python/sphericart/torch/spherical_hamonics.py diff --git a/docs/src/conf.py b/docs/src/conf.py index d0c2bb356..7b01c9e31 100644 --- a/docs/src/conf.py +++ b/docs/src/conf.py @@ -57,7 +57,7 @@ "python": ("https://docs.python.org/3", None), "torch": ("https://pytorch.org/docs/stable/", None), "e3nn": ("https://docs.e3nn.org/en/latest/", None), - "metatensor": ("https://docs.metatensor.org/latest/index.html", None), + "metatensor": ("https://docs.metatensor.org/latest/", None), } html_theme = "furo" diff --git a/docs/src/metatensor-api.rst b/docs/src/metatensor-api.rst index 39b3fd00d..9229ad664 100644 --- a/docs/src/metatensor-api.rst +++ b/docs/src/metatensor-api.rst @@ -7,7 +7,7 @@ metadata to inputs and outputs, as well as to naturally obtain spherical harmoni gradients and Hessians in a single object. Here is the API reference for the ``sphericart.metatensor`` and -``sphericart.metatensor.torch`` modules. +``sphericart.torch.metatensor`` modules. sphericart.metatensor --------------------- @@ -18,11 +18,11 @@ sphericart.metatensor .. autoclass:: sphericart.metatensor.SolidHarmonics :members: -sphericart.metatensor.torch +sphericart.torch.metatensor --------------------------- -.. autoclass:: sphericart.metatensor.torch.SphericalHarmonics +.. autoclass:: sphericart.torch.metatensor.SphericalHarmonics :members: -.. autoclass:: sphericart.metatensor.torch.SolidHarmonics +.. autoclass:: sphericart.torch.metatensor.SolidHarmonics :members: diff --git a/python/src/sphericart/__init__.py b/python/src/sphericart/__init__.py index 555405c44..268dc55d7 100644 --- a/python/src/sphericart/__init__.py +++ b/python/src/sphericart/__init__.py @@ -1,2 +1 @@ from .spherical_harmonics import SphericalHarmonics, SolidHarmonics # noqa -from . import metatensor # noqa diff --git a/python/src/sphericart/metatensor/spherical_harmonics.py b/python/src/sphericart/metatensor.py similarity index 94% rename from python/src/sphericart/metatensor/spherical_harmonics.py rename to python/src/sphericart/metatensor.py index 264d34770..6665e83cb 100644 --- a/python/src/sphericart/metatensor/spherical_harmonics.py +++ b/python/src/sphericart/metatensor.py @@ -2,8 +2,8 @@ import numpy as np -from ..spherical_harmonics import SolidHarmonics as RawSolidHarmonics -from ..spherical_harmonics import SphericalHarmonics as RawSphericalHarmonics +from .spherical_harmonics import SolidHarmonics as RawSolidHarmonics +from .spherical_harmonics import SphericalHarmonics as RawSphericalHarmonics try: @@ -188,7 +188,6 @@ def _wrap_into_tensor_map( properties: Labels, sh_gradients: Optional[np.ndarray] = None, sh_hessians: Optional[np.ndarray] = None, - metatensor_module=metatensor, # can be replaced with metatensor.torch ) -> TensorMap: # infer l_max @@ -198,21 +197,21 @@ def _wrap_into_tensor_map( for l in range(l_max + 1): # noqa E741 l_start = l**2 l_end = (l + 1) ** 2 - sh_values_block = metatensor_module.TensorBlock( + sh_values_block = metatensor.TensorBlock( values=sh_values[:, l_start:l_end, None], samples=samples, components=[components[l]], properties=properties, ) if sh_gradients is not None: - sh_gradients_block = metatensor_module.TensorBlock( + sh_gradients_block = metatensor.TensorBlock( values=sh_gradients[:, :, l_start:l_end, None], samples=samples, components=[xyz_components, components[l]], properties=properties, ) if sh_hessians is not None: - sh_hessians_block = metatensor_module.TensorBlock( + sh_hessians_block = metatensor.TensorBlock( values=sh_hessians[:, :, :, l_start:l_end, None], samples=samples, components=[ @@ -227,4 +226,4 @@ def _wrap_into_tensor_map( blocks.append(sh_values_block) - return metatensor_module.TensorMap(keys=keys, blocks=blocks) + return metatensor.TensorMap(keys=keys, blocks=blocks) diff --git a/python/src/sphericart/metatensor/__init__.py b/python/src/sphericart/metatensor/__init__.py deleted file mode 100644 index 268dc55d7..000000000 --- a/python/src/sphericart/metatensor/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .spherical_harmonics import SphericalHarmonics, SolidHarmonics # noqa diff --git a/sphericart-torch/python/sphericart/torch/__init__.py b/sphericart-torch/python/sphericart/torch/__init__.py index dc6f37c3a..498c0f31c 100644 --- a/sphericart-torch/python/sphericart/torch/__init__.py +++ b/sphericart-torch/python/sphericart/torch/__init__.py @@ -1,16 +1,13 @@ -import math import os import sys -from types import ModuleType -from typing import List, Optional, Union, Tuple import torch -from torch import Tensor from ._build_torch_version import BUILD_TORCH_VERSION import re -from . import metatensor # noqa +from .spherical_hamonics import SphericalHarmonics, SolidHarmonics # noqa: F401 +from .e3nn import patch_e3nn, unpatch_e3nn, e3nn_spherical_harmonics # noqa: F401 def parse_version_string(version_string): @@ -64,364 +61,3 @@ def _lib_path(): # load the C++ operators and custom classes torch.classes.load_library(_lib_path()) - - -# This is a workaround to provide docstrings for the SphericalHarmonics class, -# even though it is defined as a C++ TorchScript object (and we can not figure -# out a way to extract docstrings for either classes or methods from the C++ -# code). The class reproduces the API of the TorchScript class, but has empty -# functions. Instead, when __new__ is called, an instance of the TorchScript -# class is directly returned. -class SphericalHarmonics(torch.nn.Module): - """ - Spherical harmonics calculator, which computes the real spherical harmonics - :math:`Y^m_l` up to degree ``l_max``. The calculated spherical harmonics - are consistent with the definition of real spherical harmonics from Wikipedia. - - This class can be used similarly to :py:class:`sphericart.SphericalHarmonics` - (its Python/NumPy counterpart). If the class is called directly, the outputs - support single and double backpropagation. - - >>> xyz = xyz.detach().clone().requires_grad_() - >>> sh = sphericart.torch.SphericalHarmonics(l_max=8) - >>> sh_values = sh(xyz) # or sh.compute(xyz) - >>> sh_values.sum().backward() - >>> torch.allclose(xyz.grad, sh_grads.sum(axis=-1)) - True - - By default, only single backpropagation with respect to ``xyz`` is - enabled (this includes mixed second derivatives where ``xyz`` appears - as only one of the differentiation steps). To activate support - for double backpropagation with respect to ``xyz``, please set - ``backward_second_derivatives=True`` at class creation. Warning: if - ``backward_second_derivatives`` is not set to ``True`` and double - differentiation with respect to ``xyz`` is requested, the results may - be incorrect, but a warning will be displayed. This is necessary to - provide optimal performance for both use cases. In particular, the - following will happen: - - - when using ``torch.autograd.grad`` as the second backpropagation - step, a warning will be displayed and torch will raise an error. - - when using ``torch.autograd.grad`` with ``allow_unused=True`` as - the second backpropagation step, the results will be incorrect - and only a warning will be displayed. - - when using ``backward`` as the second backpropagation step, the - results will be incorrect and only a warning will be displayed. - - when using ``torch.autograd.functional.hessian``, the results will - be incorrect and only a warning will be displayed. - - Alternatively, the class allows to return explicit forward gradients and/or - Hessians of the spherical harmonics. For example: - - >>> import torch - >>> import sphericart.torch - >>> sh = sphericart.torch.SphericalHarmonics(l_max=8) - >>> xyz = torch.rand(size=(10,3)) - >>> sh_values, sh_grads = sh.compute_with_gradients(xyz) - >>> sh_grads.shape - torch.Size([10, 3, 81]) - - This class supports TorchScript. - - :param l_max: - the maximum degree of the spherical harmonics to be calculated - :param backward_second_derivatives: - if this parameter is set to ``True``, second derivatives of the spherical - harmonics are calculated and stored during forward calls to ``compute`` - (provided that ``xyz.requires_grad`` is ``True``), making it possible to perform - double reverse-mode differentiation with respect to ``xyz``. If ``False``, only - the first derivatives will be computed and only a single reverse-mode - differentiation step will be possible with respect to ``xyz``. - - :return: a calculator, in the form of a SphericalHarmonics object - """ - - def __init__( - self, - l_max: int, - backward_second_derivatives: bool = False, - ): - super().__init__() - self.calculator = torch.classes.sphericart_torch.SphericalHarmonics( - l_max, backward_second_derivatives - ) - - def forward(self, xyz: Tensor) -> Tensor: - """ - Calculates the spherical harmonics for a set of 3D points. - - The coordinates should be stored in the ``xyz`` array. If ``xyz`` - has ``requires_grad = True`` it stores the forward derivatives which - are then used in the backward pass. - The type of the entries of ``xyz`` determines the precision used, - and the device the tensor is stored on determines whether the - CPU or CUDA implementation is used for the calculation backend. - It always supports single reverse-mode differentiation, as well as - double reverse-mode differentiation if ``backward_second_derivatives`` - was set to ``True`` during class creation. - - :param xyz: - The Cartesian coordinates of the 3D points, as a `torch.Tensor` with - shape ``(n_samples, 3)``. - - :return: - A tensor of shape ``(n_samples, (l_max+1)**2)`` containing all the - spherical harmonics up to degree `l_max` in lexicographic order. - For example, if ``l_max = 2``, The last axis will correspond to - spherical harmonics with ``(l, m) = (0, 0), (1, -1), (1, 0), (1, - 1), (2, -2), (2, -1), (2, 0), (2, 1), (2, 2)``, in this order. - """ - return self.calculator.compute(xyz) - - def compute(self, xyz: Tensor) -> Tensor: - """Equivalent to ``forward``""" - return self.calculator.compute(xyz) - - def compute_with_gradients(self, xyz: Tensor) -> Tuple[Tensor, Tensor]: - """ - Calculates the spherical harmonics for a set of 3D points, - and also returns the forward-mode derivatives. - - The coordinates should be stored in the ``xyz`` array. - The type of the entries of ``xyz`` determines the precision used, - and the device the tensor is stored on determines whether the - CPU or CUDA implementation is used for the calculation backend. - Reverse-mode differentiation is not supported for this function. - - :param xyz: - The Cartesian coordinates of the 3D points, as a `torch.Tensor` with - shape ``(n_samples, 3)``. - - :return: - A tuple that contains: - - * A ``(n_samples, (l_max+1)**2)`` tensor containing all the - spherical harmonics up to degree ``l_max`` in lexicographic order. - For example, if ``l_max = 2``, The last axis will correspond to - spherical harmonics with ``(l, m) = (0, 0), (1, -1), (1, 0), (1, - 1), (2, -2), (2, -1), (2, 0), (2, 1), (2, 2)``, in this order. - * A tensor of shape ``(n_samples, 3, (l_max+1)**2)`` containing all - the spherical harmonics' derivatives up to degree ``l_max``. The - last axis is organized in the same way as in the spherical - harmonics return array, while the second-to-last axis refers to - derivatives in the the x, y, and z directions, respectively. - - """ - return self.calculator.compute_with_gradients(xyz) - - def compute_with_hessians(self, xyz: Tensor) -> Tuple[Tensor, Tensor, Tensor]: - """ - Calculates the spherical harmonics for a set of 3D points, - and also returns the forward derivatives and second derivatives. - - The coordinates should be stored in the ``xyz`` array. - The type of the entries of ``xyz`` determines the precision used, - and the device the tensor is stored on determines whether the - CPU or CUDA implementation is used for the calculation backend. - Reverse-mode differentiation is not supported for this function. - - :param xyz: - The Cartesian coordinates of the 3D points, as a ``torch.Tensor`` with - shape ``(n_samples, 3)``. - - :return: - A tuple that contains: - - * A ``(n_samples, (l_max+1)**2)`` tensor containing all the - spherical harmonics up to degree ``l_max`` in lexicographic order. - For example, if ``l_max = 2``, The last axis will correspond to - spherical harmonics with ``(l, m) = (0, 0), (1, -1), (1, 0), (1, - 1), (2, -2), (2, -1), (2, 0), (2, 1), (2, 2)``, in this order. - * A tensor of shape ``(n_samples, 3, (l_max+1)**2)`` containing all - the spherical harmonics' derivatives up to degree ``l_max``. The - last axis is organized in the same way as in the spherical - harmonics return array, while the second-to-last axis refers to - derivatives in the the x, y, and z directions, respectively. - * A tensor of shape ``(n_samples, 3, 3, (l_max+1)**2)`` containing all - the spherical harmonics' second derivatives up to degree ``l_max``. The - last axis is organized in the same way as in the spherical - harmonics return array, while the two intermediate axes represent the - hessian dimensions. - - """ - return self.calculator.compute_with_hessians(xyz) - - def omp_num_threads(self): - """Returns the number of threads available for calculations on the CPU.""" - return self.calculator.omp_num_threads() - - def l_max(self): - """Returns the maximum angular momentum setting for this calculator.""" - return self.calculator.l_max() - - -class SolidHarmonics(torch.nn.Module): - """ - Solid harmonics calculator, up to degree ``l_max``. - - This class computes the solid harmonics, a non-normalized form of the real - spherical harmonics, i.e. :math:`r^lY^m_l`. These scaled spherical harmonics - are polynomials in the Cartesian coordinates of the input points. - - The usage of this class is identical to :py:class:`sphericart.SphericalHarmonics`. - - :param l_max: - the maximum degree of the spherical harmonics to be calculated - :param backward_second_derivatives: - if this parameter is set to ``True``, second derivatives of the spherical - harmonics are calculated and stored during forward calls to ``compute`` - (provided that ``xyz.requires_grad`` is ``True``), making it possible to perform - double reverse-mode differentiation with respect to ``xyz``. If ``False``, only - the first derivatives will be computed and only a single reverse-mode - differentiation step will be possible with respect to ``xyz``. - - :return: a calculator, in the form of a SolidHarmonics object - """ - - def __init__( - self, - l_max: int, - backward_second_derivatives: bool = False, - ): - super().__init__() - self.calculator = torch.classes.sphericart_torch.SolidHarmonics( - l_max, backward_second_derivatives - ) - - def forward(self, xyz: Tensor) -> Tensor: - """See :py:meth:`SphericalHarmonics.forward`""" - return self.calculator.compute(xyz) - - def compute(self, xyz: Tensor) -> Tensor: - """Equivalent to ``forward``""" - return self.calculator.compute(xyz) - - def compute_with_gradients(self, xyz: Tensor) -> Tuple[Tensor, Tensor]: - """See :py:meth:`SphericalHarmonics.compute_with_gradients`""" - return self.calculator.compute_with_gradients(xyz) - - def compute_with_hessians(self, xyz: Tensor) -> Tuple[Tensor, Tensor, Tensor]: - """See :py:meth:`SphericalHarmonics.compute_with_hessians`""" - return self.calculator.compute_with_hessians(xyz) - - def omp_num_threads(self): - """Returns the number of threads available for calculations on the CPU.""" - return self.calculator.omp_num_threads() - - def l_max(self): - """Returns the maximum angular momentum setting for this calculator.""" - return self.calculator.l_max() - - -def e3nn_spherical_harmonics( - l_list: Union[List[int], int], - x: Tensor, - normalize: Optional[bool] = False, - normalization: Optional[str] = "integral", -) -> torch.Tensor: - """ - Computes spherical harmonics with an interface similar to the e3nn package. - - Provides an interface that is similar to :py:func:`e3nn.o3.spherical_harmonics` - but uses :py:class:`SphericalHarmonics` for the actual calculation. - Uses the same ordering of the [x,y,z] axes, and supports the same options for - input and harmonics normalization as :py:mod:`e3nn`. However, it does not support - defining the irreps through a :py:class:`e3nn.o3._irreps.Irreps` or a string - specification, but just as a single integer or a list of integers. - - :param l_list: - Either a single integer or a list of integers specifying which - :math:`Y^m_l` should be computed. All values up to the maximum - l value are computed, so this may be inefficient for use cases - requiring a single, or few, angular momentum channels. - :param x: - A ``torch.Tensor`` containing the coordinates, in the same format - expected by the ``e3nn`` function. - :param normalize: - Flag specifying whether the input positions should be normalized - (resulting in the computation of the spherical harmonics :math:`Y^m_l`), - or whether the function should compute the solid harmonics - :math:`r^lY^m_l`. - :param normalization: - String that can be "integral", "norm", "component", that controls - a further scaling of the :math:`Y^m_l`. See the - documentation of :py:func:`e3nn.o3.spherical_harmonics()` - for a detailed explanation of the different conventions. - """ - - if not hasattr(l_list, "__len__"): - l_list = [l_list] - l_max = max(l_list) - is_range_lmax = list(l_list) == list(range(l_max + 1)) - - if normalize: - sh = SphericalHarmonics(l_max)( - torch.index_select( - x, 1, torch.tensor([2, 0, 1], dtype=torch.long, device=x.device) - ) - ) - else: - sh = SolidHarmonics(l_max)( - torch.index_select( - x, 1, torch.tensor([2, 0, 1], dtype=torch.long, device=x.device) - ) - ) - assert normalization in ["integral", "norm", "component"] - if normalization != "integral": - sh *= math.sqrt(4 * math.pi) - - if not is_range_lmax: - sh_list = [] - for l in l_list: # noqa E741 - shl = sh[:, l * l : (l + 1) * (l + 1)] - if normalization == "norm": - shl *= math.sqrt(1 / (2 * l + 1)) - sh_list.append(shl) - sh = torch.cat(sh_list, dim=-1) - elif normalization == "norm": - for l in l_list: # noqa E741 - sh[:, l * l : (l + 1) * (l + 1)] *= math.sqrt(1 / (2 * l + 1)) - - return sh - - -_E3NN_SPH = None - - -def patch_e3nn(e3nn_module: ModuleType) -> None: - """Patches the :py:mod:`e3nn` module so that - :py:func:`sphericart_torch.e3nn_spherical_harmonics` - is called in lieu of the built-in function. - - :param e3nn_module: - The alias that has been chosen for the e3nn module, - usually just ``e3nn``. - """ - - global _E3NN_SPH - if _E3NN_SPH is not None: - raise RuntimeError("It appears that e3nn has already been patched") - - _E3NN_SPH = e3nn_module.o3.spherical_harmonics - e3nn_module.o3.spherical_harmonics = e3nn_spherical_harmonics - - -def unpatch_e3nn(e3nn_module: ModuleType) -> None: - """Restore the original ``spherical_harmonics`` function - in the :py:mod:`e3nn` module.""" - - global _E3NN_SPH - if _E3NN_SPH is None: - raise RuntimeError("It appears that e3nn has not been patched") - - e3nn_module.o3.spherical_harmonics = _E3NN_SPH - _E3NN_SPH = None - - -__all__ = [ - "SphericalHarmonics", - "SolidHarmonics", - "e3nn_spherical_harmonics", - "patch_e3nn", - "unpatch_e3nn", -] diff --git a/sphericart-torch/python/sphericart/torch/e3nn.py b/sphericart-torch/python/sphericart/torch/e3nn.py new file mode 100644 index 000000000..4811d1425 --- /dev/null +++ b/sphericart-torch/python/sphericart/torch/e3nn.py @@ -0,0 +1,122 @@ +import math +from types import ModuleType +from typing import List, Optional, Union + +import torch +from torch import Tensor + +from .spherical_hamonics import SolidHarmonics, SphericalHarmonics + + +def e3nn_spherical_harmonics( + l_list: Union[List[int], int], + x: Tensor, + normalize: Optional[bool] = False, + normalization: Optional[str] = "integral", +) -> torch.Tensor: + """ + Computes spherical harmonics with an interface similar to the e3nn package. + + Provides an interface that is similar to :py:func:`e3nn.o3.spherical_harmonics` + but uses :py:class:`SphericalHarmonics` for the actual calculation. + Uses the same ordering of the [x,y,z] axes, and supports the same options for + input and harmonics normalization as :py:mod:`e3nn`. However, it does not support + defining the irreps through a :py:class:`e3nn.o3._irreps.Irreps` or a string + specification, but just as a single integer or a list of integers. + + :param l_list: + Either a single integer or a list of integers specifying which + :math:`Y^m_l` should be computed. All values up to the maximum + l value are computed, so this may be inefficient for use cases + requiring a single, or few, angular momentum channels. + :param x: + A ``torch.Tensor`` containing the coordinates, in the same format + expected by the ``e3nn`` function. + :param normalize: + Flag specifying whether the input positions should be normalized + (resulting in the computation of the spherical harmonics :math:`Y^m_l`), + or whether the function should compute the solid harmonics + :math:`r^lY^m_l`. + :param normalization: + String that can be "integral", "norm", "component", that controls + a further scaling of the :math:`Y^m_l`. See the + documentation of :py:func:`e3nn.o3.spherical_harmonics()` + for a detailed explanation of the different conventions. + """ + + if not hasattr(l_list, "__len__"): + l_list = [l_list] + l_max = max(l_list) + is_range_lmax = list(l_list) == list(range(l_max + 1)) + + if normalize: + sh = SphericalHarmonics(l_max)( + torch.index_select( + x, 1, torch.tensor([2, 0, 1], dtype=torch.long, device=x.device) + ) + ) + else: + sh = SolidHarmonics(l_max)( + torch.index_select( + x, 1, torch.tensor([2, 0, 1], dtype=torch.long, device=x.device) + ) + ) + assert normalization in ["integral", "norm", "component"] + if normalization != "integral": + sh *= math.sqrt(4 * math.pi) + + if not is_range_lmax: + sh_list = [] + for l in l_list: # noqa E741 + shl = sh[:, l * l : (l + 1) * (l + 1)] + if normalization == "norm": + shl *= math.sqrt(1 / (2 * l + 1)) + sh_list.append(shl) + sh = torch.cat(sh_list, dim=-1) + elif normalization == "norm": + for l in l_list: # noqa E741 + sh[:, l * l : (l + 1) * (l + 1)] *= math.sqrt(1 / (2 * l + 1)) + + return sh + + +_E3NN_SPH = None + + +def patch_e3nn(e3nn_module: ModuleType) -> None: + """Patches the :py:mod:`e3nn` module so that + :py:func:`sphericart_torch.e3nn_spherical_harmonics` + is called in lieu of the built-in function. + + :param e3nn_module: + The alias that has been chosen for the e3nn module, + usually just ``e3nn``. + """ + + global _E3NN_SPH + if _E3NN_SPH is not None: + raise RuntimeError("It appears that e3nn has already been patched") + + _E3NN_SPH = e3nn_module.o3.spherical_harmonics + e3nn_module.o3.spherical_harmonics = e3nn_spherical_harmonics + + +def unpatch_e3nn(e3nn_module: ModuleType) -> None: + """Restore the original ``spherical_harmonics`` function + in the :py:mod:`e3nn` module.""" + + global _E3NN_SPH + if _E3NN_SPH is None: + raise RuntimeError("It appears that e3nn has not been patched") + + e3nn_module.o3.spherical_harmonics = _E3NN_SPH + _E3NN_SPH = None + + +__all__ = [ + "SphericalHarmonics", + "SolidHarmonics", + "e3nn_spherical_harmonics", + "patch_e3nn", + "unpatch_e3nn", +] diff --git a/sphericart-torch/python/sphericart/torch/metatensor/__init__.py b/sphericart-torch/python/sphericart/torch/metatensor.py similarity index 64% rename from sphericart-torch/python/sphericart/torch/metatensor/__init__.py rename to sphericart-torch/python/sphericart/torch/metatensor.py index d2cb6338c..552d17c49 100644 --- a/sphericart-torch/python/sphericart/torch/metatensor/__init__.py +++ b/sphericart-torch/python/sphericart/torch/metatensor.py @@ -1,11 +1,9 @@ +from typing import List, Optional + import torch -from .. import SolidHarmonics as RawSolidHarmonics -from .. import SphericalHarmonics as RawSphericalHarmonics -from sphericart.metatensor.spherical_harmonics import ( - _check_xyz_tensor_map, - _wrap_into_tensor_map, -) +from . import SolidHarmonics as RawSolidHarmonics +from . import SphericalHarmonics as RawSphericalHarmonics try: @@ -19,10 +17,18 @@ class SphericalHarmonics: - - def __init__(self, l_max: int): + """ + ``metatensor``-based wrapper around the + :py:meth:`sphericart.torch.SphericalHarmonics` class. + """ + + def __init__( + self, + l_max: int, + backward_second_derivatives: bool = False, + ): self.l_max = l_max - self.raw_calculator = RawSphericalHarmonics(l_max) + self.raw_calculator = RawSphericalHarmonics(l_max, backward_second_derivatives) # precompute some labels self.precomputed_keys = Labels( @@ -98,9 +104,13 @@ def compute_with_hessians(self, xyz: TensorMap) -> TensorMap: class SolidHarmonics: - def __init__(self, l_max: int): + def __init__( + self, + l_max: int, + backward_second_derivatives: bool = False, + ): self.l_max = l_max - self.raw_calculator = RawSolidHarmonics(l_max) + self.raw_calculator = RawSolidHarmonics(l_max, backward_second_derivatives) # precompute some labels self.precomputed_keys = Labels( @@ -172,3 +182,67 @@ def compute_with_hessians(self, xyz: torch.Tensor) -> TensorMap: sh_hessians, metatensor_module=metatensor.torch, ) + + +def _check_xyz_tensor_map(xyz: TensorMap): + if len(xyz.blocks()) != 1: + raise ValueError("`xyz` should have only one block") + if len(xyz.block().components) != 1: + raise ValueError("`xyz` should have only one component") + if xyz.block().components[0].names != ["xyz"]: + raise ValueError("`xyz` should have only one component named 'xyz'") + if xyz.block().components[0].values.shape[0] != 3: + raise ValueError("`xyz` should have 3 Cartesian coordinates") + if xyz.block().properties.values.shape[0] != 1: + raise ValueError("`xyz` should have only one property") + + +def _wrap_into_tensor_map( + sh_values: torch.Tensor, + keys: Labels, + samples: Labels, + components: List[Labels], + xyz_components: Labels, + xyz_2_components: Labels, + properties: Labels, + sh_gradients: Optional[torch.Tensor] = None, + sh_hessians: Optional[torch.Tensor] = None, +) -> TensorMap: + + # infer l_max + l_max = len(components) - 1 + + blocks = [] + for l in range(l_max + 1): # noqa E741 + l_start = l**2 + l_end = (l + 1) ** 2 + sh_values_block = metatensor.TensorBlock( + values=sh_values[:, l_start:l_end, None], + samples=samples, + components=[components[l]], + properties=properties, + ) + if sh_gradients is not None: + sh_gradients_block = metatensor.TensorBlock( + values=sh_gradients[:, :, l_start:l_end, None], + samples=samples, + components=[xyz_components, components[l]], + properties=properties, + ) + if sh_hessians is not None: + sh_hessians_block = metatensor.TensorBlock( + values=sh_hessians[:, :, :, l_start:l_end, None], + samples=samples, + components=[ + xyz_2_components, + xyz_components, + components[l], + ], + properties=properties, + ) + sh_gradients_block.add_gradient("positions", sh_hessians_block) + sh_values_block.add_gradient("positions", sh_gradients_block) + + blocks.append(sh_values_block) + + return metatensor.TensorMap(keys=keys, blocks=blocks) diff --git a/sphericart-torch/python/sphericart/torch/spherical_hamonics.py b/sphericart-torch/python/sphericart/torch/spherical_hamonics.py new file mode 100644 index 000000000..e72997d3c --- /dev/null +++ b/sphericart-torch/python/sphericart/torch/spherical_hamonics.py @@ -0,0 +1,245 @@ +from typing import Tuple + +import torch +from torch import Tensor + + +class SphericalHarmonics(torch.nn.Module): + """ + Spherical harmonics calculator, which computes the real spherical harmonics + :math:`Y^m_l` up to degree ``l_max``. The calculated spherical harmonics + are consistent with the definition of real spherical harmonics from Wikipedia. + + This class can be used similarly to :py:class:`sphericart.SphericalHarmonics` + (its Python/NumPy counterpart). If the class is called directly, the outputs + support single and double backpropagation. + + >>> xyz = xyz.detach().clone().requires_grad_() + >>> sh = sphericart.torch.SphericalHarmonics(l_max=8) + >>> sh_values = sh(xyz) # or sh.compute(xyz) + >>> sh_values.sum().backward() + >>> torch.allclose(xyz.grad, sh_grads.sum(axis=-1)) + True + + By default, only single backpropagation with respect to ``xyz`` is + enabled (this includes mixed second derivatives where ``xyz`` appears + as only one of the differentiation steps). To activate support + for double backpropagation with respect to ``xyz``, please set + ``backward_second_derivatives=True`` at class creation. Warning: if + ``backward_second_derivatives`` is not set to ``True`` and double + differentiation with respect to ``xyz`` is requested, the results may + be incorrect, but a warning will be displayed. This is necessary to + provide optimal performance for both use cases. In particular, the + following will happen: + + - when using ``torch.autograd.grad`` as the second backpropagation + step, a warning will be displayed and torch will raise an error. + - when using ``torch.autograd.grad`` with ``allow_unused=True`` as + the second backpropagation step, the results will be incorrect + and only a warning will be displayed. + - when using ``backward`` as the second backpropagation step, the + results will be incorrect and only a warning will be displayed. + - when using ``torch.autograd.functional.hessian``, the results will + be incorrect and only a warning will be displayed. + + Alternatively, the class allows to return explicit forward gradients and/or + Hessians of the spherical harmonics. For example: + + >>> import torch + >>> import sphericart.torch + >>> sh = sphericart.torch.SphericalHarmonics(l_max=8) + >>> xyz = torch.rand(size=(10,3)) + >>> sh_values, sh_grads = sh.compute_with_gradients(xyz) + >>> sh_grads.shape + torch.Size([10, 3, 81]) + + This class supports TorchScript. + + :param l_max: + the maximum degree of the spherical harmonics to be calculated + :param backward_second_derivatives: + if this parameter is set to ``True``, second derivatives of the spherical + harmonics are calculated and stored during forward calls to ``compute`` + (provided that ``xyz.requires_grad`` is ``True``), making it possible to perform + double reverse-mode differentiation with respect to ``xyz``. If ``False``, only + the first derivatives will be computed and only a single reverse-mode + differentiation step will be possible with respect to ``xyz``. + + :return: a calculator, in the form of a SphericalHarmonics object + """ + + def __init__( + self, + l_max: int, + backward_second_derivatives: bool = False, + ): + super().__init__() + self.calculator = torch.classes.sphericart_torch.SphericalHarmonics( + l_max, backward_second_derivatives + ) + + def forward(self, xyz: Tensor) -> Tensor: + """ + Calculates the spherical harmonics for a set of 3D points. + + The coordinates should be stored in the ``xyz`` array. If ``xyz`` + has ``requires_grad = True`` it stores the forward derivatives which + are then used in the backward pass. + The type of the entries of ``xyz`` determines the precision used, + and the device the tensor is stored on determines whether the + CPU or CUDA implementation is used for the calculation backend. + It always supports single reverse-mode differentiation, as well as + double reverse-mode differentiation if ``backward_second_derivatives`` + was set to ``True`` during class creation. + + :param xyz: + The Cartesian coordinates of the 3D points, as a `torch.Tensor` with + shape ``(n_samples, 3)``. + + :return: + A tensor of shape ``(n_samples, (l_max+1)**2)`` containing all the + spherical harmonics up to degree `l_max` in lexicographic order. + For example, if ``l_max = 2``, The last axis will correspond to + spherical harmonics with ``(l, m) = (0, 0), (1, -1), (1, 0), (1, + 1), (2, -2), (2, -1), (2, 0), (2, 1), (2, 2)``, in this order. + """ + return self.calculator.compute(xyz) + + def compute(self, xyz: Tensor) -> Tensor: + """Equivalent to ``forward``""" + return self.calculator.compute(xyz) + + def compute_with_gradients(self, xyz: Tensor) -> Tuple[Tensor, Tensor]: + """ + Calculates the spherical harmonics for a set of 3D points, + and also returns the forward-mode derivatives. + + The coordinates should be stored in the ``xyz`` array. + The type of the entries of ``xyz`` determines the precision used, + and the device the tensor is stored on determines whether the + CPU or CUDA implementation is used for the calculation backend. + Reverse-mode differentiation is not supported for this function. + + :param xyz: + The Cartesian coordinates of the 3D points, as a `torch.Tensor` with + shape ``(n_samples, 3)``. + + :return: + A tuple that contains: + + * A ``(n_samples, (l_max+1)**2)`` tensor containing all the + spherical harmonics up to degree ``l_max`` in lexicographic order. + For example, if ``l_max = 2``, The last axis will correspond to + spherical harmonics with ``(l, m) = (0, 0), (1, -1), (1, 0), (1, + 1), (2, -2), (2, -1), (2, 0), (2, 1), (2, 2)``, in this order. + * A tensor of shape ``(n_samples, 3, (l_max+1)**2)`` containing all + the spherical harmonics' derivatives up to degree ``l_max``. The + last axis is organized in the same way as in the spherical + harmonics return array, while the second-to-last axis refers to + derivatives in the the x, y, and z directions, respectively. + + """ + return self.calculator.compute_with_gradients(xyz) + + def compute_with_hessians(self, xyz: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + """ + Calculates the spherical harmonics for a set of 3D points, + and also returns the forward derivatives and second derivatives. + + The coordinates should be stored in the ``xyz`` array. + The type of the entries of ``xyz`` determines the precision used, + and the device the tensor is stored on determines whether the + CPU or CUDA implementation is used for the calculation backend. + Reverse-mode differentiation is not supported for this function. + + :param xyz: + The Cartesian coordinates of the 3D points, as a ``torch.Tensor`` with + shape ``(n_samples, 3)``. + + :return: + A tuple that contains: + + * A ``(n_samples, (l_max+1)**2)`` tensor containing all the + spherical harmonics up to degree ``l_max`` in lexicographic order. + For example, if ``l_max = 2``, The last axis will correspond to + spherical harmonics with ``(l, m) = (0, 0), (1, -1), (1, 0), (1, + 1), (2, -2), (2, -1), (2, 0), (2, 1), (2, 2)``, in this order. + * A tensor of shape ``(n_samples, 3, (l_max+1)**2)`` containing all + the spherical harmonics' derivatives up to degree ``l_max``. The + last axis is organized in the same way as in the spherical + harmonics return array, while the second-to-last axis refers to + derivatives in the the x, y, and z directions, respectively. + * A tensor of shape ``(n_samples, 3, 3, (l_max+1)**2)`` containing all + the spherical harmonics' second derivatives up to degree ``l_max``. The + last axis is organized in the same way as in the spherical + harmonics return array, while the two intermediate axes represent the + hessian dimensions. + + """ + return self.calculator.compute_with_hessians(xyz) + + def omp_num_threads(self): + """Returns the number of threads available for calculations on the CPU.""" + return self.calculator.omp_num_threads() + + def l_max(self): + """Returns the maximum angular momentum setting for this calculator.""" + return self.calculator.l_max() + + +class SolidHarmonics(torch.nn.Module): + """ + Solid harmonics calculator, up to degree ``l_max``. + + This class computes the solid harmonics, a non-normalized form of the real + spherical harmonics, i.e. :math:`r^lY^m_l`. These scaled spherical harmonics + are polynomials in the Cartesian coordinates of the input points. + + The usage of this class is identical to :py:class:`sphericart.SphericalHarmonics`. + + :param l_max: + the maximum degree of the spherical harmonics to be calculated + :param backward_second_derivatives: + if this parameter is set to ``True``, second derivatives of the spherical + harmonics are calculated and stored during forward calls to ``compute`` + (provided that ``xyz.requires_grad`` is ``True``), making it possible to perform + double reverse-mode differentiation with respect to ``xyz``. If ``False``, only + the first derivatives will be computed and only a single reverse-mode + differentiation step will be possible with respect to ``xyz``. + + :return: a calculator, in the form of a SolidHarmonics object + """ + + def __init__( + self, + l_max: int, + backward_second_derivatives: bool = False, + ): + super().__init__() + self.calculator = torch.classes.sphericart_torch.SolidHarmonics( + l_max, backward_second_derivatives + ) + + def forward(self, xyz: Tensor) -> Tensor: + """See :py:meth:`SphericalHarmonics.forward`""" + return self.calculator.compute(xyz) + + def compute(self, xyz: Tensor) -> Tensor: + """Equivalent to ``forward``""" + return self.calculator.compute(xyz) + + def compute_with_gradients(self, xyz: Tensor) -> Tuple[Tensor, Tensor]: + """See :py:meth:`SphericalHarmonics.compute_with_gradients`""" + return self.calculator.compute_with_gradients(xyz) + + def compute_with_hessians(self, xyz: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + """See :py:meth:`SphericalHarmonics.compute_with_hessians`""" + return self.calculator.compute_with_hessians(xyz) + + def omp_num_threads(self): + """Returns the number of threads available for calculations on the CPU.""" + return self.calculator.omp_num_threads() + + def l_max(self): + """Returns the maximum angular momentum setting for this calculator.""" + return self.calculator.l_max() From df64c885b0cc8f2fa423ed95a43a1c016bbf61e5 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Tue, 7 Jan 2025 08:46:26 +0100 Subject: [PATCH 5/8] Documentation --- python/src/sphericart/metatensor.py | 75 ++++++++++++++++--- .../python/sphericart/torch/metatensor.py | 53 +++++++++---- 2 files changed, 104 insertions(+), 24 deletions(-) diff --git a/python/src/sphericart/metatensor.py b/python/src/sphericart/metatensor.py index 6665e83cb..29ed1be32 100644 --- a/python/src/sphericart/metatensor.py +++ b/python/src/sphericart/metatensor.py @@ -7,8 +7,7 @@ try: - import metatensor - from metatensor import Labels, TensorMap + from metatensor import Labels, TensorBlock, TensorMap except ImportError as e: raise ImportError( "the `sphericart.metatensor` module requires `metatensor` to be installed" @@ -16,6 +15,14 @@ class SphericalHarmonics: + """ + ``metatensor``-based wrapper around the + :py:meth:`sphericart.SphericalHarmonics` class. + + :param l_max: the maximum degree of the spherical harmonics to be calculated + + :return: a spherical harmonics calculator object + """ def __init__(self, l_max: int): self.l_max = l_max @@ -44,6 +51,18 @@ def __init__(self, l_max: int): self.precomputed_properties = Labels.single() def compute(self, xyz: TensorMap) -> TensorMap: + """ + Computes the spherical harmonics for the given Cartesian coordinates, up to + the maximum degree ``l_max`` specified during initialization. + + :param xyz: a :py:class:`metatensor.TensorMap` containing the Cartesian + coordinates of the 3D points. This ``TensorMap`` should have only one + ``TensorBlock``. In this ``TensorBlock``, the samples are arbitrary, + there must be one component named ``"xyz"`` with 3 values, and one property. + + :return: The spherical harmonics and their metadata as a + :py:class:`metatensor.TensorMap` + """ _check_xyz_tensor_map(xyz) sh_values = self.raw_calculator.compute(xyz.block().values.squeeze(-1)) return _wrap_into_tensor_map( @@ -57,6 +76,17 @@ def compute(self, xyz: TensorMap) -> TensorMap: ) def compute_with_gradients(self, xyz: TensorMap) -> TensorMap: + """ + Computes the spherical harmonics for the given Cartesian coordinates, up to + the maximum degree ``l_max`` specified during initialization, + together with their gradients with respect to the Cartesian coordinates. + + :param xyz: see :py:meth:`compute` + + :return: The spherical harmonics and their metadata as a + :py:class:`metatensor.TensorMap`. Each ``TensorBlock`` in the output + ``TensorMap`` will have a gradient with respect to the Cartesian positions. + """ _check_xyz_tensor_map(xyz) sh_values, sh_gradients = self.raw_calculator.compute_with_gradients( xyz.block().values.squeeze(-1) @@ -73,6 +103,19 @@ def compute_with_gradients(self, xyz: TensorMap) -> TensorMap: ) def compute_with_hessians(self, xyz: TensorMap) -> TensorMap: + """ + Computes the spherical harmonics for the given Cartesian coordinates, up to + the maximum degree ``l_max`` specified during initialization, + together with their gradients and Hessians with respect to the Cartesian + coordinates. + + :param xyz: see :py:meth:`compute` + + :return: The spherical harmonics and their metadata as a + :py:class:`metatensor.TensorMap`. Each ``TensorBlock`` in the output + ``TensorMap`` will have a gradient with respect to the Cartesian positions, + which will itself have a gradient with respect to the Cartesian positions. + """ _check_xyz_tensor_map(xyz) sh_values, sh_gradients, sh_hessians = ( self.raw_calculator.compute_with_hessians(xyz.block().values.squeeze(-1)) @@ -91,6 +134,11 @@ def compute_with_hessians(self, xyz: TensorMap) -> TensorMap: class SolidHarmonics: + """ + ``metatensor``-based wrapper around the :py:meth:`sphericart.SolidHarmonics` class. + + See :py:class:`SphericalHarmonics` for more details. + """ def __init__(self, l_max: int): self.l_max = l_max @@ -118,7 +166,10 @@ def __init__(self, l_max: int): ) self.precomputed_properties = Labels.single() - def compute(self, xyz: np.ndarray) -> TensorMap: + def compute(self, xyz: TensorMap) -> TensorMap: + """ + See :py:meth:`sphericart.metatensor.SphericalHarmonics.compute`. + """ _check_xyz_tensor_map(xyz) sh_values = self.raw_calculator.compute(xyz.block().values.squeeze(-1)) return _wrap_into_tensor_map( @@ -131,7 +182,10 @@ def compute(self, xyz: np.ndarray) -> TensorMap: self.precomputed_properties, ) - def compute_with_gradients(self, xyz: np.ndarray) -> TensorMap: + def compute_with_gradients(self, xyz: TensorMap) -> TensorMap: + """ + See :py:meth:`sphericart.metatensor.SphericalHarmonics.compute_with_gradients`. + """ _check_xyz_tensor_map(xyz) sh_values, sh_gradients = self.raw_calculator.compute_with_gradients( xyz.block().values.squeeze(-1) @@ -147,7 +201,10 @@ def compute_with_gradients(self, xyz: np.ndarray) -> TensorMap: sh_gradients, ) - def compute_with_hessians(self, xyz: np.ndarray) -> TensorMap: + def compute_with_hessians(self, xyz: TensorMap) -> TensorMap: + """ + See :py:meth:`sphericart.metatensor.SphericalHarmonics.compute_with_hessians`. + """ _check_xyz_tensor_map(xyz) sh_values, sh_gradients, sh_hessians = ( self.raw_calculator.compute_with_hessians(xyz.block().values.squeeze(-1)) @@ -197,21 +254,21 @@ def _wrap_into_tensor_map( for l in range(l_max + 1): # noqa E741 l_start = l**2 l_end = (l + 1) ** 2 - sh_values_block = metatensor.TensorBlock( + sh_values_block = TensorBlock( values=sh_values[:, l_start:l_end, None], samples=samples, components=[components[l]], properties=properties, ) if sh_gradients is not None: - sh_gradients_block = metatensor.TensorBlock( + sh_gradients_block = TensorBlock( values=sh_gradients[:, :, l_start:l_end, None], samples=samples, components=[xyz_components, components[l]], properties=properties, ) if sh_hessians is not None: - sh_hessians_block = metatensor.TensorBlock( + sh_hessians_block = TensorBlock( values=sh_hessians[:, :, :, l_start:l_end, None], samples=samples, components=[ @@ -226,4 +283,4 @@ def _wrap_into_tensor_map( blocks.append(sh_values_block) - return metatensor.TensorMap(keys=keys, blocks=blocks) + return TensorMap(keys=keys, blocks=blocks) diff --git a/sphericart-torch/python/sphericart/torch/metatensor.py b/sphericart-torch/python/sphericart/torch/metatensor.py index 552d17c49..c9183fd98 100644 --- a/sphericart-torch/python/sphericart/torch/metatensor.py +++ b/sphericart-torch/python/sphericart/torch/metatensor.py @@ -7,8 +7,7 @@ try: - import metatensor.torch - from metatensor.torch import Labels, TensorMap + from metatensor.torch import Labels, TensorBlock, TensorMap except ImportError as e: raise ImportError( "the `sphericart.torch.metatensor` module requires " @@ -20,6 +19,10 @@ class SphericalHarmonics: """ ``metatensor``-based wrapper around the :py:meth:`sphericart.torch.SphericalHarmonics` class. + + See :py:class:`sphericart.metatensor.SphericalHarmonics` for more details. + ``backward_second_derivatives`` has the same meaning as in + :py:class:`sphericart.torch.SphericalHarmonics`. """ def __init__( @@ -53,6 +56,9 @@ def __init__( self.precomputed_properties = Labels.single() def compute(self, xyz: TensorMap) -> TensorMap: + """ + See :py:meth:`sphericart.metatensor.SphericalHarmonics.compute`. + """ _check_xyz_tensor_map(xyz) sh_values = self.raw_calculator.compute(xyz.block().values.squeeze(-1)) return _wrap_into_tensor_map( @@ -63,10 +69,12 @@ def compute(self, xyz: TensorMap) -> TensorMap: self.precomputed_xyz_components, self.precomputed_xyz_2_components, self.precomputed_properties, - metatensor_module=metatensor.torch, ) def compute_with_gradients(self, xyz: TensorMap) -> TensorMap: + """ + See :py:meth:`sphericart.metatensor.SphericalHarmonics.compute_with_gradients`. + """ _check_xyz_tensor_map(xyz) sh_values, sh_gradients = self.raw_calculator.compute_with_gradients( xyz.block().values.squeeze(-1) @@ -80,10 +88,12 @@ def compute_with_gradients(self, xyz: TensorMap) -> TensorMap: self.precomputed_xyz_2_components, self.precomputed_properties, sh_gradients, - metatensor_module=metatensor.torch, ) def compute_with_hessians(self, xyz: TensorMap) -> TensorMap: + """ + See :py:meth:`sphericart.metatensor.SphericalHarmonics.compute_with_hessians`. + """ _check_xyz_tensor_map(xyz) sh_values, sh_gradients, sh_hessians = ( self.raw_calculator.compute_with_hessians(xyz.block().values.squeeze(-1)) @@ -98,11 +108,18 @@ def compute_with_hessians(self, xyz: TensorMap) -> TensorMap: self.precomputed_xyz_2_components, sh_gradients, sh_hessians, - metatensor_module=metatensor.torch, ) class SolidHarmonics: + """ + ``metatensor``-based wrapper around the + :py:meth:`sphericart.torch.SolidHarmonics` class. + + See :py:class:`sphericart.metatensor.SphericalHarmonics` for more details. + ``backward_second_derivatives`` has the same meaning as in + :py:class:`sphericart.torch.SphericalHarmonics`. + """ def __init__( self, @@ -134,7 +151,10 @@ def __init__( ) self.precomputed_properties = Labels.single() - def compute(self, xyz: torch.Tensor) -> TensorMap: + def compute(self, xyz: TensorMap) -> TensorMap: + """ + See :py:meth:`sphericart.metatensor.SphericalHarmonics.compute`. + """ _check_xyz_tensor_map(xyz) sh_values = self.raw_calculator.compute(xyz.block().values.squeeze(-1)) return _wrap_into_tensor_map( @@ -145,10 +165,12 @@ def compute(self, xyz: torch.Tensor) -> TensorMap: self.precomputed_xyz_components, self.precomputed_xyz_2_components, self.precomputed_properties, - metatensor_module=metatensor.torch, ) - def compute_with_gradients(self, xyz: torch.Tensor) -> TensorMap: + def compute_with_gradients(self, xyz: TensorMap) -> TensorMap: + """ + See :py:meth:`sphericart.metatensor.SphericalHarmonics.compute_with_gradients`. + """ _check_xyz_tensor_map(xyz) sh_values, sh_gradients = self.raw_calculator.compute_with_gradients( xyz.block().values.squeeze(-1) @@ -162,10 +184,12 @@ def compute_with_gradients(self, xyz: torch.Tensor) -> TensorMap: self.precomputed_xyz_2_components, self.precomputed_properties, sh_gradients, - metatensor_module=metatensor.torch, ) - def compute_with_hessians(self, xyz: torch.Tensor) -> TensorMap: + def compute_with_hessians(self, xyz: TensorMap) -> TensorMap: + """ + See :py:meth:`sphericart.metatensor.SphericalHarmonics.compute_with_hessians`. + """ _check_xyz_tensor_map(xyz) sh_values, sh_gradients, sh_hessians = ( self.raw_calculator.compute_with_hessians(xyz.block().values.squeeze(-1)) @@ -180,7 +204,6 @@ def compute_with_hessians(self, xyz: torch.Tensor) -> TensorMap: self.precomputed_properties, sh_gradients, sh_hessians, - metatensor_module=metatensor.torch, ) @@ -216,21 +239,21 @@ def _wrap_into_tensor_map( for l in range(l_max + 1): # noqa E741 l_start = l**2 l_end = (l + 1) ** 2 - sh_values_block = metatensor.TensorBlock( + sh_values_block = TensorBlock( values=sh_values[:, l_start:l_end, None], samples=samples, components=[components[l]], properties=properties, ) if sh_gradients is not None: - sh_gradients_block = metatensor.TensorBlock( + sh_gradients_block = TensorBlock( values=sh_gradients[:, :, l_start:l_end, None], samples=samples, components=[xyz_components, components[l]], properties=properties, ) if sh_hessians is not None: - sh_hessians_block = metatensor.TensorBlock( + sh_hessians_block = TensorBlock( values=sh_hessians[:, :, :, l_start:l_end, None], samples=samples, components=[ @@ -245,4 +268,4 @@ def _wrap_into_tensor_map( blocks.append(sh_values_block) - return metatensor.TensorMap(keys=keys, blocks=blocks) + return TensorMap(keys=keys, blocks=blocks) From 5fdaca7dd5794f9199261a0b8f88ce4c16480b5e Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Tue, 7 Jan 2025 15:45:33 +0100 Subject: [PATCH 6/8] Documentation suggestion --- examples/metatensor/example.py | 1 + python/src/sphericart/metatensor.py | 13 +++++++++---- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/examples/metatensor/example.py b/examples/metatensor/example.py index 37d5f65bc..ef6357eee 100644 --- a/examples/metatensor/example.py +++ b/examples/metatensor/example.py @@ -31,6 +31,7 @@ calculator = sphericart.metatensor.SphericalHarmonics(l_max) spherical_harmonics = calculator.compute(xyz) +# for each block, the samples are the same as those of the `xyz` input for single_l in range(l_max + 1): spherical_single_l = spherical_harmonics.block({"o3_lambda": single_l}) diff --git a/python/src/sphericart/metatensor.py b/python/src/sphericart/metatensor.py index 29ed1be32..4feebf937 100644 --- a/python/src/sphericart/metatensor.py +++ b/python/src/sphericart/metatensor.py @@ -61,7 +61,8 @@ def compute(self, xyz: TensorMap) -> TensorMap: there must be one component named ``"xyz"`` with 3 values, and one property. :return: The spherical harmonics and their metadata as a - :py:class:`metatensor.TensorMap` + :py:class:`metatensor.TensorMap`. All ``samples`` in the output + ``TensorMap`` will be the same as those of the ``xyz`` input. """ _check_xyz_tensor_map(xyz) sh_values = self.raw_calculator.compute(xyz.block().values.squeeze(-1)) @@ -85,7 +86,9 @@ def compute_with_gradients(self, xyz: TensorMap) -> TensorMap: :return: The spherical harmonics and their metadata as a :py:class:`metatensor.TensorMap`. Each ``TensorBlock`` in the output - ``TensorMap`` will have a gradient with respect to the Cartesian positions. + ``TensorMap`` will have a gradient block with respect to the Cartesian + positions. All ``samples`` in the output ``TensorMap`` will be the same as + those of the ``xyz`` input. """ _check_xyz_tensor_map(xyz) sh_values, sh_gradients = self.raw_calculator.compute_with_gradients( @@ -113,8 +116,10 @@ def compute_with_hessians(self, xyz: TensorMap) -> TensorMap: :return: The spherical harmonics and their metadata as a :py:class:`metatensor.TensorMap`. Each ``TensorBlock`` in the output - ``TensorMap`` will have a gradient with respect to the Cartesian positions, - which will itself have a gradient with respect to the Cartesian positions. + ``TensorMap`` will have a gradient block with respect to the Cartesian + positions, which will itself have a gradient with respect to the Cartesian + positions. All ``samples`` in the output ``TensorMap`` will be the same as + those of the ``xyz`` input. """ _check_xyz_tensor_map(xyz) sh_values, sh_gradients, sh_hessians = ( From a03f0858b135a40258a4f26385e7fd0794241767 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Tue, 7 Jan 2025 15:58:46 +0100 Subject: [PATCH 7/8] Make it work on GPU --- .../python/sphericart/torch/metatensor.py | 42 +++++++++++++++++++ .../python/tests/test_metatensor.py | 7 +++- 2 files changed, 48 insertions(+), 1 deletion(-) diff --git a/sphericart-torch/python/sphericart/torch/metatensor.py b/sphericart-torch/python/sphericart/torch/metatensor.py index c9183fd98..f93a837be 100644 --- a/sphericart-torch/python/sphericart/torch/metatensor.py +++ b/sphericart-torch/python/sphericart/torch/metatensor.py @@ -60,6 +60,10 @@ def compute(self, xyz: TensorMap) -> TensorMap: See :py:meth:`sphericart.metatensor.SphericalHarmonics.compute`. """ _check_xyz_tensor_map(xyz) + device = xyz.device + if self.precomputed_keys.device != device: + self._send_precomputed_labels_to_device(device) + sh_values = self.raw_calculator.compute(xyz.block().values.squeeze(-1)) return _wrap_into_tensor_map( sh_values, @@ -76,6 +80,10 @@ def compute_with_gradients(self, xyz: TensorMap) -> TensorMap: See :py:meth:`sphericart.metatensor.SphericalHarmonics.compute_with_gradients`. """ _check_xyz_tensor_map(xyz) + device = xyz.device + if self.precomputed_keys.device != device: + self._send_precomputed_labels_to_device(device) + sh_values, sh_gradients = self.raw_calculator.compute_with_gradients( xyz.block().values.squeeze(-1) ) @@ -95,6 +103,10 @@ def compute_with_hessians(self, xyz: TensorMap) -> TensorMap: See :py:meth:`sphericart.metatensor.SphericalHarmonics.compute_with_hessians`. """ _check_xyz_tensor_map(xyz) + device = xyz.device + if self.precomputed_keys.device != device: + self._send_precomputed_labels_to_device(device) + sh_values, sh_gradients, sh_hessians = ( self.raw_calculator.compute_with_hessians(xyz.block().values.squeeze(-1)) ) @@ -110,6 +122,15 @@ def compute_with_hessians(self, xyz: TensorMap) -> TensorMap: sh_hessians, ) + def _send_precomputed_labels_to_device(self, device): + self.precomputed_keys = self.precomputed_keys.to(device) + self.precomputed_mu_components = [ + comp.to(device) for comp in self.precomputed_mu_components + ] + self.precomputed_xyz_components = self.precomputed_xyz_components.to(device) + self.precomputed_xyz_2_components = self.precomputed_xyz_2_components.to(device) + self.precomputed_properties = self.precomputed_properties.to(device) + class SolidHarmonics: """ @@ -156,6 +177,10 @@ def compute(self, xyz: TensorMap) -> TensorMap: See :py:meth:`sphericart.metatensor.SphericalHarmonics.compute`. """ _check_xyz_tensor_map(xyz) + device = xyz.device + if self.precomputed_keys.device != device: + self._send_precomputed_labels_to_device(device) + sh_values = self.raw_calculator.compute(xyz.block().values.squeeze(-1)) return _wrap_into_tensor_map( sh_values, @@ -172,6 +197,10 @@ def compute_with_gradients(self, xyz: TensorMap) -> TensorMap: See :py:meth:`sphericart.metatensor.SphericalHarmonics.compute_with_gradients`. """ _check_xyz_tensor_map(xyz) + device = xyz.device + if self.precomputed_keys.device != device: + self._send_precomputed_labels_to_device(device) + sh_values, sh_gradients = self.raw_calculator.compute_with_gradients( xyz.block().values.squeeze(-1) ) @@ -191,6 +220,10 @@ def compute_with_hessians(self, xyz: TensorMap) -> TensorMap: See :py:meth:`sphericart.metatensor.SphericalHarmonics.compute_with_hessians`. """ _check_xyz_tensor_map(xyz) + device = xyz.device + if self.precomputed_keys.device != device: + self._send_precomputed_labels_to_device(device) + sh_values, sh_gradients, sh_hessians = ( self.raw_calculator.compute_with_hessians(xyz.block().values.squeeze(-1)) ) @@ -206,6 +239,15 @@ def compute_with_hessians(self, xyz: TensorMap) -> TensorMap: sh_hessians, ) + def _send_precomputed_labels_to_device(self, device): + self.precomputed_keys = self.precomputed_keys.to(device) + self.precomputed_mu_components = [ + comp.to(device) for comp in self.precomputed_mu_components + ] + self.precomputed_xyz_components = self.precomputed_xyz_components.to(device) + self.precomputed_xyz_2_components = self.precomputed_xyz_2_components.to(device) + self.precomputed_properties = self.precomputed_properties.to(device) + def _check_xyz_tensor_map(xyz: TensorMap): if len(xyz.blocks()) != 1: diff --git a/sphericart-torch/python/tests/test_metatensor.py b/sphericart-torch/python/tests/test_metatensor.py index 8ec6b06e3..f57807232 100644 --- a/sphericart-torch/python/tests/test_metatensor.py +++ b/sphericart-torch/python/tests/test_metatensor.py @@ -33,7 +33,12 @@ def xyz(): ) -def test_metatensor(xyz): +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_metatensor(xyz, device): + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + + xyz = xyz.to(device) for l in range(L_MAX + 1): # noqa E741 calculator_spherical = sphericart.torch.metatensor.SphericalHarmonics(l) calculator_solid = sphericart.torch.metatensor.SolidHarmonics(l) From dd41b3699d7a17a1053277f093d4b61075beca35 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Sat, 11 Jan 2025 07:37:35 +0100 Subject: [PATCH 8/8] Fix `metatensor-torch` documentation --- docs/src/conf.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/src/conf.py b/docs/src/conf.py index 7b01c9e31..a9bd2496c 100644 --- a/docs/src/conf.py +++ b/docs/src/conf.py @@ -7,6 +7,9 @@ sys.path.insert(0, os.path.abspath(".")) sys.path.insert(0, ROOT) +# When importing metatensor-torch, this will change the definition of the classes +# to include the documentation +os.environ["METATENSOR_IMPORT_FOR_SPHINX"] = "1" # -- Project information -----------------------------------------------------