Skip to content

Commit

Permalink
Started on native value and buffer marshalling
Browse files Browse the repository at this point in the history
  • Loading branch information
ccummingsNV committed Jan 21, 2025
1 parent 6a517a4 commit 75428d8
Show file tree
Hide file tree
Showing 8 changed files with 264 additions and 11 deletions.
4 changes: 4 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,10 @@ if(SGL_BUILD_PYTHON)
sgl/utils/python/renderdoc.cpp
sgl/utils/python/slangpy.h
sgl/utils/python/slangpy.cpp
sgl/utils/python/slangpyvalue.h
sgl/utils/python/slangpyvalue.cpp
sgl/utils/python/slangpybuffer.h
sgl/utils/python/slangpybuffer.cpp
sgl/utils/python/tev.cpp
sgl/utils/python/texture_loader.cpp
)
Expand Down
53 changes: 42 additions & 11 deletions src/sgl/utils/python/slangpy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
#include "sgl/device/command.h"

#include "sgl/utils/python/slangpy.h"
#include "sgl/utils/python/slangpyvalue.h"
#include "sgl/utils/python/slangpybuffer.h"

namespace sgl {
extern void write_shader_cursor(ShaderCursor& cursor, nb::object value);
Expand Down Expand Up @@ -678,16 +680,34 @@ SGL_PY_EXPORT(utils_slangpy)
&NativeMarshall::set_concrete_shape,
D_NA(NativeMarshall, concrete_shape)
)
.def_prop_rw(
"slang_type",
&NativeMarshall::get_slang_type,
&NativeMarshall::set_slang_type,
D_NA(NativeMarshall, slang_type)
)
.def(
"write_shader_cursor_pre_dispatch",
&NativeMarshall::write_shader_cursor_pre_dispatch,
"context"_a,
"binding"_a,
"cursor"_a,
"value"_a,
"read_back"_a,
D_NA(NativeMarshall, write_shader_cursor_pre_dispatch)
)
.def("create_calldata", &NativeMarshall::create_calldata, D_NA(NativeMarshall, create_calldata))
.def("read_calldata", &NativeMarshall::read_calldata, D_NA(NativeMarshall, read_calldata))
.def("create_output", &NativeMarshall::create_output, D_NA(NativeMarshall, create_output))
.def("read_output", &NativeMarshall::read_output, D_NA(NativeMarshall, read_output));

nb::class_<NativeValueMarshall, NativeMarshall>(slangpy, "NativeValueMarshall") //
.def(
"__init__",
[](NativeValueMarshall& self) { new (&self) NativeValueMarshall(); },
D_NA(NativeValueMarshall, NativeValueMarshall)
);

nb::class_<NativeBoundVariableRuntime, Object>(slangpy, "NativeBoundVariableRuntime") //
.def(nb::init<>(), D_NA(NativeBoundVariableRuntime, NativeBoundVariableRuntime))
.def_prop_rw(
Expand Down Expand Up @@ -923,14 +943,6 @@ SGL_PY_EXPORT(utils_slangpy)
D_NA(Shape, operator==)
);

/* def __eq__(self, value
: object)
->bool : if isinstance (value, Shape)
: return self.shape
== value.shape else : return self.shape
== value*/


nb::class_<CallContext, Object>(slangpy, "CallContext") //
.def(
nb::init<ref<Device>, const Shape&, CallMode>(),
Expand All @@ -950,7 +962,26 @@ SGL_PY_EXPORT(utils_slangpy)
nb::rv_policy::reference_internal,
D_NA(CallContext, call_shape)
)
.def_prop_ro("call_mode", &CallContext::call_mode, D_NA(CallContext, call_mode))

;
.def_prop_ro("call_mode", &CallContext::call_mode, D_NA(CallContext, call_mode));

nb::class_<NativeNDBufferDesc>(slangpy, "NativeNDBufferDesc")
.def(nb::init<>())
.def_rw("dtype", &NativeNDBufferDesc::dtype)
.def_rw("element_stride", &NativeNDBufferDesc::element_stride)
.def_rw("shape", &NativeNDBufferDesc::shape)
.def_rw("strides", &NativeNDBufferDesc::strides)
.def_rw("usage", &NativeNDBufferDesc::usage)
.def_rw("memory_type", &NativeNDBufferDesc::memory_type);

nb::class_<NativeNDBuffer, Object>(slangpy, "NativeNDBuffer")
.def(nb::init<ref<Device>, NativeNDBufferDesc>())
.def_prop_ro("device", &NativeNDBuffer::device)
.def_prop_rw("slangpy_signature", &NativeNDBuffer::slangpy_signature, &NativeNDBuffer::set_slagpy_signature)
.def_prop_ro("dtype", &NativeNDBuffer::dtype)
.def_prop_ro("shape", &NativeNDBuffer::shape)
.def_prop_ro("strides", &NativeNDBuffer::strides)
.def_prop_ro("element_count", &NativeNDBuffer::element_count)
.def_prop_ro("usage", &NativeNDBuffer::usage)
.def_prop_ro("memory_type", &NativeNDBuffer::memory_type)
.def_prop_ro("storage", &NativeNDBuffer::storage);
}
15 changes: 15 additions & 0 deletions src/sgl/utils/python/slangpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "sgl/core/fwd.h"
#include "sgl/core/object.h"
#include "sgl/device/fwd.h"
#include "sgl/device/shader_cursor.h"
#include "sgl/utils/slangpy.h"

namespace sgl::slangpy {
Expand Down Expand Up @@ -59,6 +60,13 @@ class NativeSlangType : public Object {
/// Base class for a marshal to a slangpy supported type.
class NativeMarshall : public Object {
public:
NativeMarshall() = default;

NativeMarshall(ref<NativeSlangType> slang_type)
: m_slang_type(std::move(slang_type))
{
}

virtual ~NativeMarshall() = default;

/// Get the concrete shape of the type. For none-concrete types such as buffers,
Expand All @@ -75,6 +83,12 @@ class NativeMarshall : public Object {
return Shape();
}

/// Get the slang type.
ref<NativeSlangType> get_slang_type() const { return m_slang_type; }

/// Set the slang type.
void set_slang_type(const ref<NativeSlangType>& slang_type) { m_slang_type = slang_type; }

/// Writes call data to a shader cursor before dispatch, optionally writing data for
/// read back after the kernel has executed. By default, this calls through to
/// create_calldata, which is typically overridden python side to generate a dictionary.
Expand Down Expand Up @@ -133,6 +147,7 @@ class NativeMarshall : public Object {

private:
Shape m_concrete_shape;
ref<NativeSlangType> m_slang_type;
};

/// Nanobind trampoline class for NativeMarshall
Expand Down
27 changes: 27 additions & 0 deletions src/sgl/utils/python/slangpybuffer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// SPDX-License-Identifier: Apache-2.0

#include "sgl/device/device.h"

#include "sgl/utils/python/slangpybuffer.h"

namespace sgl {
extern void write_shader_cursor(ShaderCursor& cursor, nb::object value);
}

namespace sgl::slangpy {

NativeNDBuffer::NativeNDBuffer(ref<Device> device, NativeNDBufferDesc desc)
: m_desc(desc)
{

BufferDesc buffer_desc;
buffer_desc.element_count = desc.shape.element_count();
buffer_desc.struct_size = desc.element_stride;
buffer_desc.usage = desc.usage;
buffer_desc.memory_type = desc.memory_type;
m_storage = device->create_buffer(buffer_desc);

m_signature = fmt::format("[{},{},{}]", desc.dtype->get_type_reflection()->name(), desc.shape.size(), desc.usage);
}

} // namespace sgl::slangpy
97 changes: 97 additions & 0 deletions src/sgl/utils/python/slangpybuffer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include <vector>
#include <map>

#include "nanobind.h"

#include "sgl/core/macros.h"
#include "sgl/core/fwd.h"
#include "sgl/core/object.h"

#include "sgl/device/fwd.h"
#include "sgl/device/resource.h"

#include "sgl/utils/python/slangpy.h"

namespace sgl::slangpy {

struct NativeNDBufferDesc {
ref<NativeSlangType> dtype;
int element_stride;
Shape shape;
Shape strides;
ResourceUsage usage{ResourceUsage::shader_resource | ResourceUsage::unordered_access};
MemoryType memory_type{MemoryType::device_local};
};

class NativeNDBuffer : public Object {
public:
NativeNDBuffer(ref<Device> device, NativeNDBufferDesc desc);

Device* device() const { return storage()->device(); }
std::string_view slangpy_signature() const { return m_signature; }
void set_slagpy_signature(std::string_view signature) { m_signature = signature; }
ref<NativeSlangType> dtype() const { return m_desc.dtype; }
Shape shape() const { return m_desc.shape; }
Shape strides() const { return m_desc.strides; }
size_t element_count() const { return m_desc.shape.element_count(); }
ResourceUsage usage() const { return m_desc.usage; }
MemoryType memory_type() const { return m_desc.memory_type; }
ref<Buffer> storage() const { return m_storage; }

private:
NativeNDBufferDesc m_desc;
ref<Buffer> m_storage;
std::string m_signature;
};


class NativeNDBufferMarshall : public NativeMarshall {
public:
NativeNDBufferMarshall(
int dims,
bool writable,
ref<NativeSlangType> slang_type,
ref<NativeSlangType> slang_element_type
)
: NativeMarshall(slang_type)
, m_dims(dims)
, m_writable(writable)
, m_slang_element_type(slang_element_type)
{
}
/*
/// Writes call data to a shader cursor before dispatch, optionally writing data for
/// read back after the kernel has executed. By default, this calls through to
/// create_calldata, which is typically overridden python side to generate a dictionary.
void write_shader_cursor_pre_dispatch(
CallContext* context,
NativeBoundVariableRuntime* binding,
ShaderCursor cursor,
nb::object value,
nb::list read_back
) const override;
/// Dispatch data is just the value.
nb::object create_dispatchdata(nb::object data) const override { return data; }
/// If requested, output is just the input value (as it can't have changed).
nb::object read_output(CallContext* context, NativeBoundVariableRuntime* binding, nb::object data) const override
{
SGL_UNUSED(context);
SGL_UNUSED(binding);
return data;
};
*/

private:
int m_dims;
bool m_writable;
ref<NativeSlangType> m_slang_element_type;
};

} // namespace sgl::slangpy
28 changes: 28 additions & 0 deletions src/sgl/utils/python/slangpyvalue.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// SPDX-License-Identifier: Apache-2.0

#include "sgl/utils/python/slangpyvalue.h"

namespace sgl {
extern void write_shader_cursor(ShaderCursor& cursor, nb::object value);
}

namespace sgl::slangpy {

void NativeValueMarshall::write_shader_cursor_pre_dispatch(
CallContext* context,
NativeBoundVariableRuntime* binding,
ShaderCursor cursor,
nb::object value,
nb::list read_back
) const
{
AccessType primal_access = binding->get_access().first;
if (!value.is_none() && (primal_access == AccessType::read || primal_access == AccessType::readwrite)) {
SGL_UNUSED(binding);
SGL_UNUSED(context);
ShaderCursor field = cursor[binding->get_variable_name()]["value"];
write_shader_cursor(field, value);
}
}

} // namespace sgl::slangpy
41 changes: 41 additions & 0 deletions src/sgl/utils/python/slangpyvalue.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include <vector>
#include <map>

#include "nanobind.h"

#include "sgl/utils/python/slangpy.h"

namespace sgl::slangpy {

/// Base class for marshalling simple value types between Python and Slang.
class NativeValueMarshall : public NativeMarshall {
public:
/// Writes call data to a shader cursor before dispatch, optionally writing data for
/// read back after the kernel has executed. By default, this calls through to
/// create_calldata, which is typically overridden python side to generate a dictionary.
void write_shader_cursor_pre_dispatch(
CallContext* context,
NativeBoundVariableRuntime* binding,
ShaderCursor cursor,
nb::object value,
nb::list read_back
) const override;


/// Dispatch data is just the value.
nb::object create_dispatchdata(nb::object data) const override { return data; }

/// If requested, output is just the input value (as it can't have changed).
nb::object read_output(CallContext* context, NativeBoundVariableRuntime* binding, nb::object data) const override
{
SGL_UNUSED(context);
SGL_UNUSED(binding);
return data;
};
};

} // namespace sgl::slangpy
10 changes: 10 additions & 0 deletions src/sgl/utils/slangpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,16 @@ class SGL_API Shape {
return fmt::format("[{}]", fmt::join(as_vector(), ", "));
}

/// Total element count (if this represented contiguous array)
size_t element_count() const
{
size_t result = 1;
for (auto dim : as_vector()) {
result *= dim;
}
return result;
}

private:
std::optional<std::vector<int>> m_shape;
};
Expand Down

0 comments on commit 75428d8

Please sign in to comment.