Skip to content

Commit

Permalink
configure_threadpool
Browse files Browse the repository at this point in the history
  • Loading branch information
clbarnes committed Apr 12, 2023
1 parent 893db8f commit 88e63fd
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 4 deletions.
1 change: 1 addition & 0 deletions ncollpyde/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .main import N_CPUS # noqa: F401
from .main import PRECISION # noqa: F401
from .main import Volume # noqa: F401
from .main import configure_threadpool # noqa: F401
from .ncollpyde import n_threads # noqa: F401
from .ncollpyde import _version

Expand Down
33 changes: 32 additions & 1 deletion ncollpyde/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
except ImportError:
trimesh = None

from .ncollpyde import TriMeshWrapper, _index, _precision
from .ncollpyde import (
TriMeshWrapper,
_index,
_precision,
configure_threadpool as _configure_threadpool,
)

if TYPE_CHECKING:
import meshio
Expand All @@ -29,6 +34,32 @@
INDEX = np.dtype(_index())


def configure_threadpool(n_threads: Optional[int], name_prefix: Optional[str]):
"""Configure the thread pool used for parallelisation.
Must be called a maximum of once,
and only before the first parallelised ncollpyde query.
This will be used for all parallelised ncollpyde queries.
Parameters
----------
n_threads : Optional[int]
Number of threads to use.
If None or 0, will use the default
(see https://docs.rs/rayon/latest/rayon/struct.ThreadPoolBuilder.html#method.num_threads).
name_prefix : Optional[str]
How to name threads created by this library.
Will be suffixed with the thread index.
If not given, will use the rayon default.
Raises
------
RuntimeError
If the pool could not be built for any reason.
"""
_configure_threadpool(n_threads, name_prefix)


def interpret_threads(threads: Optional[Union[int, bool]], default=DEFAULT_THREADS):
if isinstance(threads, bool):
return threads
Expand Down
3 changes: 2 additions & 1 deletion ncollpyde/ncollpyde.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Tuple
from typing import List, Tuple, Optional

import numpy as np
import numpy.typing as npt
Expand All @@ -7,6 +7,7 @@ def _precision() -> str: ...
def _index() -> str: ...
def _version() -> str: ...
def n_threads() -> int: ...
def configure_threadpool(n_threads: Optional[int], name_prefix: Optional[str]): ...

Points = npt.NDArray[np.float64]
Indices = npt.NDArray[np.uint32]
Expand Down
22 changes: 21 additions & 1 deletion src/interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@ use numpy::ndarray::{Array, Zip};
use numpy::{IntoPyArray, PyArray1, PyArray2, PyReadonlyArray2};
use parry3d_f64::math::{Point, Vector};
use parry3d_f64::shape::TriMesh;
use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::*;
use rand::SeedableRng;
use rand_pcg::Pcg64Mcg;
use rayon::prelude::*;
use rayon::{prelude::*, ThreadPoolBuilder};

use crate::utils::{dist_from_mesh, mesh_contains_point, points_cross_mesh, random_dir, Precision};

Expand Down Expand Up @@ -247,5 +248,24 @@ pub fn ncollpyde(_py: Python, m: &PyModule) -> PyResult<()> {
rayon::current_num_threads()
}

#[pyfn(m)]
#[pyo3(name = "configure_threadpool")]
pub fn configure_threadpool(
_py: Python,
n_threads: Option<usize>,
name_prefix: Option<String>,
) -> PyResult<()> {
let mut builder = ThreadPoolBuilder::new();
if let Some(n) = n_threads {
builder = builder.num_threads(n);
}
if let Some(p) = name_prefix {
builder = builder.thread_name(move |idx| format!("{p}{idx}"));
}
builder
.build_global()
.map_err(|e| PyRuntimeError::new_err(format!("Error building threadpool: {e}")))
}

Ok(())
}
25 changes: 24 additions & 1 deletion tests/test_ncollpyde.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

"""Tests for `ncollpyde` package."""
from itertools import product
import sys
import subprocess as sp

import numpy as np
import pytest
Expand All @@ -12,7 +14,7 @@
except ImportError:
trimesh = None

from ncollpyde import PRECISION, Volume
from ncollpyde import PRECISION, Volume, configure_threadpool

points_expected = [
([-2.3051376, -4.1556454, 1.9047838], True), # internal
Expand Down Expand Up @@ -284,3 +286,24 @@ def test_distance_unsigned(simple_volume, point, expected, signed):
assert np.allclose(
simple_volume.distance([point], signed=signed), np.asarray([expected])
)


@pytest.mark.xfail(reason="Other tests already configure the pool")
def test_configure_threadpool():
configure_threadpool(2, "prefix")


def test_configure_threadpool_subprocess():
# must be run in its own interpreter so that pool is not already configured
cmd = (
"from ncollpyde import configure_threadpool; configure_threadpool(2, 'prefix')"
)
args = [sys.executable, "-c", cmd]
assert sp.call(args) == 0


def test_configure_threadpool_twice():
# configure_threadpool(2, "prefix")
with pytest.raises(RuntimeError):
configure_threadpool(3, "prefix")
configure_threadpool(3, "prefix")

0 comments on commit 88e63fd

Please sign in to comment.