Skip to content

Commit

Permalink
iterate over more than one Vertex type
Browse files Browse the repository at this point in the history
  • Loading branch information
Christian-B committed Aug 13, 2024
1 parent 23c51d2 commit 0ea4cb3
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 10 deletions.
15 changes: 10 additions & 5 deletions pacman/data/pacman_data_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
# limitations under the License.
from __future__ import annotations
import logging
from typing import (Iterable, List, Optional, Sequence, Type, TypeVar,
TYPE_CHECKING)
from typing import (Iterable, List, Optional, Sequence, Tuple, Type, TypeVar,
TYPE_CHECKING, Union)

from spinn_utilities.log import FormatAdapter
from spinn_utilities.typing.coords import XY
Expand Down Expand Up @@ -352,11 +352,13 @@ def iterate_placemements(cls) -> Iterable[Placement]:

@classmethod
def iterate_placements_by_vertex_type(
cls, vertex_type: Type[VTX]) -> Iterable[Placement]:
cls, vertex_type: Union[
Type[VTX], Tuple[Type[VTX]]]) -> Iterable[Placement]:
"""
Iterate over placements on any chip with this vertex_type.
:param type vertex_type: Class of vertex to find
:param vertex_type: Class of vertex to find
:type vertex_type: type or tuple(type)
:rtype: iterable(Placement)
:raises ~spinn_utilities.exceptions.SpiNNUtilsException:
If the placements are currently unavailable
Expand All @@ -382,11 +384,14 @@ def iterate_placements_on_core(cls, xy: XY) -> Iterable[Placement]:

@classmethod
def iterate_placements_by_xy_and_type(
cls, xy: XY, vertex_type: Type[VTX]) -> Iterable[Placement]:
cls, xy: XY, vertex_type: Union[
Type[VTX], Tuple[Type[VTX]]]) -> Iterable[Placement]:
"""
Iterate over placements with this x, y and type.
:param tuple(int, int) xy: x and y coordinates to find placements for.
:param vertex_type: Class of vertex to find
:type vertex_type: type or tuple(type)
:param type vertex_type: Class of vertex to find
:rtype: iterable(Placement)
:raises ~spinn_utilities.exceptions.SpiNNUtilsException:
Expand Down
14 changes: 9 additions & 5 deletions pacman/model/placements/placements.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from collections import defaultdict
from typing import Collection, Dict, Iterable, Iterator
from typing import Collection, Dict, Iterable, Iterator, Tuple, Union

from spinn_utilities.typing.coords import XY

Expand Down Expand Up @@ -152,24 +152,28 @@ def iterate_placements_on_core(self, xy: XY) -> Iterable[Placement]:
return self._placements[xy].values()

def iterate_placements_by_xy_and_type(
self, xy: XY, vertex_type: type) -> Iterable[Placement]:
self, xy: XY,
vertex_type: Union[type, Tuple[type]]) -> Iterable[Placement]:
"""
Iterate over placements with this x, y and this vertex_type.
:param tuple(int, int) xy: x and y coordinate to find placements for.
:param class vertex_type: Class of vertex to find
:param vertex_type: Class of vertex to find
:type vertex_type: class or tuple(class)
:rtype: iterable(Placement)
"""
for placement in self._placements[xy].values():
if isinstance(placement.vertex, vertex_type):
yield placement

def iterate_placements_by_vertex_type(
self, vertex_type: type) -> Iterable[Placement]:
self,
vertex_type: Union[type, Tuple[type]]) -> Iterable[Placement]:
"""
Iterate over placements on any chip with this vertex_type.
:param class vertex_type: Class of vertex to find
:param vertex_type: Class of vertex to find
:type vertex_type: type or tuple(type)
:rtype: iterable(Placement)
"""
for placement in self._machine_vertices.values():
Expand Down
41 changes: 41 additions & 0 deletions unittests/model_tests/placement_tests/test_placement_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,19 @@
"""
import unittest
from pacman.config_setup import unittest_setup
from pacman.data.pacman_data_writer import PacmanDataWriter
from pacman.exceptions import PacmanAlreadyPlacedError
from pacman.model.graphs.machine import SimpleMachineVertex
from pacman.model.placements import Placement, Placements

class ExtendedVertex1(SimpleMachineVertex):
pass

class ExtendedVertex2(SimpleMachineVertex):
pass

class ExtendedVertex3(SimpleMachineVertex):
pass

class TestPlacement(unittest.TestCase):
"""
Expand Down Expand Up @@ -59,5 +68,37 @@ def test_create_new_placements_duplicate_vertex(self):
Placements(pl)


def test_iterate_by_type(self):
v1a = ExtendedVertex1(None)
p1a = Placement(v1a, 0, 1, 1)
v1b = ExtendedVertex1(None)
p1b = Placement(v1b, 0, 2, 1)
v2 = ExtendedVertex2(None)
p2 = Placement(v2, 0, 0, 2)
v3 = ExtendedVertex3(None)
p3 = Placement(v3, 0, 0, 3)
placements = Placements([p1a, p1b, p2, p3])
l1 = list(placements.iterate_placements_by_vertex_type(
ExtendedVertex1))
self.assertListEqual(l1, [p1a, p1b])
l2 = list(placements.iterate_placements_by_vertex_type(
(ExtendedVertex2, ExtendedVertex1)))
self.assertListEqual(l2, [p1a, p1b, p2])
l3 = list(placements.iterate_placements_by_xy_and_type(
(0,0), (ExtendedVertex3, ExtendedVertex2)))
self.assertListEqual(l3, [p2, p3])
writer = PacmanDataWriter.setup()
writer.set_placements(placements)
l1 = list(writer.iterate_placements_by_vertex_type(
ExtendedVertex1))
self.assertListEqual(l1, [p1a, p1b])
l2 = list(writer.iterate_placements_by_vertex_type(
(ExtendedVertex2, ExtendedVertex1)))
self.assertListEqual(l2, [p1a, p1b, p2])
l3 = list(writer.iterate_placements_by_xy_and_type(
(0,0), (ExtendedVertex3, ExtendedVertex2)))
self.assertListEqual(l3, [p2, p3])


if __name__ == '__main__':
unittest.main()

0 comments on commit 0ea4cb3

Please sign in to comment.