Skip to content

Commit

Permalink
Expose is_sub_type reflection api (#121)
Browse files Browse the repository at this point in the history
* support for is_sub_type api

* Tweak subtype test
  • Loading branch information
ccummingsNV authored Oct 22, 2024
1 parent 2515418 commit cfcca48
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 2 deletions.
1 change: 1 addition & 0 deletions src/sgl/device/python/reflection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ SGL_PY_EXPORT(device_reflection)
D_NA(ProgramLayout, find_function_by_name_in_type)
)
.def("get_type_layout", &ProgramLayout::get_type_layout, "type"_a, D_NA(ProgramLayout, get_type_layout))
.def("is_sub_type", &ProgramLayout::is_sub_type, "sub_type"_a, "super_type"_a, D_NA(ProgramLayout, is_sub_type))
.def_prop_ro("hashed_strings", &ProgramLayout::hashed_strings, D(ProgramLayout, hashed_strings))
.def("__repr__", &ProgramLayout::to_string);

Expand Down
11 changes: 9 additions & 2 deletions src/sgl/device/reflection.h
Original file line number Diff line number Diff line change
Expand Up @@ -1014,13 +1014,20 @@ class SGL_API ProgramLayout : public BaseReflectionObjectImpl<slang::ProgramLayo

/// Find a given function in a type by name. Handles generic specilization if generic
/// variable values are provided.
ref<const FunctionReflection> find_function_by_name_in_type(TypeReflection* type, const char* name)
ref<const FunctionReflection> find_function_by_name_in_type(const TypeReflection* type, const char* name)
{
return detail::from_slang(m_owner, m_target->findFunctionByNameInType(type->slang_target(), name));
}

/// Test whether a type is a sub type of another type. Handles both
/// struct inheritance and interface implementation.
bool is_sub_type(const TypeReflection* sub_type, const TypeReflection* super_type)
{
return m_target->isSubType(sub_type->slang_target(), super_type->slang_target());
}

/// Get corresponding type layout from a given type.
ref<const TypeLayoutReflection> get_type_layout(TypeReflection* type)
ref<const TypeLayoutReflection> get_type_layout(const TypeReflection* type)
{
// TODO: Once device is available via session reference, pass metal layout rules for metal target
return detail::from_slang(m_owner, m_target->getTypeLayout(type->slang_target(), slang::LayoutRules::Default));
Expand Down
42 changes: 42 additions & 0 deletions src/sgl/device/tests/test_reflection.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,5 +669,47 @@ def test_get_type_layout(test_id: str, device_type: sgl.DeviceType):
assert tl.size == 4


@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES)
def test_is_sub_type(test_id: str, device_type: sgl.DeviceType):
device = helpers.get_device(type=device_type)

# Create a session, and within it a module.
session = helpers.create_session(device, {})
module = session.load_module_from_source(
module_name=f"module_from_source_{test_id}",
source="""
interface IHello {
}
struct Hello : IHello {
int a;
}
interface IHello2 {
}
extension Hello : IHello2 {
}
""",
)

t = module.layout.find_type_by_name("Hello")
assert t is not None
assert t.name == "Hello"

# t should be a sub type of itself.
assert module.layout.is_sub_type(t, t)

# t should be a sub type of its interface.
i = module.layout.find_type_by_name("IHello")
assert i is not None
assert i.name == "IHello"
assert module.layout.is_sub_type(t, i)

# t should be a sub type of its extension.
i2 = module.layout.find_type_by_name("IHello2")
assert i2 is not None
assert i2.name == "IHello2"
assert module.layout.is_sub_type(t, i)


if __name__ == "__main__":
pytest.main([__file__, "-v"])

0 comments on commit cfcca48

Please sign in to comment.