diff --git a/cuda_bindings/setup.py b/cuda_bindings/setup.py index 7c8725ee..38882778 100644 --- a/cuda_bindings/setup.py +++ b/cuda_bindings/setup.py @@ -308,9 +308,9 @@ def build_extension(self, ext): # Allow extensions to discover libraries at runtime # relative their wheels installation. if ext.name == "cuda.bindings._bindings.cynvrtc": - ldflag = f"-Wl,--disable-new-dtags,-rpath,$ORIGIN/../../../nvidia/cuda_nvrtc/lib" + ldflag = "-Wl,--disable-new-dtags,-rpath,$ORIGIN/../../../nvidia/cuda_nvrtc/lib" elif ext.name == "cuda.bindings._internal.nvjitlink": - ldflag = f"-Wl,--disable-new-dtags,-rpath,$ORIGIN/../../../nvidia/nvjitlink/lib" + ldflag = "-Wl,--disable-new-dtags,-rpath,$ORIGIN/../../../nvidia/nvjitlink/lib" else: ldflag = None @@ -326,7 +326,7 @@ def build_extension(self, ext): cmdclass = { "bdist_wheel": WheelsBuildExtensions, "build_ext": ParallelBuildExtensions, - } +} # ---------------------------------------------------------------------- # Setup diff --git a/cuda_bindings/tests/test_nvjitlink.py b/cuda_bindings/tests/test_nvjitlink.py index 4a2c1a6b..000ef52e 100644 --- a/cuda_bindings/tests/test_nvjitlink.py +++ b/cuda_bindings/tests/test_nvjitlink.py @@ -55,9 +55,7 @@ def ptx_header(version, arch): def check_nvjitlink_usable(): from cuda.bindings._internal import nvjitlink as inner_nvjitlink - if inner_nvjitlink._inspect_function_pointer("__nvJitLinkVersion") == 0: - return False - return True + return inner_nvjitlink._inspect_function_pointer("__nvJitLinkVersion") != 0 pytestmark = pytest.mark.skipif( diff --git a/cuda_core/cuda/core/experimental/_launcher.py b/cuda_core/cuda/core/experimental/_launcher.py index 91379d57..91b6856d 100644 --- a/cuda_core/cuda/core/experimental/_launcher.py +++ b/cuda_core/cuda/core/experimental/_launcher.py @@ -2,7 +2,6 @@ # # SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE -import importlib.metadata from dataclasses import dataclass from typing import Optional, Union @@ -11,7 +10,7 @@ from cuda.core.experimental._kernel_arg_handler import ParamHolder from cuda.core.experimental._module import Kernel from cuda.core.experimental._stream import Stream -from cuda.core.experimental._utils import CUDAError, check_or_create_options, handle_return +from cuda.core.experimental._utils import CUDAError, check_or_create_options, get_binding_version, handle_return # TODO: revisit this treatment for py313t builds _inited = False @@ -25,7 +24,7 @@ def _lazy_init(): global _use_ex # binding availability depends on cuda-python version - _py_major_minor = tuple(int(v) for v in (importlib.metadata.version("cuda-python").split(".")[:2])) + _py_major_minor = get_binding_version() _driver_ver = handle_return(cuda.cuDriverGetVersion()) _use_ex = (_driver_ver >= 11080) and (_py_major_minor >= (11, 8)) _inited = True diff --git a/cuda_core/cuda/core/experimental/_module.py b/cuda_core/cuda/core/experimental/_module.py index 5dc2801b..89f31b9f 100644 --- a/cuda_core/cuda/core/experimental/_module.py +++ b/cuda_core/cuda/core/experimental/_module.py @@ -2,10 +2,9 @@ # # SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE -import importlib.metadata from cuda import cuda -from cuda.core.experimental._utils import handle_return, precondition +from cuda.core.experimental._utils import get_binding_version, handle_return, precondition _backend = { "old": { @@ -30,7 +29,7 @@ def _lazy_init(): global _py_major_ver, _driver_ver, _kernel_ctypes # binding availability depends on cuda-python version - _py_major_ver = int(importlib.metadata.version("cuda-python").split(".")[0]) + _py_major_ver, _ = get_binding_version() if _py_major_ver >= 12: _backend["new"] = { "file": cuda.cuLibraryLoadFromFile, diff --git a/cuda_core/cuda/core/experimental/_utils.py b/cuda_core/cuda/core/experimental/_utils.py index 9cb47a33..b672b4ac 100644 --- a/cuda_core/cuda/core/experimental/_utils.py +++ b/cuda_core/cuda/core/experimental/_utils.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE import functools +import importlib.metadata from collections import namedtuple from typing import Callable, Dict @@ -134,3 +135,11 @@ def get_device_from_ctx(ctx_handle) -> int: assert ctx_handle == handle_return(cuda.cuCtxPopCurrent()) handle_return(cuda.cuCtxPushCurrent(prev_ctx)) return device_id + + +def get_binding_version(): + try: + major_minor = importlib.metadata.version("cuda-bindings").split(".")[:2] + except importlib.metadata.PackageNotFoundError: + major_minor = importlib.metadata.version("cuda-python").split(".")[:2] + return tuple(int(v) for v in major_minor)