diff --git a/src/sgl/device/device.cpp b/src/sgl/device/device.cpp index 75920969..bfa35909 100644 --- a/src/sgl/device/device.cpp +++ b/src/sgl/device/device.cpp @@ -961,9 +961,13 @@ void Device::upload_buffer_data(Buffer* buffer, const void* data, size_t size, s std::memcpy(alloc->data, data, size); - CommandBuffer* command_buffer = _begin_shared_command_buffer(); - command_buffer->copy_buffer_region(buffer, offset, alloc->buffer, alloc->offset, size); - _end_shared_command_buffer(false); + if (m_shared_command_buffer) { + m_shared_command_buffer->copy_buffer_region(buffer, offset, alloc->buffer, alloc->offset, size); + } else { + CommandBuffer* command_buffer = _begin_shared_command_buffer(); + command_buffer->copy_buffer_region(buffer, offset, alloc->buffer, alloc->offset, size); + _end_shared_command_buffer(false); + } } void Device::read_buffer_data(const Buffer* buffer, void* data, size_t size, size_t offset) diff --git a/src/sgl/utils/python/slangpy.cpp b/src/sgl/utils/python/slangpy.cpp index b4732ba9..14047828 100644 --- a/src/sgl/utils/python/slangpy.cpp +++ b/src/sgl/utils/python/slangpy.cpp @@ -19,6 +19,22 @@ extern void write_shader_cursor(ShaderCursor& cursor, nb::object value); namespace sgl::slangpy { +void NativeType::write_shader_cursor_pre_dispatch( + CallContext* context, + NativeBoundVariableRuntime* binding, + ShaderCursor cursor, + nb::object value, + nb::list read_back +) const +{ + // We are a leaf node, so generate and store call data for this node. + nb::object cd_val = create_calldata(context, binding, value); + if (!cd_val.is_none()) { + ShaderCursor child_field = cursor[binding->get_variable_name()]; + write_shader_cursor(child_field, cd_val); + read_back.append(nb::make_tuple(binding, value, cd_val)); + } +} void NativeBoundVariableRuntime::populate_call_shape(std::vector& call_shape, nb::object value) { @@ -42,8 +58,8 @@ void NativeBoundVariableRuntime::populate_call_shape(std::vector& call_shap // Get the shape of the value. In the case of none-concrete types, // only the container shape is needed, as we never map elements. - if (m_python_type->concrete_shape().valid()) - m_shape = m_python_type->concrete_shape(); + if (m_python_type->get_concrete_shape().valid()) + m_shape = m_python_type->get_concrete_shape(); else m_shape = m_python_type->get_shape(value); @@ -109,6 +125,29 @@ void NativeBoundVariableRuntime::write_call_data_pre_dispatch( } } +void NativeBoundVariableRuntime::write_shader_cursor_pre_dispatch( + CallContext* context, + ShaderCursor cursor, + nb::object value, + nb::list read_back +) +{ + if (m_children) { + // We have children, so generate call data for each child and + // store in a dictionary, then store the dictionary as the call data. + ShaderCursor child_field = cursor[m_variable_name.c_str()]; + for (const auto& [name, child_ref] : *m_children) { + if (child_ref) { + nb::object child_value = value[name.c_str()]; + child_ref->write_shader_cursor_pre_dispatch(context, child_field, child_value, read_back); + } + } + } else { + // We are a leaf node, so generate and store call data for this node. + m_python_type->write_shader_cursor_pre_dispatch(context, this, cursor, value, read_back); + } +} + void NativeBoundVariableRuntime::read_call_data_post_dispatch( CallContext* context, nb::dict call_data, @@ -228,6 +267,30 @@ void NativeBoundCallRuntime::write_calldata_pre_dispatch( } } +void NativeBoundCallRuntime::write_shader_cursor_pre_dispatch( + CallContext* context, + ShaderCursor cursor, + nb::list args, + nb::dict kwargs, + nb::list read_back + +) +{ + // Write call data for each positional argument. + for (size_t idx = 0; idx < args.size(); ++idx) { + m_args[idx]->write_shader_cursor_pre_dispatch(context, cursor, args[idx], read_back); + } + + // Write call data for each keyword argument. + for (auto [key, value] : kwargs) { + auto it = m_kwargs.find(nb::str(key).c_str()); + if (it != m_kwargs.end()) { + it->second->write_shader_cursor_pre_dispatch(context, cursor, nb::cast(value), read_back); + } + } +} + + void NativeBoundCallRuntime::read_call_data_post_dispatch( CallContext* context, nb::dict call_data, @@ -305,8 +368,8 @@ nb::object NativeCallData::exec( } // Write uniforms to call data. - nb::dict call_data; - m_runtime->write_calldata_pre_dispatch(context, call_data, unpacked_args, unpacked_kwargs); + // nb::dict call_data; + // m_runtime->write_calldata_pre_dispatch(context, call_data, unpacked_args, unpacked_kwargs); // Calculate total threads and strides. int total_threads = 1; @@ -318,29 +381,40 @@ nb::object NativeCallData::exec( } std::reverse(strides.begin(), strides.end()); - if (!strides.empty()) { - call_data["_call_stride"] = nb::cast(strides); - call_data["_call_dim"] = nb::cast(cs); - } - call_data["_thread_count"] = uint3(total_threads, 1, 1); // Copy user provided vars and insert call data. - nb::dict vars = nb::dict(); - nb::list uniforms = opts->get_uniforms(); - if (uniforms) { - for (auto u : uniforms) { - if (nb::isinstance(u)) { - vars.update(nb::cast(u)); - } else { - vars.update(nb::cast(u(this))); - } - } - } - vars["call_data"] = call_data; + // vars["call_data"] = call_data; + + nb::list read_back; // Dispatch the kernel. - auto bind_vars = [&](ShaderCursor cursor) { write_shader_cursor(cursor, vars); }; + auto bind_vars = [&](ShaderCursor cursor) + { + auto call_data_cursor = cursor.find_field("call_data"); + + if (!strides.empty()) { + call_data_cursor["_call_stride"] + ._set_array(&strides[0], strides.size() * 4, TypeReflection::ScalarType::int32, strides.size()); + call_data_cursor["_call_dim"] + ._set_array(&cs[0], cs.size() * 4, TypeReflection::ScalarType::int32, cs.size()); + } + call_data_cursor["_thread_count"] = uint3(total_threads, 1, 1); + + m_runtime + ->write_shader_cursor_pre_dispatch(context, call_data_cursor, unpacked_args, unpacked_kwargs, read_back); + + nb::list uniforms = opts->get_uniforms(); + if (uniforms) { + for (auto u : uniforms) { + if (nb::isinstance(u)) { + write_shader_cursor(cursor, nb::cast(u)); + } else { + write_shader_cursor(cursor, nb::cast(u(this))); + } + } + } + }; m_kernel->dispatch(uint3(total_threads, 1, 1), bind_vars, command_buffer); // If command_buffer is not null, return early. @@ -349,7 +423,14 @@ nb::object NativeCallData::exec( } // Read call data post dispatch. - m_runtime->read_call_data_post_dispatch(context, call_data, unpacked_args, unpacked_kwargs); + // m_runtime->read_call_data_post_dispatch(context, call_data, unpacked_args, unpacked_kwargs); + for (auto val : read_back) { + auto t = nb::cast(val); + auto bvr = nb::cast>(t[0]); + auto rb_val = t[1]; + auto rb_data = t[2]; + bvr->get_python_type()->read_calldata(context, bvr.get(), rb_val, rb_data); + } // Pack updated 'this' values back. for (size_t i = 0; i < args.size(); ++i) { @@ -578,10 +659,10 @@ SGL_PY_EXPORT(utils_slangpy) nb::class_(slangpy, "NativeSlangType") // .def(nb::init<>(), D_NA(NativeSlangType, NativeSlangType)) .def_prop_rw( - "_reflection", - &NativeSlangType::reflection, - &NativeSlangType::set_reflection, - D_NA(NativeSlangType, reflection) + "type_reflection", + &NativeSlangType::get_type_reflection, + &NativeSlangType::set_type_reflection, + D_NA(NativeSlangType, type_reflection) ); nb::class_(slangpy, "NativeType") // @@ -593,10 +674,15 @@ SGL_PY_EXPORT(utils_slangpy) .def_prop_rw( "concrete_shape", - &NativeType::concrete_shape, + &NativeType::get_concrete_shape, &NativeType::set_concrete_shape, D_NA(NativeType, concrete_shape) ) + .def( + "write_shader_cursor_pre_dispatch", + &NativeType::write_shader_cursor_pre_dispatch, + D_NA(NativeType, write_shader_cursor_pre_dispatch) + ) .def("create_calldata", &NativeType::create_calldata, D_NA(NativeType, create_calldata)) .def("read_calldata", &NativeType::read_calldata, D_NA(NativeType, read_calldata)) .def("create_output", &NativeType::create_output, D_NA(NativeType, create_output)) diff --git a/src/sgl/utils/python/slangpy.h b/src/sgl/utils/python/slangpy.h index 920701c7..bbc9d768 100644 --- a/src/sgl/utils/python/slangpy.h +++ b/src/sgl/utils/python/slangpy.h @@ -47,13 +47,13 @@ class NativeSlangType : public Object { }; /// Get the reflection type. - ref reflection() const { return m_reflection; } + ref get_type_reflection() const { return m_type_reflection; } /// Set the reflection type. - void set_reflection(const ref& reflection) { m_reflection = reflection; } + void set_type_reflection(const ref& reflection) { m_type_reflection = reflection; } private: - ref m_reflection; + ref m_type_reflection; }; /// Base class for a marshal to a slangpy supported type. @@ -63,7 +63,7 @@ class NativeType : public Object { /// Get the concrete shape of the type. For none-concrete types such as buffers, /// this will return an invalid shape. - Shape concrete_shape() const { return m_concrete_shape; } + Shape get_concrete_shape() const { return m_concrete_shape; } /// Set the concrete shape of the type. void set_concrete_shape(const Shape& concrete_shape) { m_concrete_shape = concrete_shape; } @@ -75,6 +75,17 @@ class NativeType : public Object { return Shape(); } + /// Writes call data to a shader cursor before dispatch, optionally writing data for + /// read back after the kernel has executed. By default, this calls through to + /// create_calldata, which is typically overridden python side to generate a dictionary. + virtual void write_shader_cursor_pre_dispatch( + CallContext* context, + NativeBoundVariableRuntime* binding, + ShaderCursor cursor, + nb::object value, + nb::list read_back + ) const; + /// Create call data (uniform values) to be passed to a compute kernel. virtual nb::object create_calldata(CallContext* context, NativeBoundVariableRuntime* binding, nb::object data) const { @@ -126,10 +137,21 @@ class NativeType : public Object { /// Nanobind trampoline class for NativeType struct PyNativeType : public NativeType { - NB_TRAMPOLINE(NativeType, 9); + NB_TRAMPOLINE(NativeType, 10); Shape get_shape(nb::object data) const override { NB_OVERRIDE(get_shape, data); } + void write_shader_cursor_pre_dispatch( + CallContext* context, + NativeBoundVariableRuntime* binding, + ShaderCursor cursor, + nb::object value, + nb::list read_back + ) const override + { + NB_OVERRIDE(write_shader_cursor_pre_dispatch, context, binding, cursor, value, read_back); + } + nb::object create_calldata(CallContext* context, NativeBoundVariableRuntime* binding, nb::object data) const override { @@ -218,6 +240,9 @@ class NativeBoundVariableRuntime : public Object { /// Write call data to be passed to a compute kernel by calling create_calldata on the marshal. void write_call_data_pre_dispatch(CallContext* context, nb::dict call_data, nb::object value); + void + write_shader_cursor_pre_dispatch(CallContext* context, ShaderCursor cursor, nb::object value, nb::list read_back); + /// Read back changes from call data after a kernel has been executed by calling read_calldata on the marshal. void read_call_data_post_dispatch(CallContext* context, nb::dict call_data, nb::object value); @@ -272,6 +297,14 @@ class NativeBoundCallRuntime : Object { /// Write call data to be passed to a compute kernel by calling create_calldata on the argument marshals. void write_calldata_pre_dispatch(CallContext* context, nb::dict call_data, nb::list args, nb::dict kwargs); + void write_shader_cursor_pre_dispatch( + CallContext* context, + ShaderCursor cursor, + nb::list args, + nb::dict kwargs, + nb::list read_back + ); + /// Read back changes from call data after a kernel has been executed by calling read_calldata on the argument /// marshals. void read_call_data_post_dispatch(CallContext* context, nb::dict call_data, nb::list args, nb::dict kwargs);