Skip to content

Commit

Permalink
All slangpy tests now working
Browse files Browse the repository at this point in the history
  • Loading branch information
ccummingsNV committed Jan 20, 2025
1 parent b1b0b44 commit 21e3794
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 17 deletions.
2 changes: 2 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,7 @@ if(SGL_BUILD_PYTHON)
PYTHON_PATH ${SGL_OUTPUT_DIRECTORY}/python
OUTPUT ${SGL_OUTPUT_DIRECTORY}/python/sgl/__init__.pyi
DEPENDS sgl_ext
INCLUDE_PRIVATE # allow us to have functions/variables that start with _
)

# Post-process the main stub file.
Expand All @@ -485,6 +486,7 @@ if(SGL_BUILD_PYTHON)
PYTHON_PATH ${SGL_OUTPUT_DIRECTORY}/python
OUTPUT ${SGL_OUTPUT_DIRECTORY}/python/sgl/${submodule_path}/__init__.pyi
DEPENDS sgl_ext
INCLUDE_PRIVATE # allow us to have functions/variables that start with _
)
endforeach()

Expand Down
61 changes: 56 additions & 5 deletions src/sgl/utils/python/slangpy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ extern void write_shader_cursor(ShaderCursor& cursor, nb::object value);

namespace sgl::slangpy {


void NativeBoundVariableRuntime::populate_call_shape(std::vector<int>& call_shape, nb::object value)
{
if (m_children) {
Expand Down Expand Up @@ -325,6 +326,17 @@ nb::object NativeCallData::exec(

// Copy user provided vars and insert call data.
nb::dict vars = nb::dict();
nb::list uniforms = opts->get_uniforms();
if (uniforms) {
for (auto u : uniforms) {
if (nb::isinstance<nb::dict>(u)) {
vars.update(nb::cast<nb::dict>(u));
} else {
vars.update(nb::cast<nb::dict>(u(this)));
}
}
}

vars["call_data"] = call_data;

// Dispatch the kernel.
Expand Down Expand Up @@ -357,7 +369,7 @@ nb::object NativeCallData::exec(
return nb::none();
}

nb::list NativeCallData::unpack_args(nb::args args)
nb::list unpack_args(nb::args args)
{
nb::list unpacked;
for (auto arg : args) {
Expand All @@ -366,7 +378,7 @@ nb::list NativeCallData::unpack_args(nb::args args)
return unpacked;
}

nb::dict NativeCallData::unpack_kwargs(nb::kwargs kwargs)
nb::dict unpack_kwargs(nb::kwargs kwargs)
{
nb::dict unpacked;
for (const auto& [k, v] : kwargs) {
Expand All @@ -375,7 +387,7 @@ nb::dict NativeCallData::unpack_kwargs(nb::kwargs kwargs)
return unpacked;
}

nb::object NativeCallData::unpack_arg(nb::object arg)
nb::object unpack_arg(nb::object arg)
{
auto obj = arg;

Expand Down Expand Up @@ -408,7 +420,7 @@ nb::object NativeCallData::unpack_arg(nb::object arg)
return obj;
}

void NativeCallData::pack_arg(nanobind::object arg, nanobind::object unpacked_arg)
void pack_arg(nanobind::object arg, nanobind::object unpacked_arg)
{
// If object has 'update_this', update it.
if (nb::hasattr(arg, "update_this")) {
Expand Down Expand Up @@ -523,6 +535,32 @@ SGL_PY_EXPORT(utils_slangpy)
D_NA(slangpy, hash_signature)
);

slangpy.def(
"unpack_args",
[](nb::args args) { return unpack_args(args); },
"args"_a,
D_NA(slangpy, unpack_args)
);
slangpy.def(
"unpack_kwargs",
[](nb::kwargs kwargs) { return unpack_kwargs(kwargs); },
"kwargs"_a,
D_NA(slangpy, unpack_kwargs)
);
slangpy.def(
"unpack_arg",
[](nb::object arg) { return unpack_arg(arg); },
"arg"_a,
D_NA(slangpy, unpack_arg)
);
slangpy.def(
"pack_arg",
[](nb::object arg, nb::object unpacked_arg) { pack_arg(arg, unpacked_arg); },
"arg"_a,
"unpacked_arg"_a,
D_NA(slangpy, pack_arg)
);

nb::register_exception_translator(
[](const std::exception_ptr& p, void* /* unused */)
{
Expand Down Expand Up @@ -623,6 +661,11 @@ SGL_PY_EXPORT(utils_slangpy)
&NativeBoundVariableRuntime::read_call_data_post_dispatch,
D_NA(NativeBoundVariableRuntime, read_call_data_post_dispatch)
)
.def(
"write_raw_dispatch_data",
&NativeBoundVariableRuntime::write_raw_dispatch_data,
D_NA(NativeBoundVariableRuntime, write_raw_dispatch_data)
)
.def("read_output", &NativeBoundVariableRuntime::read_output, D_NA(NativeBoundVariableRuntime, read_output));

nb::class_<NativeBoundCallRuntime, Object>(slangpy, "NativeBoundCallRuntime") //
Expand Down Expand Up @@ -654,6 +697,11 @@ SGL_PY_EXPORT(utils_slangpy)
"read_call_data_post_dispatch",
&NativeBoundCallRuntime::read_call_data_post_dispatch,
D_NA(NativeBoundCallRuntime, read_call_data_post_dispatch)
)
.def(
"write_raw_dispatch_data",
&NativeBoundCallRuntime::write_raw_dispatch_data,
D_NA(NativeBoundCallRuntime, write_raw_dispatch_data)
);

nb::class_<NativeCallRuntimeOptions, Object>(slangpy, "NativeCallRuntimeOptions") //
Expand Down Expand Up @@ -815,5 +863,8 @@ SGL_PY_EXPORT(utils_slangpy)
&CallContext::call_shape,
nb::rv_policy::reference_internal,
D_NA(CallContext, call_shape)
);
)
.def_prop_ro("call_mode", &CallContext::call_mode, D_NA(CallContext, call_mode))

;
}
18 changes: 6 additions & 12 deletions src/sgl/utils/python/slangpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,7 @@ class NativeBoundVariableRuntime : public Object {
ref<NativeSlangType> get_vector_type() const { return m_vector_type; }

/// Set the vector slang type.
void set_vector_type(const ref<NativeSlangType>& vector_type)
{
SGL_UNUSED(vector_type); /* m_vector_type = vector_type*/
;
}
void set_vector_type(ref<NativeSlangType> vector_type) { m_vector_type = vector_type; }

/// Get the shape being used for the current call.
Shape get_shape() const { return m_shape; }
Expand Down Expand Up @@ -356,14 +352,12 @@ class NativeCallData : Object {

nb::object
exec(ref<NativeCallRuntimeOptions> opts, CommandBuffer* command_buffer, nb::args args, nb::kwargs kwargs);
};

nb::list unpack_args(nb::args args);

nb::dict unpack_kwargs(nb::kwargs kwargs);

nb::object unpack_arg(nanobind::object arg);

void pack_arg(nb::object arg, nb::object unpacked_arg);
};
nb::list unpack_args(nb::args args);
nb::dict unpack_kwargs(nb::kwargs kwargs);
nb::object unpack_arg(nanobind::object arg);
void pack_arg(nb::object arg, nb::object unpacked_arg);

} // namespace sgl::slangpy

0 comments on commit 21e3794

Please sign in to comment.