From 0ea4cb39b53a0e74fedad3a17ea70b958e6d8d55 Mon Sep 17 00:00:00 2001 From: "Christian Y. Brenninkmeijer" Date: Tue, 13 Aug 2024 15:41:41 +0100 Subject: [PATCH] iterate over more than one Vertex type --- pacman/data/pacman_data_view.py | 15 ++++--- pacman/model/placements/placements.py | 14 ++++--- .../placement_tests/test_placement_object.py | 41 +++++++++++++++++++ 3 files changed, 60 insertions(+), 10 deletions(-) diff --git a/pacman/data/pacman_data_view.py b/pacman/data/pacman_data_view.py index 9fd12d9df..abb987521 100644 --- a/pacman/data/pacman_data_view.py +++ b/pacman/data/pacman_data_view.py @@ -13,8 +13,8 @@ # limitations under the License. from __future__ import annotations import logging -from typing import (Iterable, List, Optional, Sequence, Type, TypeVar, - TYPE_CHECKING) +from typing import (Iterable, List, Optional, Sequence, Tuple, Type, TypeVar, + TYPE_CHECKING, Union) from spinn_utilities.log import FormatAdapter from spinn_utilities.typing.coords import XY @@ -352,11 +352,13 @@ def iterate_placemements(cls) -> Iterable[Placement]: @classmethod def iterate_placements_by_vertex_type( - cls, vertex_type: Type[VTX]) -> Iterable[Placement]: + cls, vertex_type: Union[ + Type[VTX], Tuple[Type[VTX]]]) -> Iterable[Placement]: """ Iterate over placements on any chip with this vertex_type. - :param type vertex_type: Class of vertex to find + :param vertex_type: Class of vertex to find + :type vertex_type: type or tuple(type) :rtype: iterable(Placement) :raises ~spinn_utilities.exceptions.SpiNNUtilsException: If the placements are currently unavailable @@ -382,11 +384,14 @@ def iterate_placements_on_core(cls, xy: XY) -> Iterable[Placement]: @classmethod def iterate_placements_by_xy_and_type( - cls, xy: XY, vertex_type: Type[VTX]) -> Iterable[Placement]: + cls, xy: XY, vertex_type: Union[ + Type[VTX], Tuple[Type[VTX]]]) -> Iterable[Placement]: """ Iterate over placements with this x, y and type. :param tuple(int, int) xy: x and y coordinates to find placements for. + :param vertex_type: Class of vertex to find + :type vertex_type: type or tuple(type) :param type vertex_type: Class of vertex to find :rtype: iterable(Placement) :raises ~spinn_utilities.exceptions.SpiNNUtilsException: diff --git a/pacman/model/placements/placements.py b/pacman/model/placements/placements.py index b90cd6e0f..96661c95d 100644 --- a/pacman/model/placements/placements.py +++ b/pacman/model/placements/placements.py @@ -13,7 +13,7 @@ # limitations under the License. from collections import defaultdict -from typing import Collection, Dict, Iterable, Iterator +from typing import Collection, Dict, Iterable, Iterator, Tuple, Union from spinn_utilities.typing.coords import XY @@ -152,12 +152,14 @@ def iterate_placements_on_core(self, xy: XY) -> Iterable[Placement]: return self._placements[xy].values() def iterate_placements_by_xy_and_type( - self, xy: XY, vertex_type: type) -> Iterable[Placement]: + self, xy: XY, + vertex_type: Union[type, Tuple[type]]) -> Iterable[Placement]: """ Iterate over placements with this x, y and this vertex_type. :param tuple(int, int) xy: x and y coordinate to find placements for. - :param class vertex_type: Class of vertex to find + :param vertex_type: Class of vertex to find + :type vertex_type: class or tuple(class) :rtype: iterable(Placement) """ for placement in self._placements[xy].values(): @@ -165,11 +167,13 @@ def iterate_placements_by_xy_and_type( yield placement def iterate_placements_by_vertex_type( - self, vertex_type: type) -> Iterable[Placement]: + self, + vertex_type: Union[type, Tuple[type]]) -> Iterable[Placement]: """ Iterate over placements on any chip with this vertex_type. - :param class vertex_type: Class of vertex to find + :param vertex_type: Class of vertex to find + :type vertex_type: type or tuple(type) :rtype: iterable(Placement) """ for placement in self._machine_vertices.values(): diff --git a/unittests/model_tests/placement_tests/test_placement_object.py b/unittests/model_tests/placement_tests/test_placement_object.py index 72a9e5abb..fe3e32f7e 100644 --- a/unittests/model_tests/placement_tests/test_placement_object.py +++ b/unittests/model_tests/placement_tests/test_placement_object.py @@ -17,10 +17,19 @@ """ import unittest from pacman.config_setup import unittest_setup +from pacman.data.pacman_data_writer import PacmanDataWriter from pacman.exceptions import PacmanAlreadyPlacedError from pacman.model.graphs.machine import SimpleMachineVertex from pacman.model.placements import Placement, Placements +class ExtendedVertex1(SimpleMachineVertex): + pass + +class ExtendedVertex2(SimpleMachineVertex): + pass + +class ExtendedVertex3(SimpleMachineVertex): + pass class TestPlacement(unittest.TestCase): """ @@ -59,5 +68,37 @@ def test_create_new_placements_duplicate_vertex(self): Placements(pl) + def test_iterate_by_type(self): + v1a = ExtendedVertex1(None) + p1a = Placement(v1a, 0, 1, 1) + v1b = ExtendedVertex1(None) + p1b = Placement(v1b, 0, 2, 1) + v2 = ExtendedVertex2(None) + p2 = Placement(v2, 0, 0, 2) + v3 = ExtendedVertex3(None) + p3 = Placement(v3, 0, 0, 3) + placements = Placements([p1a, p1b, p2, p3]) + l1 = list(placements.iterate_placements_by_vertex_type( + ExtendedVertex1)) + self.assertListEqual(l1, [p1a, p1b]) + l2 = list(placements.iterate_placements_by_vertex_type( + (ExtendedVertex2, ExtendedVertex1))) + self.assertListEqual(l2, [p1a, p1b, p2]) + l3 = list(placements.iterate_placements_by_xy_and_type( + (0,0), (ExtendedVertex3, ExtendedVertex2))) + self.assertListEqual(l3, [p2, p3]) + writer = PacmanDataWriter.setup() + writer.set_placements(placements) + l1 = list(writer.iterate_placements_by_vertex_type( + ExtendedVertex1)) + self.assertListEqual(l1, [p1a, p1b]) + l2 = list(writer.iterate_placements_by_vertex_type( + (ExtendedVertex2, ExtendedVertex1))) + self.assertListEqual(l2, [p1a, p1b, p2]) + l3 = list(writer.iterate_placements_by_xy_and_type( + (0,0), (ExtendedVertex3, ExtendedVertex2))) + self.assertListEqual(l3, [p2, p3]) + + if __name__ == '__main__': unittest.main()