Skip to content

Commit

Permalink
fix binding version check
Browse files Browse the repository at this point in the history
  • Loading branch information
leofang committed Jan 12, 2025
1 parent b647dfc commit f47ac40
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 12 deletions.
6 changes: 3 additions & 3 deletions cuda_bindings/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,9 +308,9 @@ def build_extension(self, ext):
# Allow extensions to discover libraries at runtime
# relative their wheels installation.
if ext.name == "cuda.bindings._bindings.cynvrtc":
ldflag = f"-Wl,--disable-new-dtags,-rpath,$ORIGIN/../../../nvidia/cuda_nvrtc/lib"
ldflag = "-Wl,--disable-new-dtags,-rpath,$ORIGIN/../../../nvidia/cuda_nvrtc/lib"
elif ext.name == "cuda.bindings._internal.nvjitlink":
ldflag = f"-Wl,--disable-new-dtags,-rpath,$ORIGIN/../../../nvidia/nvjitlink/lib"
ldflag = "-Wl,--disable-new-dtags,-rpath,$ORIGIN/../../../nvidia/nvjitlink/lib"
else:
ldflag = None

Expand All @@ -326,7 +326,7 @@ def build_extension(self, ext):
cmdclass = {
"bdist_wheel": WheelsBuildExtensions,
"build_ext": ParallelBuildExtensions,
}
}

# ----------------------------------------------------------------------
# Setup
Expand Down
4 changes: 1 addition & 3 deletions cuda_bindings/tests/test_nvjitlink.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,7 @@ def ptx_header(version, arch):
def check_nvjitlink_usable():
from cuda.bindings._internal import nvjitlink as inner_nvjitlink

if inner_nvjitlink._inspect_function_pointer("__nvJitLinkVersion") == 0:
return False
return True
return inner_nvjitlink._inspect_function_pointer("__nvJitLinkVersion") != 0


pytestmark = pytest.mark.skipif(
Expand Down
5 changes: 2 additions & 3 deletions cuda_core/cuda/core/experimental/_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE

import importlib.metadata
from dataclasses import dataclass
from typing import Optional, Union

Expand All @@ -11,7 +10,7 @@
from cuda.core.experimental._kernel_arg_handler import ParamHolder
from cuda.core.experimental._module import Kernel
from cuda.core.experimental._stream import Stream
from cuda.core.experimental._utils import CUDAError, check_or_create_options, handle_return
from cuda.core.experimental._utils import CUDAError, check_or_create_options, get_binding_version, handle_return

# TODO: revisit this treatment for py313t builds
_inited = False
Expand All @@ -25,7 +24,7 @@ def _lazy_init():

global _use_ex
# binding availability depends on cuda-python version
_py_major_minor = tuple(int(v) for v in (importlib.metadata.version("cuda-python").split(".")[:2]))
_py_major_minor = get_binding_version()
_driver_ver = handle_return(cuda.cuDriverGetVersion())
_use_ex = (_driver_ver >= 11080) and (_py_major_minor >= (11, 8))
_inited = True
Expand Down
5 changes: 2 additions & 3 deletions cuda_core/cuda/core/experimental/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
#
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE

import importlib.metadata

from cuda import cuda
from cuda.core.experimental._utils import handle_return, precondition
from cuda.core.experimental._utils import get_binding_version, handle_return, precondition

_backend = {
"old": {
Expand All @@ -30,7 +29,7 @@ def _lazy_init():

global _py_major_ver, _driver_ver, _kernel_ctypes
# binding availability depends on cuda-python version
_py_major_ver = int(importlib.metadata.version("cuda-python").split(".")[0])
_py_major_ver, _ = get_binding_version()
if _py_major_ver >= 12:
_backend["new"] = {
"file": cuda.cuLibraryLoadFromFile,
Expand Down
9 changes: 9 additions & 0 deletions cuda_core/cuda/core/experimental/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE

import functools
import importlib.metadata
from collections import namedtuple
from typing import Callable, Dict

Expand Down Expand Up @@ -134,3 +135,11 @@ def get_device_from_ctx(ctx_handle) -> int:
assert ctx_handle == handle_return(cuda.cuCtxPopCurrent())
handle_return(cuda.cuCtxPushCurrent(prev_ctx))
return device_id


def get_binding_version():
try:
major_minor = importlib.metadata.version("cuda-bindings").split(".")[:2]
except importlib.metadata.PackageNotFoundError:
major_minor = importlib.metadata.version("cuda-python").split(".")[:2]
return tuple(int(v) for v in major_minor)

0 comments on commit f47ac40

Please sign in to comment.