From 95ba9d1d5da512791abe788377de5ad840c48f9d Mon Sep 17 00:00:00 2001 From: Andrew Rowley Date: Tue, 31 Oct 2023 09:44:11 +0000 Subject: [PATCH] Fix and test --- pacman/model/graphs/common/mdslice.py | 12 ++++++------ unittests/model_tests/test_mdslice.py | 7 +++++++ 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/pacman/model/graphs/common/mdslice.py b/pacman/model/graphs/common/mdslice.py index c157f501b..75cf3cbef 100644 --- a/pacman/model/graphs/common/mdslice.py +++ b/pacman/model/graphs/common/mdslice.py @@ -159,23 +159,23 @@ def from_string(cls, as_str): def get_relative_indices(self, app_vertex_indices): n_dims = len(self._atoms_shape) remainders = app_vertex_indices - last_per_core = 0 + cum_last_core = 1 rel_index = numpy.zeros(len(app_vertex_indices)) - for n in n_dims: + for n in range(n_dims): # Work out the index in this dimension global_index_d = remainders % self._atoms_shape[n] # Work out the index in this dimension relative to the core start - rel_index_d = global_index_d - self._starts[n] + rel_index_d = global_index_d - self._start[n] # Update the total relative index using the position in this # dimension - rel_index = (rel_index * last_per_core) + rel_index_d + rel_index += rel_index_d * cum_last_core # Prepare for next round of the loop by removing what we used # of the global index and remembering the sizes in this # dimension remainders = remainders // self._atoms_shape[n] - last_per_core = self._shape[n] + cum_last_core *= self._shape[n] - return rel_index + return rel_index.astype(numpy.uint32) diff --git a/unittests/model_tests/test_mdslice.py b/unittests/model_tests/test_mdslice.py index 3352f59cb..d00a10218 100644 --- a/unittests/model_tests/test_mdslice.py +++ b/unittests/model_tests/test_mdslice.py @@ -76,3 +76,10 @@ def test_3b(self): list(s.get_raster_ids())) s2 = MDSlice.from_string(str(s)) self.assertEqual(s, s2) + + def test_get_relative_indices(self): + s = MDSlice(22, 89, (2, 3, 2), (4, 3, 0), (6, 9, 4)) + # Going over the raster IDs should result in a line over the core + self.assertListEqual(list(range(2 * 3 * 2)), + list((s.get_relative_indices( + s.get_raster_ids()))))