Skip to content

Commit

Permalink
Native NDBuffer largely working
Browse files Browse the repository at this point in the history
  • Loading branch information
ccummingsNV committed Jan 21, 2025
1 parent 75428d8 commit b7e0548
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 0 deletions.
15 changes: 15 additions & 0 deletions src/sgl/utils/python/slangpy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,21 @@ SGL_PY_EXPORT(utils_slangpy)
D_NA(NativeValueMarshall, NativeValueMarshall)
);

nb::class_<NativeNDBufferMarshall, PyNativeNDBufferMarshall, NativeMarshall>(slangpy, "NativeNDBufferMarshall") //
.def(
"__init__",
[](NativeNDBufferMarshall& self,
int dims,
bool writable,
ref<NativeSlangType> slang_type,
ref<NativeSlangType> slang_element_type)
{ new (&self) PyNativeNDBufferMarshall(dims, writable, slang_type, slang_element_type); },
D_NA(NativeNDBufferMarshall, NativeNDBufferMarshall)
)
.def_prop_ro("dims", &sgl::slangpy::NativeNDBufferMarshall::dims)
.def_prop_ro("writable", &sgl::slangpy::NativeNDBufferMarshall::writable)
.def_prop_ro("slang_element_type", &sgl::slangpy::NativeNDBufferMarshall::slang_element_type);

nb::class_<NativeBoundVariableRuntime, Object>(slangpy, "NativeBoundVariableRuntime") //
.def(nb::init<>(), D_NA(NativeBoundVariableRuntime, NativeBoundVariableRuntime))
.def_prop_rw(
Expand Down
60 changes: 60 additions & 0 deletions src/sgl/utils/python/slangpybuffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,70 @@ class NativeNDBufferMarshall : public NativeMarshall {
};
*/

int dims() const { return m_dims; }
bool writable() const { return m_writable; }
ref<NativeSlangType> slang_element_type() const { return m_slang_element_type; }

Shape get_shape(nb::object data) const override
{
auto buffer = nb::cast<NativeNDBuffer*>(data);
return buffer->shape();
}

void write_shader_cursor_pre_dispatch(
CallContext* context,
NativeBoundVariableRuntime* binding,
ShaderCursor cursor,
nb::object value,
nb::list read_back
) const override
{
SGL_UNUSED(context);
SGL_UNUSED(read_back);

auto buffer = nb::cast<NativeNDBuffer*>(value);
ShaderCursor field = cursor[binding->get_variable_name()];
field["buffer"] = buffer->storage();

auto shape_vec = buffer->shape().as_vector();
field["shape"]
._set_array(&shape_vec[0], shape_vec.size() * 4, TypeReflection::ScalarType::int32, shape_vec.size());

auto strides_vec = buffer->strides().as_vector();
field["strides"]
._set_array(&strides_vec[0], strides_vec.size() * 4, TypeReflection::ScalarType::int32, strides_vec.size());
}

void read_calldata(CallContext* context, NativeBoundVariableRuntime* binding, nb::object data, nb::object result)
const override
{
SGL_UNUSED(context);
SGL_UNUSED(binding);
SGL_UNUSED(data);
SGL_UNUSED(result);
}

nb::object read_output(CallContext* context, NativeBoundVariableRuntime* binding, nb::object data) const override
{
SGL_UNUSED(context);
SGL_UNUSED(binding);
return data;
}

private:
int m_dims;
bool m_writable;
ref<NativeSlangType> m_slang_element_type;
};

/// Nanobind trampoline class for NativeNDBufferMarshall as can't currently implement create_output in native.
struct PyNativeNDBufferMarshall : public NativeNDBufferMarshall {
NB_TRAMPOLINE(NativeNDBufferMarshall, 1);

nb::object create_output(CallContext* context, NativeBoundVariableRuntime* binding) const override
{
NB_OVERRIDE(create_output, context, binding);
}
};

} // namespace sgl::slangpy

0 comments on commit b7e0548

Please sign in to comment.