diff --git a/pacman/data/pacman_data_view.py b/pacman/data/pacman_data_view.py index 9fd12d9df..588cc0a25 100644 --- a/pacman/data/pacman_data_view.py +++ b/pacman/data/pacman_data_view.py @@ -13,8 +13,8 @@ # limitations under the License. from __future__ import annotations import logging -from typing import (Iterable, List, Optional, Sequence, Type, TypeVar, - TYPE_CHECKING) +from typing import (Iterable, List, Optional, Sequence, Tuple, Type, TypeVar, + TYPE_CHECKING, Union) from spinn_utilities.log import FormatAdapter from spinn_utilities.typing.coords import XY @@ -352,11 +352,13 @@ def iterate_placemements(cls) -> Iterable[Placement]: @classmethod def iterate_placements_by_vertex_type( - cls, vertex_type: Type[VTX]) -> Iterable[Placement]: + cls, vertex_type: Union[ + type, Tuple[type, ...]]) -> Iterable[Placement]: """ Iterate over placements on any chip with this vertex_type. - :param type vertex_type: Class of vertex to find + :param vertex_type: Class of vertex to find + :type vertex_type: type or tuple(type) :rtype: iterable(Placement) :raises ~spinn_utilities.exceptions.SpiNNUtilsException: If the placements are currently unavailable @@ -382,11 +384,14 @@ def iterate_placements_on_core(cls, xy: XY) -> Iterable[Placement]: @classmethod def iterate_placements_by_xy_and_type( - cls, xy: XY, vertex_type: Type[VTX]) -> Iterable[Placement]: + cls, xy: XY, vertex_type: Union[ + type, Tuple[type, ...]]) -> Iterable[Placement]: """ Iterate over placements with this x, y and type. :param tuple(int, int) xy: x and y coordinates to find placements for. + :param vertex_type: Class of vertex to find + :type vertex_type: type or tuple(type) :param type vertex_type: Class of vertex to find :rtype: iterable(Placement) :raises ~spinn_utilities.exceptions.SpiNNUtilsException: diff --git a/pacman/model/placements/placements.py b/pacman/model/placements/placements.py index b90cd6e0f..6ec12418f 100644 --- a/pacman/model/placements/placements.py +++ b/pacman/model/placements/placements.py @@ -13,7 +13,7 @@ # limitations under the License. from collections import defaultdict -from typing import Collection, Dict, Iterable, Iterator +from typing import Collection, Dict, Iterable, Iterator, Tuple, Union from spinn_utilities.typing.coords import XY @@ -152,12 +152,14 @@ def iterate_placements_on_core(self, xy: XY) -> Iterable[Placement]: return self._placements[xy].values() def iterate_placements_by_xy_and_type( - self, xy: XY, vertex_type: type) -> Iterable[Placement]: + self, xy: XY, vertex_type: Union[ + type, Tuple[type, ...]]) -> Iterable[Placement]: """ Iterate over placements with this x, y and this vertex_type. :param tuple(int, int) xy: x and y coordinate to find placements for. - :param class vertex_type: Class of vertex to find + :param vertex_type: Class of vertex to find + :type vertex_type: class or tuple(class) :rtype: iterable(Placement) """ for placement in self._placements[xy].values(): @@ -165,11 +167,13 @@ def iterate_placements_by_xy_and_type( yield placement def iterate_placements_by_vertex_type( - self, vertex_type: type) -> Iterable[Placement]: + self, vertex_type: Union[ + type, Tuple[type, ...]]) -> Iterable[Placement]: """ Iterate over placements on any chip with this vertex_type. - :param class vertex_type: Class of vertex to find + :param vertex_type: Class of vertex to find + :type vertex_type: type or tuple(type) :rtype: iterable(Placement) """ for placement in self._machine_vertices.values(): diff --git a/unittests/model_tests/placement_tests/test_placement_object.py b/unittests/model_tests/placement_tests/test_placement_object.py index 72a9e5abb..a8fd3933a 100644 --- a/unittests/model_tests/placement_tests/test_placement_object.py +++ b/unittests/model_tests/placement_tests/test_placement_object.py @@ -17,11 +17,24 @@ """ import unittest from pacman.config_setup import unittest_setup +from pacman.data.pacman_data_writer import PacmanDataWriter from pacman.exceptions import PacmanAlreadyPlacedError from pacman.model.graphs.machine import SimpleMachineVertex from pacman.model.placements import Placement, Placements +class ExtendedVertex1(SimpleMachineVertex): + pass + + +class ExtendedVertex2(SimpleMachineVertex): + pass + + +class ExtendedVertex3(SimpleMachineVertex): + pass + + class TestPlacement(unittest.TestCase): """ tester for placement object in pacman.model.placements.placement @@ -58,6 +71,37 @@ def test_create_new_placements_duplicate_vertex(self): with self.assertRaises(PacmanAlreadyPlacedError): Placements(pl) + def test_iterate_by_type(self): + v1a = ExtendedVertex1(None) + p1a = Placement(v1a, 0, 1, 1) + v1b = ExtendedVertex1(None) + p1b = Placement(v1b, 0, 2, 1) + v2 = ExtendedVertex2(None) + p2 = Placement(v2, 0, 0, 2) + v3 = ExtendedVertex3(None) + p3 = Placement(v3, 0, 0, 3) + placements = Placements([p1a, p1b, p2, p3]) + l1 = list(placements.iterate_placements_by_vertex_type( + ExtendedVertex1)) + self.assertListEqual(l1, [p1a, p1b]) + l2 = list(placements.iterate_placements_by_vertex_type( + (ExtendedVertex2, ExtendedVertex1))) + self.assertListEqual(l2, [p1a, p1b, p2]) + l3 = list(placements.iterate_placements_by_xy_and_type( + (0, 0), (ExtendedVertex3, ExtendedVertex2))) + self.assertListEqual(l3, [p2, p3]) + writer = PacmanDataWriter.setup() + writer.set_placements(placements) + l1 = list(writer.iterate_placements_by_vertex_type( + ExtendedVertex1)) + self.assertListEqual(l1, [p1a, p1b]) + l2 = list(writer.iterate_placements_by_vertex_type( + (ExtendedVertex2, ExtendedVertex1))) + self.assertListEqual(l2, [p1a, p1b, p2]) + l3 = list(writer.iterate_placements_by_xy_and_type( + (0, 0), (ExtendedVertex3, ExtendedVertex2))) + self.assertListEqual(l3, [p2, p3]) + if __name__ == '__main__': unittest.main()