Skip to content

Commit

Permalink
Extra tests for shader access via compute
Browse files Browse the repository at this point in the history
  • Loading branch information
ccummingsNV committed Oct 21, 2024
1 parent daca0f8 commit 92e703c
Showing 1 changed file with 106 additions and 0 deletions.
106 changes: 106 additions & 0 deletions src/sgl/device/tests/test_texture_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,5 +95,111 @@ def test_read_write_texture(
assert np.allclose(data, mip_data)


@pytest.mark.parametrize(
"type",
[
sgl.ResourceType.texture_1d,
sgl.ResourceType.texture_2d,
sgl.ResourceType.texture_3d,
],
)
@pytest.mark.parametrize("slices", [1, 4])
@pytest.mark.parametrize("mips", [0, 1, 4])
@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES)
def test_shader_read_write_texture(
device_type: sgl.DeviceType, slices: int, mips: int, type: sgl.ResourceType
):
device = helpers.get_device(device_type)
assert device is not None

# No 3d texture arrays.
if type == sgl.ResourceType.texture_3d and slices > 1:
return

# Skip 1d texture arrays until slang fix is in
if type == sgl.ResourceType.texture_1d and slices > 1:
pytest.skip("Pending slang crash using 1d texture array as UAV")

# Skip 3d textures with mips until slang fix is in
if type == sgl.ResourceType.texture_3d and mips != 1:
pytest.skip("Pending slang fix for 3d textures with mips")

# Create texture and build random data
src_tex = device.create_texture(**make_args(type, slices, mips))
dest_tex = device.create_texture(**make_args(type, slices, mips))
rand_data = make_rand_data(src_tex.type, src_tex.array_size, src_tex.mip_count)

# Write random data to texture
for slice_idx, slice_data in enumerate(rand_data):
for mip_idx, mip_data in enumerate(slice_data):
src_tex.from_numpy(mip_data, array_slice=slice_idx, mip_level=mip_idx)

for mip in range(src_tex.mip_count):
dims = len(rand_data[0][0].shape) - 1
if slices == 1:

COPY_TEXTURE_SHADER = f"""
[shader("compute")]
[numthreads(1, 1, 1)]
void copy_color(
uint{dims} tid: SV_DispatchThreadID,
Texture{dims}D<float4> src,
RWTexture{dims}D<float4> dest
)
{{
dest[tid] = src[tid];
}}
"""
module = device.load_module_from_source(
module_name=f"test_shader_read_write_texture_{slices}_{dims}",
source=COPY_TEXTURE_SHADER,
)
copy_kernel = device.create_compute_kernel(
device.link_program([module], [module.entry_point("copy_color")])
)

copy_kernel.dispatch(
[src_tex.width, src_tex.height, src_tex.depth],
src=src_tex.get_srv(mip),
dest=dest_tex.get_uav(mip),
)
else:

COPY_TEXTURE_SHADER = f"""
[shader("compute")]
[numthreads(1, 1, 1)]
void copy_color(
uint{dims} tid: SV_DispatchThreadID,
Texture{dims}DArray<float4> src,
RWTexture{dims}DArray<float4> dest,
uniform uint slice
)
{{
uint{dims+1} idx = uint{dims+1}(tid, slice);
dest[idx] = src[idx];
}}
"""
module = device.load_module_from_source(
module_name=f"test_shader_read_write_texture_{slices}_{dims}",
source=COPY_TEXTURE_SHADER,
)
copy_kernel = device.create_compute_kernel(
device.link_program([module], [module.entry_point("copy_color")])
)
for i in range(0, slices):
copy_kernel.dispatch(
[src_tex.width, src_tex.height, src_tex.depth],
src=src_tex.get_srv(mip),
dest=dest_tex.get_uav(mip),
slice=i,
)

# Read back data and compare
for slice_idx, slice_data in enumerate(rand_data):
for mip_idx, mip_data in enumerate(slice_data):
data = dest_tex.to_numpy(array_slice=slice_idx, mip_level=mip_idx)
assert np.allclose(data, mip_data)


if __name__ == "__main__":
pytest.main([__file__, "-v"])

0 comments on commit 92e703c

Please sign in to comment.