From 75428d850e50634e8bf83215223eeb7f35fdd16d Mon Sep 17 00:00:00 2001 From: Chris Cummings Date: Tue, 21 Jan 2025 16:03:24 +0000 Subject: [PATCH] Started on native value and buffer marshalling --- src/CMakeLists.txt | 4 ++ src/sgl/utils/python/slangpy.cpp | 53 +++++++++++--- src/sgl/utils/python/slangpy.h | 15 ++++ src/sgl/utils/python/slangpybuffer.cpp | 27 +++++++ src/sgl/utils/python/slangpybuffer.h | 97 ++++++++++++++++++++++++++ src/sgl/utils/python/slangpyvalue.cpp | 28 ++++++++ src/sgl/utils/python/slangpyvalue.h | 41 +++++++++++ src/sgl/utils/slangpy.h | 10 +++ 8 files changed, 264 insertions(+), 11 deletions(-) create mode 100644 src/sgl/utils/python/slangpybuffer.cpp create mode 100644 src/sgl/utils/python/slangpybuffer.h create mode 100644 src/sgl/utils/python/slangpyvalue.cpp create mode 100644 src/sgl/utils/python/slangpyvalue.h diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index e916086e..972e37a4 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -433,6 +433,10 @@ if(SGL_BUILD_PYTHON) sgl/utils/python/renderdoc.cpp sgl/utils/python/slangpy.h sgl/utils/python/slangpy.cpp + sgl/utils/python/slangpyvalue.h + sgl/utils/python/slangpyvalue.cpp + sgl/utils/python/slangpybuffer.h + sgl/utils/python/slangpybuffer.cpp sgl/utils/python/tev.cpp sgl/utils/python/texture_loader.cpp ) diff --git a/src/sgl/utils/python/slangpy.cpp b/src/sgl/utils/python/slangpy.cpp index 6e81b6d9..10c9b881 100644 --- a/src/sgl/utils/python/slangpy.cpp +++ b/src/sgl/utils/python/slangpy.cpp @@ -12,6 +12,8 @@ #include "sgl/device/command.h" #include "sgl/utils/python/slangpy.h" +#include "sgl/utils/python/slangpyvalue.h" +#include "sgl/utils/python/slangpybuffer.h" namespace sgl { extern void write_shader_cursor(ShaderCursor& cursor, nb::object value); @@ -678,9 +680,20 @@ SGL_PY_EXPORT(utils_slangpy) &NativeMarshall::set_concrete_shape, D_NA(NativeMarshall, concrete_shape) ) + .def_prop_rw( + "slang_type", + &NativeMarshall::get_slang_type, + &NativeMarshall::set_slang_type, + D_NA(NativeMarshall, slang_type) + ) .def( "write_shader_cursor_pre_dispatch", &NativeMarshall::write_shader_cursor_pre_dispatch, + "context"_a, + "binding"_a, + "cursor"_a, + "value"_a, + "read_back"_a, D_NA(NativeMarshall, write_shader_cursor_pre_dispatch) ) .def("create_calldata", &NativeMarshall::create_calldata, D_NA(NativeMarshall, create_calldata)) @@ -688,6 +701,13 @@ SGL_PY_EXPORT(utils_slangpy) .def("create_output", &NativeMarshall::create_output, D_NA(NativeMarshall, create_output)) .def("read_output", &NativeMarshall::read_output, D_NA(NativeMarshall, read_output)); + nb::class_(slangpy, "NativeValueMarshall") // + .def( + "__init__", + [](NativeValueMarshall& self) { new (&self) NativeValueMarshall(); }, + D_NA(NativeValueMarshall, NativeValueMarshall) + ); + nb::class_(slangpy, "NativeBoundVariableRuntime") // .def(nb::init<>(), D_NA(NativeBoundVariableRuntime, NativeBoundVariableRuntime)) .def_prop_rw( @@ -923,14 +943,6 @@ SGL_PY_EXPORT(utils_slangpy) D_NA(Shape, operator==) ); - /* def __eq__(self, value - : object) - ->bool : if isinstance (value, Shape) - : return self.shape - == value.shape else : return self.shape - == value*/ - - nb::class_(slangpy, "CallContext") // .def( nb::init, const Shape&, CallMode>(), @@ -950,7 +962,26 @@ SGL_PY_EXPORT(utils_slangpy) nb::rv_policy::reference_internal, D_NA(CallContext, call_shape) ) - .def_prop_ro("call_mode", &CallContext::call_mode, D_NA(CallContext, call_mode)) - - ; + .def_prop_ro("call_mode", &CallContext::call_mode, D_NA(CallContext, call_mode)); + + nb::class_(slangpy, "NativeNDBufferDesc") + .def(nb::init<>()) + .def_rw("dtype", &NativeNDBufferDesc::dtype) + .def_rw("element_stride", &NativeNDBufferDesc::element_stride) + .def_rw("shape", &NativeNDBufferDesc::shape) + .def_rw("strides", &NativeNDBufferDesc::strides) + .def_rw("usage", &NativeNDBufferDesc::usage) + .def_rw("memory_type", &NativeNDBufferDesc::memory_type); + + nb::class_(slangpy, "NativeNDBuffer") + .def(nb::init, NativeNDBufferDesc>()) + .def_prop_ro("device", &NativeNDBuffer::device) + .def_prop_rw("slangpy_signature", &NativeNDBuffer::slangpy_signature, &NativeNDBuffer::set_slagpy_signature) + .def_prop_ro("dtype", &NativeNDBuffer::dtype) + .def_prop_ro("shape", &NativeNDBuffer::shape) + .def_prop_ro("strides", &NativeNDBuffer::strides) + .def_prop_ro("element_count", &NativeNDBuffer::element_count) + .def_prop_ro("usage", &NativeNDBuffer::usage) + .def_prop_ro("memory_type", &NativeNDBuffer::memory_type) + .def_prop_ro("storage", &NativeNDBuffer::storage); } diff --git a/src/sgl/utils/python/slangpy.h b/src/sgl/utils/python/slangpy.h index 017b7e5b..1d4c0186 100644 --- a/src/sgl/utils/python/slangpy.h +++ b/src/sgl/utils/python/slangpy.h @@ -11,6 +11,7 @@ #include "sgl/core/fwd.h" #include "sgl/core/object.h" #include "sgl/device/fwd.h" +#include "sgl/device/shader_cursor.h" #include "sgl/utils/slangpy.h" namespace sgl::slangpy { @@ -59,6 +60,13 @@ class NativeSlangType : public Object { /// Base class for a marshal to a slangpy supported type. class NativeMarshall : public Object { public: + NativeMarshall() = default; + + NativeMarshall(ref slang_type) + : m_slang_type(std::move(slang_type)) + { + } + virtual ~NativeMarshall() = default; /// Get the concrete shape of the type. For none-concrete types such as buffers, @@ -75,6 +83,12 @@ class NativeMarshall : public Object { return Shape(); } + /// Get the slang type. + ref get_slang_type() const { return m_slang_type; } + + /// Set the slang type. + void set_slang_type(const ref& slang_type) { m_slang_type = slang_type; } + /// Writes call data to a shader cursor before dispatch, optionally writing data for /// read back after the kernel has executed. By default, this calls through to /// create_calldata, which is typically overridden python side to generate a dictionary. @@ -133,6 +147,7 @@ class NativeMarshall : public Object { private: Shape m_concrete_shape; + ref m_slang_type; }; /// Nanobind trampoline class for NativeMarshall diff --git a/src/sgl/utils/python/slangpybuffer.cpp b/src/sgl/utils/python/slangpybuffer.cpp new file mode 100644 index 00000000..bd75e69a --- /dev/null +++ b/src/sgl/utils/python/slangpybuffer.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: Apache-2.0 + +#include "sgl/device/device.h" + +#include "sgl/utils/python/slangpybuffer.h" + +namespace sgl { +extern void write_shader_cursor(ShaderCursor& cursor, nb::object value); +} + +namespace sgl::slangpy { + +NativeNDBuffer::NativeNDBuffer(ref device, NativeNDBufferDesc desc) + : m_desc(desc) +{ + + BufferDesc buffer_desc; + buffer_desc.element_count = desc.shape.element_count(); + buffer_desc.struct_size = desc.element_stride; + buffer_desc.usage = desc.usage; + buffer_desc.memory_type = desc.memory_type; + m_storage = device->create_buffer(buffer_desc); + + m_signature = fmt::format("[{},{},{}]", desc.dtype->get_type_reflection()->name(), desc.shape.size(), desc.usage); +} + +} // namespace sgl::slangpy diff --git a/src/sgl/utils/python/slangpybuffer.h b/src/sgl/utils/python/slangpybuffer.h new file mode 100644 index 00000000..efdc85f9 --- /dev/null +++ b/src/sgl/utils/python/slangpybuffer.h @@ -0,0 +1,97 @@ +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +#include "nanobind.h" + +#include "sgl/core/macros.h" +#include "sgl/core/fwd.h" +#include "sgl/core/object.h" + +#include "sgl/device/fwd.h" +#include "sgl/device/resource.h" + +#include "sgl/utils/python/slangpy.h" + +namespace sgl::slangpy { + +struct NativeNDBufferDesc { + ref dtype; + int element_stride; + Shape shape; + Shape strides; + ResourceUsage usage{ResourceUsage::shader_resource | ResourceUsage::unordered_access}; + MemoryType memory_type{MemoryType::device_local}; +}; + +class NativeNDBuffer : public Object { +public: + NativeNDBuffer(ref device, NativeNDBufferDesc desc); + + Device* device() const { return storage()->device(); } + std::string_view slangpy_signature() const { return m_signature; } + void set_slagpy_signature(std::string_view signature) { m_signature = signature; } + ref dtype() const { return m_desc.dtype; } + Shape shape() const { return m_desc.shape; } + Shape strides() const { return m_desc.strides; } + size_t element_count() const { return m_desc.shape.element_count(); } + ResourceUsage usage() const { return m_desc.usage; } + MemoryType memory_type() const { return m_desc.memory_type; } + ref storage() const { return m_storage; } + +private: + NativeNDBufferDesc m_desc; + ref m_storage; + std::string m_signature; +}; + + +class NativeNDBufferMarshall : public NativeMarshall { +public: + NativeNDBufferMarshall( + int dims, + bool writable, + ref slang_type, + ref slang_element_type + ) + : NativeMarshall(slang_type) + , m_dims(dims) + , m_writable(writable) + , m_slang_element_type(slang_element_type) + { + } + /* + /// Writes call data to a shader cursor before dispatch, optionally writing data for + /// read back after the kernel has executed. By default, this calls through to + /// create_calldata, which is typically overridden python side to generate a dictionary. + void write_shader_cursor_pre_dispatch( + CallContext* context, + NativeBoundVariableRuntime* binding, + ShaderCursor cursor, + nb::object value, + nb::list read_back + ) const override; + + + /// Dispatch data is just the value. + nb::object create_dispatchdata(nb::object data) const override { return data; } + + /// If requested, output is just the input value (as it can't have changed). + nb::object read_output(CallContext* context, NativeBoundVariableRuntime* binding, nb::object data) const override + { + SGL_UNUSED(context); + SGL_UNUSED(binding); + return data; + }; + */ + +private: + int m_dims; + bool m_writable; + ref m_slang_element_type; +}; + +} // namespace sgl::slangpy diff --git a/src/sgl/utils/python/slangpyvalue.cpp b/src/sgl/utils/python/slangpyvalue.cpp new file mode 100644 index 00000000..145082fd --- /dev/null +++ b/src/sgl/utils/python/slangpyvalue.cpp @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: Apache-2.0 + +#include "sgl/utils/python/slangpyvalue.h" + +namespace sgl { +extern void write_shader_cursor(ShaderCursor& cursor, nb::object value); +} + +namespace sgl::slangpy { + +void NativeValueMarshall::write_shader_cursor_pre_dispatch( + CallContext* context, + NativeBoundVariableRuntime* binding, + ShaderCursor cursor, + nb::object value, + nb::list read_back +) const +{ + AccessType primal_access = binding->get_access().first; + if (!value.is_none() && (primal_access == AccessType::read || primal_access == AccessType::readwrite)) { + SGL_UNUSED(binding); + SGL_UNUSED(context); + ShaderCursor field = cursor[binding->get_variable_name()]["value"]; + write_shader_cursor(field, value); + } +} + +} // namespace sgl::slangpy diff --git a/src/sgl/utils/python/slangpyvalue.h b/src/sgl/utils/python/slangpyvalue.h new file mode 100644 index 00000000..b7609646 --- /dev/null +++ b/src/sgl/utils/python/slangpyvalue.h @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +#include "nanobind.h" + +#include "sgl/utils/python/slangpy.h" + +namespace sgl::slangpy { + +/// Base class for marshalling simple value types between Python and Slang. +class NativeValueMarshall : public NativeMarshall { +public: + /// Writes call data to a shader cursor before dispatch, optionally writing data for + /// read back after the kernel has executed. By default, this calls through to + /// create_calldata, which is typically overridden python side to generate a dictionary. + void write_shader_cursor_pre_dispatch( + CallContext* context, + NativeBoundVariableRuntime* binding, + ShaderCursor cursor, + nb::object value, + nb::list read_back + ) const override; + + + /// Dispatch data is just the value. + nb::object create_dispatchdata(nb::object data) const override { return data; } + + /// If requested, output is just the input value (as it can't have changed). + nb::object read_output(CallContext* context, NativeBoundVariableRuntime* binding, nb::object data) const override + { + SGL_UNUSED(context); + SGL_UNUSED(binding); + return data; + }; +}; + +} // namespace sgl::slangpy diff --git a/src/sgl/utils/slangpy.h b/src/sgl/utils/slangpy.h index 0182c31f..f033764a 100644 --- a/src/sgl/utils/slangpy.h +++ b/src/sgl/utils/slangpy.h @@ -130,6 +130,16 @@ class SGL_API Shape { return fmt::format("[{}]", fmt::join(as_vector(), ", ")); } + /// Total element count (if this represented contiguous array) + size_t element_count() const + { + size_t result = 1; + for (auto dim : as_vector()) { + result *= dim; + } + return result; + } + private: std::optional> m_shape; };