From 80c749e03865024562c1dffb5eec511842307c1e Mon Sep 17 00:00:00 2001 From: Domagoj Fijan Date: Sat, 16 Nov 2024 14:40:41 -0500 Subject: [PATCH] replace numba with numpy --- .github/workflows/requirements-test.txt | 1 - dupin/data/spatial.py | 19 +++---------------- 2 files changed, 3 insertions(+), 17 deletions(-) diff --git a/.github/workflows/requirements-test.txt b/.github/workflows/requirements-test.txt index c50ca99..efa5242 100644 --- a/.github/workflows/requirements-test.txt +++ b/.github/workflows/requirements-test.txt @@ -7,5 +7,4 @@ numpy == 2.0 ruptures == 1.1.9 scikit-learn == 1.5 pandas == 2.2.2 -numba == 0.60 xarray == 2024.6 diff --git a/dupin/data/spatial.py b/dupin/data/spatial.py index 25c7226..16b4317 100644 --- a/dupin/data/spatial.py +++ b/dupin/data/spatial.py @@ -6,24 +6,11 @@ from . import base -def _njit(*args, **kwargs): - """Allow for JIT when numba is found.""" - try: - import numba - except ImportError: - return lambda x: x - return numba.njit(*args, **kwargs) - - -@_njit() def _freud_neighbor_summing( - arr: np.ndarray, - particle_index: np.ndarray, - neighbor_index: np.ndarray, - base: np.ndarray, + arr, particle_index, neighbor_index, base ) -> np.ndarray: - for i, j in zip(particle_index, neighbor_index): - base[i] += arr[j] + np.add.at(base, particle_index, arr[neighbor_index]) + return base class NeighborAveraging(base.DataMap):