Skip to content

Commit

Permalink
All torch tests passing
Browse files Browse the repository at this point in the history
  • Loading branch information
ccummingsNV committed Jan 19, 2025
1 parent d7ba3ef commit e7e2c76
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 8 deletions.
6 changes: 3 additions & 3 deletions slangpy/tests/test_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
from slangpy.backend import DeviceType, Device
from slangpy.builtin.torch import TorchModule
from slangpy.torchintegration import TorchModule
import slangpy.tests.helpers as helpers
import hashlib
import os
Expand Down Expand Up @@ -35,7 +35,7 @@ def get_module(device: Device):
module = device.load_module_from_source(
hashlib.sha256(module_source.encode()).hexdigest()[0:16], module_source
)
return TorchModule(module)
return TorchModule.load_from_module(device, module)


def compare_tensors(a: torch.Tensor, b: torch.Tensor):
Expand All @@ -61,7 +61,7 @@ def test_missing_torch_context(device_type: DeviceType):
module = helpers.create_module(device, TEST_CODE)

a = torch.randn((8, 5), dtype=torch.float32, device=torch.device('cuda'), requires_grad=True)
with pytest.raises(RuntimeError, match=r"Failed to access current torch context.*"):
with pytest.raises(ValueError, match=r"Tensor types can not be directly passed to SlangPy"):
b = module.square(a)


Expand Down
9 changes: 8 additions & 1 deletion slangpy/torchintegration/torchfunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,20 @@
from slangpy.torchintegration.wrappedtensor import WrappedTensor
from slangpy.core.function import Function, IThis
import slangpy.reflection as kfr
from slangpy.backend import (FunctionReflection, TypeConformance)
from slangpy.backend import (FunctionReflection, TypeConformance, Device)

if TYPE_CHECKING:
from slangpy.core.module import Module
from slangpy.core.struct import Struct
from slangpy.torchintegration.torchstruct import TorchStruct


def check_cuda_enabled(device: Device):
if not device.supports_cuda_interop:
raise RuntimeError("Cuda interop must be enabled for torch support "
"create SGL device with Device..., enable_cuda_interop=True")


def unpack_arg(arg: Any, tensors: list[torch.Tensor]) -> Any:
if hasattr(arg, "get_this"):
arg = arg.get_this()
Expand Down Expand Up @@ -191,6 +197,7 @@ class TorchFunction(torch.nn.Module):

def __init__(self, function: Function):
super().__init__()
check_cuda_enabled(function.module.device)
self.function = function.return_type(WrappedTensor)

def forward(self, *args: Any, **kwargs: Any):
Expand Down
3 changes: 2 additions & 1 deletion slangpy/torchintegration/torchmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from slangpy.backend import SlangModule, Device
from slangpy.core.module import Module

from slangpy.torchintegration.torchfunction import TorchFunction
from slangpy.torchintegration.torchfunction import TorchFunction, check_cuda_enabled
from slangpy.torchintegration.torchstruct import TorchStruct


Expand All @@ -18,6 +18,7 @@ class TorchModule:

def __init__(self, module: 'Module'):
super().__init__()
check_cuda_enabled(module.device)
self.module = module

@staticmethod
Expand Down
3 changes: 2 additions & 1 deletion slangpy/torchintegration/torchstruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from slangpy.core.function import Function
from slangpy.core.struct import Struct
from slangpy.torchintegration.torchfunction import TorchFunction
from slangpy.torchintegration.torchfunction import TorchFunction, check_cuda_enabled


class TorchStruct:
Expand All @@ -14,6 +14,7 @@ class TorchStruct:

def __init__(self, struct: Struct) -> None:
super().__init__()
check_cuda_enabled(struct.module.device)
self.struct = struct

@property
Expand Down
8 changes: 6 additions & 2 deletions slangpy/torchintegration/wrappedtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,11 @@ def create_calldata(self, context: CallContext, binding: 'BoundVariableRuntime',
return result

def create_output(self, context: CallContext, binding: BoundVariableRuntime) -> Any:
return WrappedTensor(torch.empty(context.call_shape.as_tuple(), dtype=self.torch_dtype, device=torch.device('cuda')))
# Overall shape of tensor must contain the call, plus the shape of the slang datatype
# i.e. if a float tensor is to store 4x4 matrix results, it needs the shape to be
# extended by (4,4)
combined_shape = context.call_shape.as_tuple() + self.slang_dtype.shape.as_tuple()
return WrappedTensor(torch.empty(combined_shape, dtype=self.torch_dtype, device=torch.device('cuda')))

def read_output(self, context: CallContext, binding: BoundVariableRuntime, data: Any) -> Any:
return data
Expand All @@ -135,7 +139,7 @@ def create_tensor_marshall(layout: SlangProgramLayout, value: Any):
return tr.get_or_create_type(layout, ValueRef, value)
else:
slang_dtype = value.slang_type
torch_dtype = _slang_dtype_to_torch(slang_dtype)
torch_dtype = _slang_dtype_to_torch(innermost_type(slang_dtype))
if torch_dtype is None:
raise ValueError(f"Unsupported slang type {value.slang_type}")
marshall = WrappedTensorMarshall(layout, torch_dtype, slang_dtype,
Expand Down

0 comments on commit e7e2c76

Please sign in to comment.