Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose is_sub_type reflection api #121

Merged
merged 2 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/sgl/device/python/reflection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ SGL_PY_EXPORT(device_reflection)
D_NA(ProgramLayout, find_function_by_name_in_type)
)
.def("get_type_layout", &ProgramLayout::get_type_layout, "type"_a, D_NA(ProgramLayout, get_type_layout))
.def("is_sub_type", &ProgramLayout::is_sub_type, "sub_type"_a, "super_type"_a, D_NA(ProgramLayout, is_sub_type))
.def_prop_ro("hashed_strings", &ProgramLayout::hashed_strings, D(ProgramLayout, hashed_strings))
.def("__repr__", &ProgramLayout::to_string);

Expand Down
11 changes: 9 additions & 2 deletions src/sgl/device/reflection.h
Original file line number Diff line number Diff line change
Expand Up @@ -1014,13 +1014,20 @@ class SGL_API ProgramLayout : public BaseReflectionObjectImpl<slang::ProgramLayo

/// Find a given function in a type by name. Handles generic specilization if generic
/// variable values are provided.
ref<const FunctionReflection> find_function_by_name_in_type(TypeReflection* type, const char* name)
ref<const FunctionReflection> find_function_by_name_in_type(const TypeReflection* type, const char* name)
{
return detail::from_slang(m_owner, m_target->findFunctionByNameInType(type->slang_target(), name));
}

/// Test whether a type is a sub type of another type. Handles both
/// struct inheritance and interface implementation.
bool is_sub_type(const TypeReflection* sub_type, const TypeReflection* super_type)
{
return m_target->isSubType(sub_type->slang_target(), super_type->slang_target());
}

/// Get corresponding type layout from a given type.
ref<const TypeLayoutReflection> get_type_layout(TypeReflection* type)
ref<const TypeLayoutReflection> get_type_layout(const TypeReflection* type)
{
// TODO: Once device is available via session reference, pass metal layout rules for metal target
return detail::from_slang(m_owner, m_target->getTypeLayout(type->slang_target(), slang::LayoutRules::Default));
Expand Down
42 changes: 42 additions & 0 deletions src/sgl/device/tests/test_reflection.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,5 +669,47 @@ def test_get_type_layout(test_id: str, device_type: sgl.DeviceType):
assert tl.size == 4


@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES)
def test_is_sub_type(test_id: str, device_type: sgl.DeviceType):
device = helpers.get_device(type=device_type)

# Create a session, and within it a module.
session = helpers.create_session(device, {})
module = session.load_module_from_source(
module_name=f"module_from_source_{test_id}",
source="""
interface IHello {
}
struct Hello : IHello {
int a;
}

interface IHello2 {
}
extension Hello : IHello2 {
}
""",
)

t = module.layout.find_type_by_name("Hello")
assert t is not None
assert t.name == "Hello"

# t should be a sub type of itself.
assert module.layout.is_sub_type(t, t)

# t should be a sub type of its interface.
i = module.layout.find_type_by_name("IHello")
assert i is not None
assert i.name == "IHello"
assert module.layout.is_sub_type(t, i)

# t should be a sub type of its extension.
i2 = module.layout.find_type_by_name("IHello2")
assert i2 is not None
assert i2.name == "IHello2"
assert module.layout.is_sub_type(t, i)


if __name__ == "__main__":
pytest.main([__file__, "-v"])
Loading