From cfcca48cdf665d7a97e2d69ef656340ffe913bf4 Mon Sep 17 00:00:00 2001 From: ccummingsNV Date: Tue, 22 Oct 2024 12:23:59 +0100 Subject: [PATCH] Expose is_sub_type reflection api (#121) * support for is_sub_type api * Tweak subtype test --- src/sgl/device/python/reflection.cpp | 1 + src/sgl/device/reflection.h | 11 +++++-- src/sgl/device/tests/test_reflection.py | 42 +++++++++++++++++++++++++ 3 files changed, 52 insertions(+), 2 deletions(-) diff --git a/src/sgl/device/python/reflection.cpp b/src/sgl/device/python/reflection.cpp index 443de886..bfe1fbe7 100644 --- a/src/sgl/device/python/reflection.cpp +++ b/src/sgl/device/python/reflection.cpp @@ -207,6 +207,7 @@ SGL_PY_EXPORT(device_reflection) D_NA(ProgramLayout, find_function_by_name_in_type) ) .def("get_type_layout", &ProgramLayout::get_type_layout, "type"_a, D_NA(ProgramLayout, get_type_layout)) + .def("is_sub_type", &ProgramLayout::is_sub_type, "sub_type"_a, "super_type"_a, D_NA(ProgramLayout, is_sub_type)) .def_prop_ro("hashed_strings", &ProgramLayout::hashed_strings, D(ProgramLayout, hashed_strings)) .def("__repr__", &ProgramLayout::to_string); diff --git a/src/sgl/device/reflection.h b/src/sgl/device/reflection.h index ccef522e..c5270ebc 100644 --- a/src/sgl/device/reflection.h +++ b/src/sgl/device/reflection.h @@ -1014,13 +1014,20 @@ class SGL_API ProgramLayout : public BaseReflectionObjectImpl find_function_by_name_in_type(TypeReflection* type, const char* name) + ref find_function_by_name_in_type(const TypeReflection* type, const char* name) { return detail::from_slang(m_owner, m_target->findFunctionByNameInType(type->slang_target(), name)); } + /// Test whether a type is a sub type of another type. Handles both + /// struct inheritance and interface implementation. + bool is_sub_type(const TypeReflection* sub_type, const TypeReflection* super_type) + { + return m_target->isSubType(sub_type->slang_target(), super_type->slang_target()); + } + /// Get corresponding type layout from a given type. - ref get_type_layout(TypeReflection* type) + ref get_type_layout(const TypeReflection* type) { // TODO: Once device is available via session reference, pass metal layout rules for metal target return detail::from_slang(m_owner, m_target->getTypeLayout(type->slang_target(), slang::LayoutRules::Default)); diff --git a/src/sgl/device/tests/test_reflection.py b/src/sgl/device/tests/test_reflection.py index b3f28d2b..966c47a1 100644 --- a/src/sgl/device/tests/test_reflection.py +++ b/src/sgl/device/tests/test_reflection.py @@ -669,5 +669,47 @@ def test_get_type_layout(test_id: str, device_type: sgl.DeviceType): assert tl.size == 4 +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_is_sub_type(test_id: str, device_type: sgl.DeviceType): + device = helpers.get_device(type=device_type) + + # Create a session, and within it a module. + session = helpers.create_session(device, {}) + module = session.load_module_from_source( + module_name=f"module_from_source_{test_id}", + source=""" + interface IHello { + } + struct Hello : IHello { + int a; + } + + interface IHello2 { + } + extension Hello : IHello2 { + } + """, + ) + + t = module.layout.find_type_by_name("Hello") + assert t is not None + assert t.name == "Hello" + + # t should be a sub type of itself. + assert module.layout.is_sub_type(t, t) + + # t should be a sub type of its interface. + i = module.layout.find_type_by_name("IHello") + assert i is not None + assert i.name == "IHello" + assert module.layout.is_sub_type(t, i) + + # t should be a sub type of its extension. + i2 = module.layout.find_type_by_name("IHello2") + assert i2 is not None + assert i2.name == "IHello2" + assert module.layout.is_sub_type(t, i) + + if __name__ == "__main__": pytest.main([__file__, "-v"])