Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gh-439: infrastructure for testing array API compatibility #459

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
3 changes: 2 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,11 @@ jobs:
env:
FORCE_COLOR: 1

- name: Run tests and generate coverage report
- name: Run tests wih every array backend and generate coverage report
run: nox -s coverage-${{ matrix.python-version }} --verbose
env:
FORCE_COLOR: 1
GLASS_ARRAY_BACKEND: all

- name: Coveralls requires XML report
run: coverage xml
Expand Down
57 changes: 55 additions & 2 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,42 @@ following way -
python -m pytest --cov --doctest-plus
```

### Array API tests

One can specify a particular array backend for testing by setting the
`GLASS_ARRAY_BACKEND` environment variable. The default array backend is NumPy.
_GLASS_ can be tested with every supported array library available in the
environment by setting `GLASS_ARRAY_BACKEND` to `all`. The testing framework
only installs NumPy automatically; hence, remaining array libraries should
either be installed manually or developers should use `Nox`.

```bash
# run tests using numpy
python -m pytest
GLASS_ARRAY_BACKEND=numpy python -m pytest
# run tests using array_api_strict (should be installed manually)
GLASS_ARRAY_BACKEND=array_api_strict python -m pytest
# run tests using jax (should be installed manually)
GLASS_ARRAY_BACKEND=jax python -m pytest
# run tests using every supported array library available in the environment
GLASS_ARRAY_BACKEND=all python -m pytest
```

Moreover, one can mark a test to be compatible with the array API standard by
decorating it with `@array_api_compatible`. This will `parameterize` the test to
run on every array library specified through `GLASS_ARRAY_BACKEND` -

```python
import types
from tests.conftest import array_api_compatible


@array_api_compatible
def test_something(xp: types.ModuleType):
# use `xp.` to access the array library functionality
...
```

## Documenting

_GLASS_'s documentation is mainly written in the form of
Expand Down Expand Up @@ -166,11 +202,28 @@ nox -s tests
Only `tests`, `coverage`, and the `doctests` session run on all supported Python
versions by default.

To specify a particular Python version (for example `3.11`), use the following
To specify a particular Python version (for example `3.13`), use the following
syntax -

```bash
nox -s tests-3.11
nox -s tests-3.13
```

One can specify a particular array backend for testing by setting the
`GLASS_ARRAY_BACKEND` environment variable. The default array backend is NumPy.
_GLASS_ can be tested with every supported array library by setting
`GLASS_ARRAY_BACKEND` to `all`.

```bash
# run tests using numpy
nox -s tests-3.13
GLASS_ARRAY_BACKEND=numpy nox -s tests-3.13
# run tests using array_api_strict
GLASS_ARRAY_BACKEND=array_api_strict nox -s tests-3.13
# run tests using jax
GLASS_ARRAY_BACKEND=jax nox -s tests-3.13
# run tests using every supported array library
GLASS_ARRAY_BACKEND=all nox -s tests-3.13
```

The following command can be used to deploy the docs on `localhost` -
Expand Down
14 changes: 14 additions & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import os
from pathlib import Path

import nox
Expand All @@ -20,6 +21,10 @@
"3.12",
"3.13",
]
ARRAY_BACKENDS = {
"array_api_strict": "array_api_strict>=2",
"jax": "jax>=0.4.32",
}


@nox.session
Expand All @@ -33,6 +38,15 @@ def lint(session: nox.Session) -> None:
def tests(session: nox.Session) -> None:
"""Run the unit tests."""
session.install("-c", ".github/test-constraints.txt", "-e", ".[test]")

array_backend = os.environ.get("GLASS_ARRAY_BACKEND")
if array_backend == "array_api_strict":
session.install(ARRAY_BACKENDS["array_api_strict"])
elif array_backend == "jax":
session.install(ARRAY_BACKENDS["jax"])
elif array_backend == "all":
session.install(*ARRAY_BACKENDS.values())

session.run(
"pytest",
*session.posargs,
Expand Down
93 changes: 93 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,105 @@
import contextlib
import importlib.metadata
import os
import types

import numpy as np
import numpy.typing as npt
import packaging.version
import pytest

from cosmology import Cosmology

from glass import RadialWindow

# Handling of array backends, inspired by-
# https://github.com/scipy/scipy/blob/36e349b6afbea057cb713fc314296f10d55194cc/scipy/conftest.py#L139

# environment variable to specify array backends for testing
# can be:
# a particular array library (numpy, jax, array_api_strict, ...)
# all (try finding every supported array library available in the environment)
GLASS_ARRAY_BACKEND: str = os.environ.get("GLASS_ARRAY_BACKEND", "")


def _check_version(lib: str, array_api_compliant_version: str) -> None:
"""
Check if installed library's version is compliant with the array API standard.

Parameters
----------
lib
name of the library.
array_api_compliant_version
version of the library compliant with the array API standard.

Raises
------
ImportError
If the installed version is not compliant with the array API standard.
"""
lib_version = packaging.version.Version(importlib.metadata.version(lib))
if lib_version < packaging.version.Version(array_api_compliant_version):
msg = f"{lib} must be >= {array_api_compliant_version}; found {lib_version}"
raise ImportError(msg)


def _import_and_add_numpy(xp_available_backends: dict[str, types.ModuleType]) -> None:
"""Add numpy to the backends dictionary."""
_check_version("numpy", "2.1.0")
Saransh-cpp marked this conversation as resolved.
Show resolved Hide resolved
xp_available_backends.update({"numpy": np})


def _import_and_add_array_api_strict(
xp_available_backends: dict[str, types.ModuleType],
) -> None:
"""Add array_api_strict to the backends dictionary."""
import array_api_strict

_check_version("array_api_strict", "2.0.0")
xp_available_backends.update({"array_api_strict": array_api_strict})
array_api_strict.set_array_api_strict_flags(api_version="2023.12")


def _import_and_add_jax(xp_available_backends: dict[str, types.ModuleType]) -> None:
"""Add jax to the backends dictionary."""
import jax

_check_version("jax", "0.4.32")
xp_available_backends.update({"jax.numpy": jax.numpy})
# enable 64 bit numbers
jax.config.update("jax_enable_x64", val=True)


# a dictionary with all array backends to test
xp_available_backends: dict[str, types.ModuleType] = {}

# if no backend passed, use numpy by default
if not GLASS_ARRAY_BACKEND or GLASS_ARRAY_BACKEND == "numpy":
_import_and_add_numpy(xp_available_backends)
elif GLASS_ARRAY_BACKEND == "array_api_strict":
_import_and_add_array_api_strict(xp_available_backends)
elif GLASS_ARRAY_BACKEND == "jax":
_import_and_add_jax(xp_available_backends)
# if all, try importing every backend
elif GLASS_ARRAY_BACKEND == "all":
with contextlib.suppress(ImportError):
_import_and_add_numpy(xp_available_backends)

with contextlib.suppress(ImportError):
_import_and_add_array_api_strict(xp_available_backends)

with contextlib.suppress(ImportError):
_import_and_add_jax(xp_available_backends)
else:
msg = f"unsupported array backend: {GLASS_ARRAY_BACKEND}"
raise ValueError(msg)

# use this as a decorator for tests involving array API compatible functions
array_api_compatible = pytest.mark.parametrize("xp", xp_available_backends.values())
Comment on lines +98 to +99
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if all of this stuff should go in its own file, so that only this line features here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my opinion all the common testing utilities should go directly into conftest.py because that is the sole use of this file (even if it gets crowded).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough. I'm just concerned that it can make it harder for others to contribute. Would be good for it to be as neat and granular as possible.



# Pytest fixtures
@pytest.fixture(scope="session")
def cosmo() -> Cosmology:
class MockCosmology:
Expand Down
Loading