diff --git a/slangpy/tests/test_pytorch.py b/slangpy/tests/test_pytorch.py index bf60b05..ef0ba16 100644 --- a/slangpy/tests/test_pytorch.py +++ b/slangpy/tests/test_pytorch.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import pytest from slangpy.backend import DeviceType, Device -from slangpy.builtin.torch import TorchModule +from slangpy.torchintegration import TorchModule import slangpy.tests.helpers as helpers import hashlib import os @@ -35,7 +35,7 @@ def get_module(device: Device): module = device.load_module_from_source( hashlib.sha256(module_source.encode()).hexdigest()[0:16], module_source ) - return TorchModule(module) + return TorchModule.load_from_module(device, module) def compare_tensors(a: torch.Tensor, b: torch.Tensor): @@ -61,7 +61,7 @@ def test_missing_torch_context(device_type: DeviceType): module = helpers.create_module(device, TEST_CODE) a = torch.randn((8, 5), dtype=torch.float32, device=torch.device('cuda'), requires_grad=True) - with pytest.raises(RuntimeError, match=r"Failed to access current torch context.*"): + with pytest.raises(ValueError, match=r"Tensor types can not be directly passed to SlangPy"): b = module.square(a) diff --git a/slangpy/torchintegration/torchfunction.py b/slangpy/torchintegration/torchfunction.py index 4a1a634..6e4c190 100644 --- a/slangpy/torchintegration/torchfunction.py +++ b/slangpy/torchintegration/torchfunction.py @@ -7,7 +7,7 @@ from slangpy.torchintegration.wrappedtensor import WrappedTensor from slangpy.core.function import Function, IThis import slangpy.reflection as kfr -from slangpy.backend import (FunctionReflection, TypeConformance) +from slangpy.backend import (FunctionReflection, TypeConformance, Device) if TYPE_CHECKING: from slangpy.core.module import Module @@ -15,6 +15,12 @@ from slangpy.torchintegration.torchstruct import TorchStruct +def check_cuda_enabled(device: Device): + if not device.supports_cuda_interop: + raise RuntimeError("Cuda interop must be enabled for torch support " + "create SGL device with Device..., enable_cuda_interop=True") + + def unpack_arg(arg: Any, tensors: list[torch.Tensor]) -> Any: if hasattr(arg, "get_this"): arg = arg.get_this() @@ -191,6 +197,7 @@ class TorchFunction(torch.nn.Module): def __init__(self, function: Function): super().__init__() + check_cuda_enabled(function.module.device) self.function = function.return_type(WrappedTensor) def forward(self, *args: Any, **kwargs: Any): diff --git a/slangpy/torchintegration/torchmodule.py b/slangpy/torchintegration/torchmodule.py index bb29fee..3f5f005 100644 --- a/slangpy/torchintegration/torchmodule.py +++ b/slangpy/torchintegration/torchmodule.py @@ -7,7 +7,7 @@ from slangpy.backend import SlangModule, Device from slangpy.core.module import Module -from slangpy.torchintegration.torchfunction import TorchFunction +from slangpy.torchintegration.torchfunction import TorchFunction, check_cuda_enabled from slangpy.torchintegration.torchstruct import TorchStruct @@ -18,6 +18,7 @@ class TorchModule: def __init__(self, module: 'Module'): super().__init__() + check_cuda_enabled(module.device) self.module = module @staticmethod diff --git a/slangpy/torchintegration/torchstruct.py b/slangpy/torchintegration/torchstruct.py index fd4520f..9ff0c3f 100644 --- a/slangpy/torchintegration/torchstruct.py +++ b/slangpy/torchintegration/torchstruct.py @@ -3,7 +3,7 @@ from slangpy.core.function import Function from slangpy.core.struct import Struct -from slangpy.torchintegration.torchfunction import TorchFunction +from slangpy.torchintegration.torchfunction import TorchFunction, check_cuda_enabled class TorchStruct: @@ -14,6 +14,7 @@ class TorchStruct: def __init__(self, struct: Struct) -> None: super().__init__() + check_cuda_enabled(struct.module.device) self.struct = struct @property diff --git a/slangpy/torchintegration/wrappedtensor.py b/slangpy/torchintegration/wrappedtensor.py index 843a6a7..21067a2 100644 --- a/slangpy/torchintegration/wrappedtensor.py +++ b/slangpy/torchintegration/wrappedtensor.py @@ -123,7 +123,11 @@ def create_calldata(self, context: CallContext, binding: 'BoundVariableRuntime', return result def create_output(self, context: CallContext, binding: BoundVariableRuntime) -> Any: - return WrappedTensor(torch.empty(context.call_shape.as_tuple(), dtype=self.torch_dtype, device=torch.device('cuda'))) + # Overall shape of tensor must contain the call, plus the shape of the slang datatype + # i.e. if a float tensor is to store 4x4 matrix results, it needs the shape to be + # extended by (4,4) + combined_shape = context.call_shape.as_tuple() + self.slang_dtype.shape.as_tuple() + return WrappedTensor(torch.empty(combined_shape, dtype=self.torch_dtype, device=torch.device('cuda'))) def read_output(self, context: CallContext, binding: BoundVariableRuntime, data: Any) -> Any: return data @@ -135,7 +139,7 @@ def create_tensor_marshall(layout: SlangProgramLayout, value: Any): return tr.get_or_create_type(layout, ValueRef, value) else: slang_dtype = value.slang_type - torch_dtype = _slang_dtype_to_torch(slang_dtype) + torch_dtype = _slang_dtype_to_torch(innermost_type(slang_dtype)) if torch_dtype is None: raise ValueError(f"Unsupported slang type {value.slang_type}") marshall = WrappedTensorMarshall(layout, torch_dtype, slang_dtype,