Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

metatensor interface #158

Merged
merged 9 commits into from
Jan 11, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@ furo # sphinx theme
# jax[cpu], because python -m pip install jax, which would be triggered
# by the main package's dependencies, does not install jaxlib
jax[cpu] >= 0.4.18

# metatensor and metatensor-torch for the metatensor API
metatensor-torch
1 change: 1 addition & 0 deletions docs/src/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ different languages and frameworks it supports.
python-api
pytorch-api
jax-api
metatensor-api

Although the Julia API is not fully documented yet, basic usage examples are available
`here <https://github.com/lab-cosmo/sphericart/blob/main/julia/README.md>`_.
1 change: 1 addition & 0 deletions docs/src/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
"python": ("https://docs.python.org/3", None),
"torch": ("https://pytorch.org/docs/stable/", None),
"e3nn": ("https://docs.e3nn.org/en/latest/", None),
"metatensor": ("https://docs.metatensor.org/latest/", None),
}

html_theme = "furo"
Expand Down
1 change: 1 addition & 0 deletions docs/src/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ floating-point arithmetics, and they evaluate the mean relative error between th
pytorch-examples
jax-examples
spherical-complex
metatensor-examples

Although comprehensive Julia examples are not fully available yet, basic usage is illustrated
`here <https://github.com/lab-cosmo/sphericart/blob/main/julia/README.md>`_.
2 changes: 1 addition & 1 deletion docs/src/jax-api.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
JAX API
===========
=======

The `sphericart.jax` module aims to provide a functional-style and
`JAX`-friendly framework. As a result, it does not follow the same syntax as
Expand Down
28 changes: 28 additions & 0 deletions docs/src/metatensor-api.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
Metatensor API
==============

``sphericart`` can be used in conjunction with
`metatensor <https://docs.metatensor.org/latest/index.html>`_ in order to attach
metadata to inputs and outputs, as well as to naturally obtain spherical harmonics,
gradients and Hessians in a single object.

Here is the API reference for the ``sphericart.metatensor`` and
``sphericart.torch.metatensor`` modules.

sphericart.metatensor
---------------------

.. autoclass:: sphericart.metatensor.SphericalHarmonics
:members:

.. autoclass:: sphericart.metatensor.SolidHarmonics
:members:

sphericart.torch.metatensor
---------------------------

.. autoclass:: sphericart.torch.metatensor.SphericalHarmonics
:members:

.. autoclass:: sphericart.torch.metatensor.SolidHarmonics
:members:
13 changes: 13 additions & 0 deletions docs/src/metatensor-examples.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
Using sphericart with metatensor
--------------------------------

``sphericart`` can be used in conjunction with
`metatensor <https://docs.metatensor.org/latest/index.html>`_ in order to attach
metadata to inputs and outputs, as well as to naturally obtain spherical harmonics,
gradients and Hessians in a single object.

This example shows how to use the ``sphericart.metatensor`` module to compute
spherical harmonics, their gradients and their Hessians.

.. literalinclude:: ../../examples/metatensor/example.py
:language: python
54 changes: 54 additions & 0 deletions examples/metatensor/example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import numpy as np
from metatensor import Labels, TensorBlock, TensorMap

import sphericart
import sphericart.metatensor


l_max = 15
n_samples = 100

xyz = TensorMap(
keys=Labels.single(),
blocks=[
TensorBlock(
values=np.random.rand(n_samples, 3, 1),
samples=Labels(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we indicate somewhere (maybe here and in the docstring?) that the output will have the same samples as the input?

names=["sample"],
values=np.arange(n_samples).reshape(-1, 1),
),
components=[
Labels(
names=["xyz"],
values=np.arange(3).reshape(-1, 1),
)
],
properties=Labels.single(),
)
],
)

calculator = sphericart.metatensor.SphericalHarmonics(l_max)

spherical_harmonics = calculator.compute(xyz)

for single_l in range(l_max + 1):
spherical_single_l = spherical_harmonics.block({"o3_lambda": single_l})

# check values against pure sphericart
assert np.allclose(
spherical_single_l.values.squeeze(-1),
sphericart.SphericalHarmonics(single_l).compute(
xyz.block().values.squeeze(-1)
)[:, single_l**2 : (single_l + 1) ** 2],
)

# further example: obtaining gradients of l = 2 spherical harmonics
spherical_harmonics = calculator.compute_with_gradients(xyz)
l_2_gradients = spherical_harmonics.block({"o3_lambda": 2}).gradient("positions")

# further example: obtaining Hessians of l = 2 spherical harmonics
spherical_harmonics = calculator.compute_with_hessians(xyz)
l_2_hessians = spherical_harmonics.block(
{"o3_lambda": 2}
).gradient("positions").gradient("positions")
Loading