Skip to content

Commit

Permalink
Start optimizations to shift slangpy to using shadercursor for uniforms
Browse files Browse the repository at this point in the history
  • Loading branch information
ccummingsNV committed Jan 21, 2025
1 parent 21e3794 commit 18dd74f
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 36 deletions.
10 changes: 7 additions & 3 deletions src/sgl/device/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -961,9 +961,13 @@ void Device::upload_buffer_data(Buffer* buffer, const void* data, size_t size, s

std::memcpy(alloc->data, data, size);

CommandBuffer* command_buffer = _begin_shared_command_buffer();
command_buffer->copy_buffer_region(buffer, offset, alloc->buffer, alloc->offset, size);
_end_shared_command_buffer(false);
if (m_shared_command_buffer) {
m_shared_command_buffer->copy_buffer_region(buffer, offset, alloc->buffer, alloc->offset, size);
} else {
CommandBuffer* command_buffer = _begin_shared_command_buffer();
command_buffer->copy_buffer_region(buffer, offset, alloc->buffer, alloc->offset, size);
_end_shared_command_buffer(false);
}
}

void Device::read_buffer_data(const Buffer* buffer, void* data, size_t size, size_t offset)
Expand Down
142 changes: 114 additions & 28 deletions src/sgl/utils/python/slangpy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,22 @@ extern void write_shader_cursor(ShaderCursor& cursor, nb::object value);

namespace sgl::slangpy {

void NativeType::write_shader_cursor_pre_dispatch(
CallContext* context,
NativeBoundVariableRuntime* binding,
ShaderCursor cursor,
nb::object value,
nb::list read_back
) const
{
// We are a leaf node, so generate and store call data for this node.
nb::object cd_val = create_calldata(context, binding, value);
if (!cd_val.is_none()) {
ShaderCursor child_field = cursor[binding->get_variable_name()];
write_shader_cursor(child_field, cd_val);
read_back.append(nb::make_tuple(binding, value, cd_val));
}
}

void NativeBoundVariableRuntime::populate_call_shape(std::vector<int>& call_shape, nb::object value)
{
Expand All @@ -42,8 +58,8 @@ void NativeBoundVariableRuntime::populate_call_shape(std::vector<int>& call_shap

// Get the shape of the value. In the case of none-concrete types,
// only the container shape is needed, as we never map elements.
if (m_python_type->concrete_shape().valid())
m_shape = m_python_type->concrete_shape();
if (m_python_type->get_concrete_shape().valid())
m_shape = m_python_type->get_concrete_shape();
else
m_shape = m_python_type->get_shape(value);

Expand Down Expand Up @@ -109,6 +125,29 @@ void NativeBoundVariableRuntime::write_call_data_pre_dispatch(
}
}

void NativeBoundVariableRuntime::write_shader_cursor_pre_dispatch(
CallContext* context,
ShaderCursor cursor,
nb::object value,
nb::list read_back
)
{
if (m_children) {
// We have children, so generate call data for each child and
// store in a dictionary, then store the dictionary as the call data.
ShaderCursor child_field = cursor[m_variable_name.c_str()];
for (const auto& [name, child_ref] : *m_children) {
if (child_ref) {
nb::object child_value = value[name.c_str()];
child_ref->write_shader_cursor_pre_dispatch(context, child_field, child_value, read_back);
}
}
} else {
// We are a leaf node, so generate and store call data for this node.
m_python_type->write_shader_cursor_pre_dispatch(context, this, cursor, value, read_back);
}
}

void NativeBoundVariableRuntime::read_call_data_post_dispatch(
CallContext* context,
nb::dict call_data,
Expand Down Expand Up @@ -228,6 +267,30 @@ void NativeBoundCallRuntime::write_calldata_pre_dispatch(
}
}

void NativeBoundCallRuntime::write_shader_cursor_pre_dispatch(
CallContext* context,
ShaderCursor cursor,
nb::list args,
nb::dict kwargs,
nb::list read_back

)
{
// Write call data for each positional argument.
for (size_t idx = 0; idx < args.size(); ++idx) {
m_args[idx]->write_shader_cursor_pre_dispatch(context, cursor, args[idx], read_back);
}

// Write call data for each keyword argument.
for (auto [key, value] : kwargs) {
auto it = m_kwargs.find(nb::str(key).c_str());
if (it != m_kwargs.end()) {
it->second->write_shader_cursor_pre_dispatch(context, cursor, nb::cast<nb::object>(value), read_back);
}
}
}


void NativeBoundCallRuntime::read_call_data_post_dispatch(
CallContext* context,
nb::dict call_data,
Expand Down Expand Up @@ -305,8 +368,8 @@ nb::object NativeCallData::exec(
}

// Write uniforms to call data.
nb::dict call_data;
m_runtime->write_calldata_pre_dispatch(context, call_data, unpacked_args, unpacked_kwargs);
// nb::dict call_data;
// m_runtime->write_calldata_pre_dispatch(context, call_data, unpacked_args, unpacked_kwargs);

// Calculate total threads and strides.
int total_threads = 1;
Expand All @@ -318,29 +381,40 @@ nb::object NativeCallData::exec(
}
std::reverse(strides.begin(), strides.end());

if (!strides.empty()) {
call_data["_call_stride"] = nb::cast(strides);
call_data["_call_dim"] = nb::cast(cs);
}
call_data["_thread_count"] = uint3(total_threads, 1, 1);

// Copy user provided vars and insert call data.
nb::dict vars = nb::dict();
nb::list uniforms = opts->get_uniforms();
if (uniforms) {
for (auto u : uniforms) {
if (nb::isinstance<nb::dict>(u)) {
vars.update(nb::cast<nb::dict>(u));
} else {
vars.update(nb::cast<nb::dict>(u(this)));
}
}
}

vars["call_data"] = call_data;
// vars["call_data"] = call_data;

nb::list read_back;

// Dispatch the kernel.
auto bind_vars = [&](ShaderCursor cursor) { write_shader_cursor(cursor, vars); };
auto bind_vars = [&](ShaderCursor cursor)
{
auto call_data_cursor = cursor.find_field("call_data");

if (!strides.empty()) {
call_data_cursor["_call_stride"]
._set_array(&strides[0], strides.size() * 4, TypeReflection::ScalarType::int32, strides.size());
call_data_cursor["_call_dim"]
._set_array(&cs[0], cs.size() * 4, TypeReflection::ScalarType::int32, cs.size());
}
call_data_cursor["_thread_count"] = uint3(total_threads, 1, 1);

m_runtime
->write_shader_cursor_pre_dispatch(context, call_data_cursor, unpacked_args, unpacked_kwargs, read_back);

nb::list uniforms = opts->get_uniforms();
if (uniforms) {
for (auto u : uniforms) {
if (nb::isinstance<nb::dict>(u)) {
write_shader_cursor(cursor, nb::cast<nb::dict>(u));
} else {
write_shader_cursor(cursor, nb::cast<nb::dict>(u(this)));
}
}
}
};
m_kernel->dispatch(uint3(total_threads, 1, 1), bind_vars, command_buffer);

// If command_buffer is not null, return early.
Expand All @@ -349,7 +423,14 @@ nb::object NativeCallData::exec(
}

// Read call data post dispatch.
m_runtime->read_call_data_post_dispatch(context, call_data, unpacked_args, unpacked_kwargs);
// m_runtime->read_call_data_post_dispatch(context, call_data, unpacked_args, unpacked_kwargs);
for (auto val : read_back) {
auto t = nb::cast<nb::tuple>(val);
auto bvr = nb::cast<ref<NativeBoundVariableRuntime>>(t[0]);
auto rb_val = t[1];
auto rb_data = t[2];
bvr->get_python_type()->read_calldata(context, bvr.get(), rb_val, rb_data);
}

// Pack updated 'this' values back.
for (size_t i = 0; i < args.size(); ++i) {
Expand Down Expand Up @@ -578,10 +659,10 @@ SGL_PY_EXPORT(utils_slangpy)
nb::class_<NativeSlangType, Object>(slangpy, "NativeSlangType") //
.def(nb::init<>(), D_NA(NativeSlangType, NativeSlangType))
.def_prop_rw(
"_reflection",
&NativeSlangType::reflection,
&NativeSlangType::set_reflection,
D_NA(NativeSlangType, reflection)
"type_reflection",
&NativeSlangType::get_type_reflection,
&NativeSlangType::set_type_reflection,
D_NA(NativeSlangType, type_reflection)
);

nb::class_<NativeType, PyNativeType, Object>(slangpy, "NativeType") //
Expand All @@ -593,10 +674,15 @@ SGL_PY_EXPORT(utils_slangpy)

.def_prop_rw(
"concrete_shape",
&NativeType::concrete_shape,
&NativeType::get_concrete_shape,
&NativeType::set_concrete_shape,
D_NA(NativeType, concrete_shape)
)
.def(
"write_shader_cursor_pre_dispatch",
&NativeType::write_shader_cursor_pre_dispatch,
D_NA(NativeType, write_shader_cursor_pre_dispatch)
)
.def("create_calldata", &NativeType::create_calldata, D_NA(NativeType, create_calldata))
.def("read_calldata", &NativeType::read_calldata, D_NA(NativeType, read_calldata))
.def("create_output", &NativeType::create_output, D_NA(NativeType, create_output))
Expand Down
43 changes: 38 additions & 5 deletions src/sgl/utils/python/slangpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ class NativeSlangType : public Object {
};

/// Get the reflection type.
ref<TypeReflection> reflection() const { return m_reflection; }
ref<TypeReflection> get_type_reflection() const { return m_type_reflection; }

/// Set the reflection type.
void set_reflection(const ref<TypeReflection>& reflection) { m_reflection = reflection; }
void set_type_reflection(const ref<TypeReflection>& reflection) { m_type_reflection = reflection; }

private:
ref<TypeReflection> m_reflection;
ref<TypeReflection> m_type_reflection;
};

/// Base class for a marshal to a slangpy supported type.
Expand All @@ -63,7 +63,7 @@ class NativeType : public Object {

/// Get the concrete shape of the type. For none-concrete types such as buffers,
/// this will return an invalid shape.
Shape concrete_shape() const { return m_concrete_shape; }
Shape get_concrete_shape() const { return m_concrete_shape; }

/// Set the concrete shape of the type.
void set_concrete_shape(const Shape& concrete_shape) { m_concrete_shape = concrete_shape; }
Expand All @@ -75,6 +75,17 @@ class NativeType : public Object {
return Shape();
}

/// Writes call data to a shader cursor before dispatch, optionally writing data for
/// read back after the kernel has executed. By default, this calls through to
/// create_calldata, which is typically overridden python side to generate a dictionary.
virtual void write_shader_cursor_pre_dispatch(
CallContext* context,
NativeBoundVariableRuntime* binding,
ShaderCursor cursor,
nb::object value,
nb::list read_back
) const;

/// Create call data (uniform values) to be passed to a compute kernel.
virtual nb::object create_calldata(CallContext* context, NativeBoundVariableRuntime* binding, nb::object data) const
{
Expand Down Expand Up @@ -126,10 +137,21 @@ class NativeType : public Object {

/// Nanobind trampoline class for NativeType
struct PyNativeType : public NativeType {
NB_TRAMPOLINE(NativeType, 9);
NB_TRAMPOLINE(NativeType, 10);

Shape get_shape(nb::object data) const override { NB_OVERRIDE(get_shape, data); }

void write_shader_cursor_pre_dispatch(
CallContext* context,
NativeBoundVariableRuntime* binding,
ShaderCursor cursor,
nb::object value,
nb::list read_back
) const override
{
NB_OVERRIDE(write_shader_cursor_pre_dispatch, context, binding, cursor, value, read_back);
}

nb::object
create_calldata(CallContext* context, NativeBoundVariableRuntime* binding, nb::object data) const override
{
Expand Down Expand Up @@ -218,6 +240,9 @@ class NativeBoundVariableRuntime : public Object {
/// Write call data to be passed to a compute kernel by calling create_calldata on the marshal.
void write_call_data_pre_dispatch(CallContext* context, nb::dict call_data, nb::object value);

void
write_shader_cursor_pre_dispatch(CallContext* context, ShaderCursor cursor, nb::object value, nb::list read_back);

/// Read back changes from call data after a kernel has been executed by calling read_calldata on the marshal.
void read_call_data_post_dispatch(CallContext* context, nb::dict call_data, nb::object value);

Expand Down Expand Up @@ -272,6 +297,14 @@ class NativeBoundCallRuntime : Object {
/// Write call data to be passed to a compute kernel by calling create_calldata on the argument marshals.
void write_calldata_pre_dispatch(CallContext* context, nb::dict call_data, nb::list args, nb::dict kwargs);

void write_shader_cursor_pre_dispatch(
CallContext* context,
ShaderCursor cursor,
nb::list args,
nb::dict kwargs,
nb::list read_back
);

/// Read back changes from call data after a kernel has been executed by calling read_calldata on the argument
/// marshals.
void read_call_data_post_dispatch(CallContext* context, nb::dict call_data, nb::list args, nb::dict kwargs);
Expand Down

0 comments on commit 18dd74f

Please sign in to comment.