diff --git a/.github/workflows/python_actions.yml b/.github/workflows/python_actions.yml index 511bd214..4c389c90 100644 --- a/.github/workflows/python_actions.yml +++ b/.github/workflows/python_actions.yml @@ -26,5 +26,6 @@ jobs: coverage-package: spinn_utilities flake8-packages: spinn_utilities unittests pylint-packages: spinn_utilities - mypy-packages: spinn_utilities unittests + mypy-full_packages: spinn_utilities + mypy-packages: unittests secrets: inherit diff --git a/mypy.bash b/mypy.bash new file mode 100755 index 00000000..fe943d61 --- /dev/null +++ b/mypy.bash @@ -0,0 +1,23 @@ +#!/bin/bash + +# Copyright (c) 2024 The University of Manchester +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This bash assumes that other repositories are installed in paralled + +# requires the latest mypy +# pip install --upgrade mypy + + +mypy --python-version 3.8 --disallow-untyped-defs spinn_utilities unittests diff --git a/mypyd.bash b/mypyd.bash new file mode 100755 index 00000000..25be44d9 --- /dev/null +++ b/mypyd.bash @@ -0,0 +1,23 @@ +#!/bin/bash + +# Copyright (c) 2024 The University of Manchester +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This bash assumes that other repositories are installed in paralled + +# requires the latest mypy +# pip install --upgrade mypy + + +mypy --python-version 3.8 --disallow-untyped-defs spinn_utilities diff --git a/spinn_utilities/abstract_base.py b/spinn_utilities/abstract_base.py index 8116c79f..fecc73e1 100644 --- a/spinn_utilities/abstract_base.py +++ b/spinn_utilities/abstract_base.py @@ -14,7 +14,7 @@ """ A trimmed down version of standard Python Abstract Base classes. """ -from typing import TypeVar +from typing import Any, Dict, Type, TypeVar, Tuple #: :meta private: T = TypeVar("T") @@ -58,7 +58,8 @@ def my_abstract_method(self, ...): ... """ - def __new__(mcs, name, bases, namespace, **kwargs): + def __new__(mcs, name: str, bases: Tuple[Type, ...], + namespace: Dict[str, Any], **kwargs: Any) -> "AbstractBase": # Actually make the class abs_cls = super().__new__(mcs, name, bases, namespace, **kwargs) @@ -74,5 +75,6 @@ def __new__(mcs, name, bases, namespace, **kwargs): abstracts.add(nm) # Lock down the set - abs_cls.__abstractmethods__ = frozenset(abstracts) + abs_cls.__abstractmethods__ = frozenset( # type: ignore[attr-defined] + abstracts) return abs_cls diff --git a/spinn_utilities/citation/citation_aggregator.py b/spinn_utilities/citation/citation_aggregator.py index 0cd8de0a..5cf250b2 100644 --- a/spinn_utilities/citation/citation_aggregator.py +++ b/spinn_utilities/citation/citation_aggregator.py @@ -17,6 +17,8 @@ import io import importlib import argparse +from types import ModuleType +from typing import Any, Dict, List, Optional, Set, Union import sys from .citation_updater_and_doi_generator import CitationUpdaterAndDoiGenerator @@ -43,6 +45,8 @@ # pylint: skip-file +_SEEN_TYPE = Set[Union[ModuleType, str, None]] + class CitationAggregator(object): """ @@ -51,38 +55,40 @@ class CitationAggregator(object): """ def create_aggregated_citation_file( - self, module_to_start_at, aggregated_citation_file): + self, module_to_start_at: ModuleType, + aggregated_citation_file: str) -> None: """ Entrance method for building the aggregated citation file. :param module_to_start_at: the top level module to figure out its citation file for - :type module_to_start_at: python module - :param str aggregated_citation_file: + :param aggregated_citation_file: file name of aggregated citation file """ # get the top citation file to add references to + module_file: Optional[str] = module_to_start_at.__file__ + assert module_file is not None top_citation_file_path = os.path.join(os.path.dirname(os.path.dirname( - os.path.abspath(module_to_start_at.__file__))), CITATION_FILE) - modules_seen_so_far = set() - modules_seen_so_far.add("") # Make sure the empty entry is absent + os.path.abspath(module_file))), CITATION_FILE) + modules_seen_so_far: _SEEN_TYPE = set() + modules_seen_so_far.add("") with open(top_citation_file_path, encoding=ENCODING) as stream: - top_citation_file = yaml.safe_load(stream) + top_citation_file: Dict[str, Any] = yaml.safe_load( + stream) top_citation_file[REFERENCES_YAML_POINTER] = list() # get the dependency list requirements_file_path = os.path.join( os.path.dirname(os.path.dirname(os.path.abspath( - module_to_start_at.__file__))), REQUIREMENTS_FILE) + module_file))), REQUIREMENTS_FILE) c_requirements_file_path = os.path.join( os.path.dirname(os.path.dirname(os.path.abspath( - module_to_start_at.__file__))), C_REQUIREMENTS_FILE) + module_file))), C_REQUIREMENTS_FILE) # attempt to get python PYPI to import command map pypi_to_import_map_file = os.path.join( os.path.dirname(os.path.dirname(os.path.abspath( - module_to_start_at.__file__))), - PYPI_TO_IMPORT_FILE) + module_file))), PYPI_TO_IMPORT_FILE) pypi_to_import_map = None if os.path.isfile(pypi_to_import_map_file): pypi_to_import_map = self._read_pypi_import_map( @@ -95,6 +101,7 @@ def create_aggregated_citation_file( if module.startswith("#"): continue if module not in modules_seen_so_far: + assert pypi_to_import_map is not None import_name = pypi_to_import_map.get(module, module) # pylint: disable=broad-except try: @@ -125,7 +132,7 @@ def create_aggregated_citation_file( allow_unicode=True) @staticmethod - def _read_pypi_import_map(aggregated_citation_file): + def _read_pypi_import_map(aggregated_citation_file: str) -> Dict[str, str]: """ Read the PYPI to import name map. @@ -133,7 +140,7 @@ def _read_pypi_import_map(aggregated_citation_file): :return: map between PYPI names and import names :rtype: dict(str,str) """ - pypi_to_import_map = dict() + pypi_to_import_map: Dict[str, str] = dict() with open(aggregated_citation_file, encoding=ENCODING) as f: for line in f: [pypi, import_command] = line.split(":") @@ -141,7 +148,8 @@ def _read_pypi_import_map(aggregated_citation_file): return pypi_to_import_map def _handle_c_dependency( - self, top_citation_file, module, modules_seen_so_far): + self, top_citation_file: Dict[str, Any], module: str, + modules_seen_so_far: _SEEN_TYPE) -> None: """ Handle a C code dependency. @@ -164,7 +172,7 @@ def _handle_c_dependency( print(f"Could not find C dependency {module}") @staticmethod - def locate_path_for_c_dependency(true_software_name): + def locate_path_for_c_dependency(true_software_name: str) -> Optional[str]: """ :param str true_software_name: :rtype: str or None @@ -187,15 +195,16 @@ def locate_path_for_c_dependency(true_software_name): return None def _search_for_other_c_references( - self, reference_entry, software_path, modules_seen_so_far): + self, reference_entry: Dict[str, Any], software_path: str, + modules_seen_so_far: _SEEN_TYPE) -> None: """ - Go though the top level path and tries to locate other CFF + Go through the top level path and tries to locate other CFF files that need to be added to the references pile. - :param dict(str,list(str)) reference_entry: + :param reference_entry: The reference entry to add new dependencies as references for. - :param str software_path: the path to search in - :param set(str) modules_seen_so_far: + :param software_path: the path to search in + :param modules_seen_so_far: """ for possible_extra_citation_file in os.listdir(software_path): if possible_extra_citation_file.endswith(".cff"): @@ -210,23 +219,26 @@ def _search_for_other_c_references( possible_extra_citation_file.split(".")[0]) def _handle_python_dependency( - self, top_citation_file, imported_module, modules_seen_so_far, - module_name): + self, top_citation_file: Dict[str, Any], + imported_module: ModuleType, modules_seen_so_far: _SEEN_TYPE, + module_name: str) -> None: """ Handle a python dependency. - :param dict(str,list(str)) top_citation_file: + :param top_citation_file: YAML file for the top citation file :param imported_module: the actual imported module :type imported_module: ModuleType - :param set(str) modules_seen_so_far: + :param modules_seen_so_far: list of names of dependencies already processed - :param str module_name: + :param module_name: the name of this module to consider as a dependency :raises FileNotFoundError: """ # get modules citation file - citation_level_dir = os.path.abspath(imported_module.__file__) + module_path = imported_module.__file__ + assert module_path is not None + citation_level_dir = os.path.abspath(module_path) m_path = module_name.replace(".", os.sep) last_citation_level_dir = None while (not citation_level_dir.endswith(m_path) and @@ -247,19 +259,19 @@ def _handle_python_dependency( top_citation_file[REFERENCES_YAML_POINTER].append(reference_entry) def _process_reference( - self, citation_level_dir, imported_module, modules_seen_so_far, - module_name): + self, citation_level_dir: str, + imported_module: Optional[ModuleType], + modules_seen_so_far: _SEEN_TYPE, + module_name: str) -> Dict[str, Any]: """ Take a module level and tries to locate and process a citation file. - :param str citation_level_dir: + :param citation_level_dir: the expected level where the ``CITATION.cff`` should be :param imported_module: the module after being imported - :type imported_module: python module - :param set(str) modules_seen_so_far: + :param modules_seen_so_far: list of dependencies already processed :return: the reference entry in JSON format - :rtype: dict """ # if it exists, add it as a reference to the top one if os.path.isfile(os.path.join(citation_level_dir, CITATION_FILE)): @@ -285,7 +297,9 @@ def _process_reference( return reference_entry @staticmethod - def _try_to_find_version(imported_module, module_name): + def _try_to_find_version( + imported_module: Optional[ModuleType], + module_name: str) -> Dict[str, Any]: """ Try to locate a version file or version data to auto-generate minimal citation data. @@ -296,7 +310,7 @@ def _try_to_find_version(imported_module, module_name): :return: reference entry for this python module :rtype: dict """ - reference_entry = dict() + reference_entry: Dict[str, Any] = dict() reference_entry[REFERENCES_TYPE_TYPE] = REFERENCES_SOFTWARE_TYPE reference_entry[REFERENCES_TITLE_TYPE] = module_name if imported_module is None: @@ -323,15 +337,15 @@ def _try_to_find_version(imported_module, module_name): return reference_entry @staticmethod - def _read_and_process_reference_entry(dependency_citation_file_path): + def _read_and_process_reference_entry( + dependency_citation_file_path: str) -> Dict[str, Any]: """ Read a ``CITATION.cff`` and makes it a reference for a higher level citation file. - :param str dependency_citation_file_path: + :param dependency_citation_file_path: path to a `CITATION.cff` file :return: reference entry for the higher level `CITATION.cff` - :rtype: dict """ reference_entry = dict() @@ -357,7 +371,7 @@ def _read_and_process_reference_entry(dependency_citation_file_path): return reference_entry -def generate_aggregate(arguments=None): +def generate_aggregate(arguments: Optional[List[str]] = None) -> None: """ Command-line tool to generate a single ``citation.cff`` from others. diff --git a/spinn_utilities/citation/citation_updater_and_doi_generator.py b/spinn_utilities/citation/citation_updater_and_doi_generator.py index e7a1a407..f084dfe7 100644 --- a/spinn_utilities/citation/citation_updater_and_doi_generator.py +++ b/spinn_utilities/citation/citation_updater_and_doi_generator.py @@ -19,6 +19,9 @@ import unicodedata import os from time import strptime +from typing import Any, cast, Dict, List, Optional, Tuple, Union + +from spinn_utilities.typing.json import JsonObject CITATION_FILE_VERSION_FIELD = "version" CITATION_FILE_DATE_FIELD = "date-released" @@ -52,7 +55,8 @@ class _ZenodoException(Exception): Exception from a call to Zenodo. """ - def __init__(self, operation, expected, request): + def __init__( + self, operation: str, expected: int, request: requests.Response): super().__init__( "don't know what went wrong. got wrong status code when trying " f"to {operation}. Got error code {request.status_code} " @@ -82,17 +86,18 @@ class _Zenodo(object): _VALID_STATUS_REQUEST_POST = 201 _VALID_STATUS_REQUEST_PUBLISH = 202 - def __init__(self, token): + def __init__(self, token: str): self.__zenodo_token = token @staticmethod - def _json(r): + def _json(r: requests.Response) -> Optional[JsonObject]: try: return r.json() except Exception: # pylint: disable=broad-except return None - def get_verify(self, related): + def get_verify( + self, related: List[Dict[str, str]]) -> Optional[JsonObject]: r = requests.get( self._DEPOSIT_GET_URL, timeout=10, params={self._ACCESS_TOKEN: self.__zenodo_token, @@ -103,7 +108,8 @@ def get_verify(self, related): "request a DOI", self._VALID_STATUS_REQUEST_GET, r) return self._json(r) - def post_create(self, related): + def post_create( + self, related: List[Dict[str, str]]) -> Optional[JsonObject]: r = requests.post( self._DEPOSIT_GET_URL, timeout=10, params={self._ACCESS_TOKEN: self.__zenodo_token, @@ -114,7 +120,9 @@ def post_create(self, related): "get an empty upload", self._VALID_STATUS_REQUEST_POST, r) return self._json(r) - def post_upload(self, deposit_id, data, files): + def post_upload( + self, deposit_id: str, data: Dict[str, Any], + files: Dict[str, io.BufferedReader]) -> Optional[JsonObject]: r = requests.post( self._DEPOSIT_PUT_URL.format(deposit_id), timeout=10, params={self._ACCESS_TOKEN: self.__zenodo_token}, @@ -125,7 +133,7 @@ def post_upload(self, deposit_id, data, files): self._VALID_STATUS_REQUEST_POST, r) return self._json(r) - def post_publish(self, deposit_id): + def post_publish(self, deposit_id: str) -> Optional[JsonObject]: r = requests.post( self._PUBLISH_URL.format(deposit_id), timeout=10, params={self._ACCESS_TOKEN: self.__zenodo_token}) @@ -136,31 +144,32 @@ def post_publish(self, deposit_id): class CitationUpdaterAndDoiGenerator(object): - def __init__(self): - self.__zenodo = None + def __init__(self) -> None: + self.__zenodo: Optional[_Zenodo] = None def update_citation_file_and_create_doi( - self, citation_file_path, doi_title, create_doi, publish_doi, - previous_doi, zenodo_access_token, module_path): + self, citation_file_path: str, doi_title: str, create_doi: bool, + publish_doi: bool, previous_doi: str, zenodo_access_token: str, + module_path: str) -> None: """ Take a CITATION.cff file and updates the version and date-released fields, and rewrites the ``CITATION.cff`` file. - :param str citation_file_path: File path to the ``CITATION.cff`` file - :param bool create_doi: + :param citation_file_path: File path to the ``CITATION.cff`` file + :param create_doi: Whether to use Zenodo DOI interface to grab a DOI - :param str zenodo_access_token: Access token for Zenodo - :param bool publish_doi: Whether to publish the DOI on Zenodo - :param str previous_doi: DOI to append the created DOI to - :param str doi_title: Title for the created DOI - :param str module_path: Path to the module to zip up - :param bool update_version: + :param zenodo_access_token: Access token for Zenodo + :param publish_doi: Whether to publish the DOI on Zenodo + :param previous_doi: DOI to append the created DOI to + :param doi_title: Title for the created DOI + :param module_path: Path to the module to zip up + :param update_version: Whether we should update the citation version """ self.__zenodo = _Zenodo(zenodo_access_token) # data holders - deposit_id = None + deposit_id: Optional[str] = None # read in YAML file with open(citation_file_path, 'r', encoding="utf-8") as stream: @@ -178,17 +187,17 @@ def update_citation_file_and_create_doi( # if creating a DOI, finish the request and possibly publish it if create_doi: + assert deposit_id is not None self._finish_doi( deposit_id, publish_doi, doi_title, yaml_file[CITATION_FILE_DESCRIPTION], yaml_file, module_path) - def _request_doi(self, previous_doi): + def _request_doi(self, previous_doi: str) -> Tuple[bytes, Any]: """ Go to Zenodo and requests a DOI. - :param str previous_doi: the previous DOI for this module, if exists + :param previous_doi: the previous DOI for this module, if exists :return: the DOI id, and deposit id - :rtype: tuple(str, str) """ # create link to previous version (if applicable) related = list() @@ -197,34 +206,40 @@ def _request_doi(self, previous_doi): IDENTIFIER: previous_doi}) # get a request for a DOI + assert self.__zenodo is not None self.__zenodo.get_verify(related) # get empty upload request_data = self.__zenodo.post_create(related) + assert request_data is not None # get DOI and deposit id + metadata = cast(Dict[str, Dict[str, str]], + request_data[ZENODO_METADATA]) doi_id = unicodedata.normalize( 'NFKD', - (request_data[ZENODO_METADATA][ZENODO_PRE_RESERVED_DOI] + (metadata[ZENODO_PRE_RESERVED_DOI] [ZENODO_DOI_VALUE])).encode('ascii', 'ignore') deposition_id = request_data[ZENODO_DEPOSIT_ID] return doi_id, deposition_id def _finish_doi( - self, deposit_id, publish_doi, title, - doi_description, yaml_file, module_path): + self, deposit_id: str, publish_doi: bool, title: str, + doi_description: str, yaml_file: Dict[str, Any], + module_path: str) -> None: """ Finishes the DOI on Zenodo. - :param str deposit_id: the deposit id to publish - :param bool publish_doi: whether we should publish the DOI - :param str title: the title of this DOI - :param str doi_description: the description for the DOI + :param deposit_id: the deposit id to publish + :param publish_doi: whether we should publish the DOI + :param title: the title of this DOI + :param doi_description: the description for the DOI :param yaml_file: the citation file after its been read it :param module_path: the path to the module to DOI """ zipped_file = None + assert self.__zenodo is not None try: zipped_file = self._zip_up_module(module_path) with open(zipped_file, "rb") as zipped_open_file: @@ -239,11 +254,11 @@ def _finish_doi( if publish_doi: self.__zenodo.post_publish(deposit_id) - def _zip_up_module(self, module_path): + def _zip_up_module(self, module_path: str) -> str: """ Zip up a module. - :param str module_path: the path to the module to zip up + :param module_path: the path to the module to zip up :return: the filename to the zip file """ if os.path.isfile('module.zip'): @@ -259,14 +274,15 @@ def _zip_up_module(self, module_path): return 'module.zip' @staticmethod - def _zip_walker(module_path, avoids, module_zip_file): + def _zip_walker(module_path: str, avoids: List[str], + module_zip_file: zipfile.ZipFile) -> None: """ Traverse the module and its sub-directories and only adds to the files to the zip which are not within a avoid directory that. - :param str module_path: the path to start the search at - :param set(str) avoids: the set of avoids to avoid - :param ~zipfile.ZipFile module_zip_file: the zip file to put into + :param module_path: the path to start the search at + :param avoids: the set of avoids to avoid + :param module_zip_file: the zip file to put into """ for directory_path, _, files in os.walk(module_path): for directory_name in directory_path.split(os.sep): @@ -280,7 +296,8 @@ def _zip_walker(module_path, avoids, module_zip_file): os.path.join(directory_path, potential_zip_file)) @staticmethod - def _fill_in_data(doi_title, doi_description, yaml_file): + def _fill_in_data(doi_title: str, doi_description: str, + yaml_file: Dict[str, Any]) -> Dict[str, Any]: """ Add in data to the Zenodo metadata. @@ -291,7 +308,7 @@ def _fill_in_data(doi_title, doi_description, yaml_file): :rtype: dict """ # add basic meta data - metadata = { + metadata: Dict[str, Any] = { ZENODO_METADATA_TITLE: doi_title, ZENODO_METATDATA_DESC: doi_description, ZENODO_METADATA_CREATORS: [] @@ -313,16 +330,15 @@ def _fill_in_data(doi_title, doi_description, yaml_file): @staticmethod def convert_text_date_to_date( - version_month, version_year, version_day): + version_month: Union[int, str], version_year: Union[int, str], + version_day: Union[int, str]) -> str: """ Convert the 3 components of a date into a CFF date. :param version_month: version month, in text form - :type version_month: str or int - :param int version_year: version year - :param int version_day: version day of month + :param version_year: version year + :param version_day: version day of month :return: the string representation for the CFF file - :rtype: str """ return "{}-{}-{}".format( version_year, @@ -331,14 +347,12 @@ def convert_text_date_to_date( version_day) @staticmethod - def convert_month_name_to_number(version_month): + def convert_month_name_to_number(version_month: Union[int, str]) -> int: """ Convert a python month in text form to a number form. :param version_month: the text form of the month - :type version_month: str or int :return: the month int value - :rtype: int :raises ValueError: when the month name is not recognised """ if isinstance(version_month, int): diff --git a/spinn_utilities/classproperty.py b/spinn_utilities/classproperty.py index fc259ce2..c5a845b8 100644 --- a/spinn_utilities/classproperty.py +++ b/spinn_utilities/classproperty.py @@ -12,22 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Callable, Optional, Type + class _ClassPropertyDescriptor(object): """ A class to handle the management of class properties. """ - def __init__(self, fget): + def __init__(self, fget: Callable) -> None: self.fget = fget - def __get__(self, obj, klass=None): + def __get__( + self, obj: Optional[Any], klass: Optional[Type] = None) -> Any: if klass is None: klass = type(obj) return self.fget.__get__(obj, klass)() -def classproperty(func): +def classproperty(func: Callable) -> _ClassPropertyDescriptor: """ Defines a property at the class-level. @@ -41,6 +44,7 @@ def my_property(cls): return cls._my_property """ if not isinstance(func, (classmethod, staticmethod)): - func = classmethod(func) + # mypy claims expression has type "classmethod ... + func = classmethod(func) # type: ignore[assignment] return _ClassPropertyDescriptor(func) diff --git a/spinn_utilities/conf_loader.py b/spinn_utilities/conf_loader.py index f1391541..0dd9d097 100644 --- a/spinn_utilities/conf_loader.py +++ b/spinn_utilities/conf_loader.py @@ -41,13 +41,13 @@ def install_cfg_and_error( It will create a file in the users home directory based on the defaults. Then it prints a helpful message and throws an error with the same message. - :param str filename: + :param filename: Name under which to save the new configuration file - :param list(str) defaults: + :param defaults: List of full paths to the default configuration files. Each of which *must* have an associated template file with exactly the same path plus `.template`. - :param list(str) config_locations: + :param config_locations: List of paths where the user configuration files were looked for. Only used for the message :raise spinn_utilities.configs.NoConfigFoundException: @@ -98,8 +98,8 @@ def install_cfg_and_error( return NoConfigFoundException(msg) -def _check_config( - cfg_file: str, default_configs: CamelCaseConfigParser, strict: bool): +def _check_config(cfg_file: str, default_configs: CamelCaseConfigParser, + strict: bool) -> None: """ Checks the configuration read up to this point to see if it is outdated. @@ -115,10 +115,10 @@ def _check_config( These are specific values in specific options no longer supported. For example old algorithm names. - :param str cfg_file: Path of last file read in - :param CamelCaseConfigParser default_configs: + :param cfg_file: Path of last file read in + :param default_configs: configuration with just the default files in - :param bool strict: Flag to say an exception should be raised + :param strict: Flag to say an exception should be raised """ if not default_configs.sections(): # empty logger.warning("Can not validate cfg files as no default.") @@ -143,16 +143,16 @@ def _check_config( def _read_a_config( configuration: CamelCaseConfigParser, cfg_file: str, - default_configs: CamelCaseConfigParser, strict: bool): + default_configs: CamelCaseConfigParser, strict: bool) -> None: """ Reads in a configuration file and then directly its `machine_spec_file`. - :param CamelCaseConfigParser configuration: + :param configuration: configuration to be updated by the reading of a file - :param str cfg_file: path to file which should be read in - :param CamelCaseConfigParser default_configs: + :param cfg_file: path to file which should be read in + :param default_configs: configuration with just the default files in - :param bool strict: Flag to say checker should raise an exception + :param strict: Flag to say checker should raise an exception """ _check_config(cfg_file, default_configs, strict) configuration.read(cfg_file) @@ -167,10 +167,9 @@ def _config_locations(filename: str) -> List[str]: """ Defines the list of places we can get configuration files from. - :param str filename: + :param filename: The local name of the configuration file, e.g., 'spynnaker.cfg' :return: list of fully-qualified filenames - :rtype: list(str) """ dotname = "." + filename @@ -191,10 +190,10 @@ def load_config( """ Load the configuration. - :param str filename: + :param filename: The base name of the configuration file(s). Should not include any path components. - :param list(str) defaults: + :param defaults: The list of files to get default configurations from. :param config_parsers: The parsers to parse the sections of the configuration file with, as @@ -203,9 +202,7 @@ def load_config( be parsed if the section_name is found in the configuration files already loaded. The standard logging parser is appended to (a copy of) this. - :type config_parsers: list(tuple(str, ~configparser.RawConfigParser)) :return: the fully-loaded and checked configuration - :rtype: ~configparser.RawConfigParser """ configs = CamelCaseConfigParser() diff --git a/spinn_utilities/config_holder.py b/spinn_utilities/config_holder.py index 73535ebb..1ac06cbe 100644 --- a/spinn_utilities/config_holder.py +++ b/spinn_utilities/config_holder.py @@ -32,7 +32,7 @@ __unittest_mode: bool = False -def add_default_cfg(default: str): +def add_default_cfg(default: str) -> None: """ Adds an extra default configuration file to be read after earlier ones. @@ -42,7 +42,7 @@ def add_default_cfg(default: str): __default_config_files.append(default) -def clear_cfg_files(unittest_mode: bool): +def clear_cfg_files(unittest_mode: bool) -> None: """ Clears any previous set configurations and configuration files. @@ -58,14 +58,14 @@ def clear_cfg_files(unittest_mode: bool): __unittest_mode = unittest_mode -def set_cfg_files(config_file: str, default: str): +def set_cfg_files(config_file: str, default: str) -> None: """ Adds the configuration files to be loaded. - :param str config_file: + :param config_file: The base name of the configuration file(s). Should not include any path components. - :param str default: + :param default: Full path to the extra file to get default configurations from. """ global __config_file @@ -86,7 +86,7 @@ def _pre_load_config() -> CamelCaseConfigParser: return load_config() -def logging_parser(config: CamelCaseConfigParser): +def logging_parser(config: CamelCaseConfigParser) -> None: """ Create the root logger with the given level. @@ -128,27 +128,25 @@ def load_config() -> CamelCaseConfigParser: return __config -def is_config_none(section, option) -> bool: +def is_config_none(section: str, option: str) -> bool: """ Check if the value of a configuration option would be considered None - :param str section: What section to get the option from. - :param str option: What option to read. + :param section: What section to get the option from. + :param option: What option to read. :return: True if and only if the value would be considered None - :rtype: bool """ value = get_config_str_or_none(section, option) return value is None -def get_config_str(section, option) -> str: +def get_config_str(section: str, option: str) -> str: """ Get the string value of a configuration option. - :param str section: What section to get the option from. - :param str option: What option to read. + :param section: What section to get the option from. + :param option: What option to read. :return: The option value - :rtype: str :raises ConfigException: if the Value would be None """ value = get_config_str_or_none(section, option) @@ -157,14 +155,13 @@ def get_config_str(section, option) -> str: return value -def get_config_str_or_none(section, option) -> Optional[str]: +def get_config_str_or_none(section: str, option: str) -> Optional[str]: """ Get the string value of a configuration option. - :param str section: What section to get the option from. - :param str option: What option to read. + :param section: What section to get the option from. + :param option: What option to read. :return: The option value - :rtype: str or None :raises ConfigException: if the Value would be None """ if __config is None: @@ -178,11 +175,10 @@ def get_config_str_list( """ Get the string value of a configuration option split into a list. - :param str section: What section to get the option from. - :param str option: What option to read. + :param section: What section to get the option from. + :param option: What option to read. :param token: The token to split the string into a list :return: The list (possibly empty) of the option values - :rtype: list(str) """ if __config is None: return _pre_load_config().get_str_list(section, option, token) @@ -194,10 +190,9 @@ def get_config_int(section: str, option: str) -> int: """ Get the integer value of a configuration option. - :param str section: What section to get the option from. - :param str option: What option to read. + :param section: What section to get the option from. + :param option: What option to read. :return: The option value - :rtype: int :raises ConfigException: if the Value would be None """ value = get_config_int_or_none(section, option) @@ -206,14 +201,13 @@ def get_config_int(section: str, option: str) -> int: return value -def get_config_int_or_none(section, option) -> Optional[int]: +def get_config_int_or_none(section: str, option: str) -> Optional[int]: """ Get the integer value of a configuration option. - :param str section: What section to get the option from. - :param str option: What option to read. + :param section: What section to get the option from. + :param option: What option to read. :return: The option value - :rtype: int or None :raises ConfigException: if the Value would be None """ if __config is None: @@ -226,10 +220,9 @@ def get_config_float(section: str, option: str) -> float: """ Get the float value of a configuration option. - :param str section: What section to get the option from. - :param str option: What option to read. + :param section: What section to get the option from. + :param option: What option to read. :return: The option value. - :rtype: float :raises ConfigException: if the Value would be None """ value = get_config_float_or_none(section, option) @@ -238,14 +231,13 @@ def get_config_float(section: str, option: str) -> float: return value -def get_config_float_or_none(section, option) -> Optional[float]: +def get_config_float_or_none(section: str, option: str) -> Optional[float]: """ Get the float value of a configuration option. - :param str section: What section to get the option from. - :param str option: What option to read. + :param section: What section to get the option from. + :param option: What option to read. :return: The option value. - :rtype: float or None """ if __config is None: return _pre_load_config().get_float(section, option) @@ -257,10 +249,9 @@ def get_config_bool(section: str, option: str) -> bool: """ Get the Boolean value of a configuration option. - :param str section: What section to get the option from. - :param str option: What option to read. + :param section: What section to get the option from. + :param option: What option to read. :return: The option value. - :rtype: bool :raises ConfigException: if the Value would be None """ value = get_config_bool_or_none(section, option) @@ -269,17 +260,16 @@ def get_config_bool(section: str, option: str) -> bool: return value -def get_config_bool_or_none(section, option, +def get_config_bool_or_none(section: str, option: str, special_nones: Optional[List[str]] = None ) -> Optional[bool]: """ Get the Boolean value of a configuration option. - :param str section: What section to get the option from. - :param str option: What option to read. + :param section: What section to get the option from. + :param option: What option to read. :param special_nones: What special values to except as None :return: The option value. - :rtype: bool :raises ConfigException: if the Value would be None """ if __config is None: @@ -288,15 +278,15 @@ def get_config_bool_or_none(section, option, return __config.get_bool(section, option, special_nones) -def set_config(section: str, option: str, value: Optional[str]): +def set_config(section: str, option: str, value: Optional[str]) -> None: """ Sets the value of a configuration option. This method should only be called by the simulator or by unit tests. - :param str section: What section to set the option in. - :param str option: What option to set. - :param object value: Value to set option to + :param section: What section to set the option in. + :param option: What option to set. + :param value: Value to set option to :raises ConfigException: If called unexpectedly """ if __config is None: @@ -309,9 +299,8 @@ def has_config_option(section: str, option: str) -> bool: """ Check if the section has this configuration option. - :param str section: What section to check - :param str option: What option to check. - :rtype: bool + :param section: What section to check + :param option: What option to check. :return: True if and only if the option is defined. It may be `None` """ if __config is None: @@ -324,7 +313,7 @@ def config_options(section: str) -> List[str]: """ Return a list of option names for the given section name. - :param str section: What section to list options for. + :param section: What section to list options for. """ if __config is None: raise ConfigException("configuration not loaded") @@ -335,16 +324,16 @@ def config_options(section: str) -> List[str]: # Union[Callable[[str, str], Any], # Callable[[str, str, Optional[List[str]]], Any]] def _check_lines(py_path: str, line: str, lines: List[str], index: int, - method: Callable, used_cfgs: Dict[str, Set[str]], start, - special_nones: Optional[List[str]] = None): + method: Callable, used_cfgs: Dict[str, Set[str]], start: str, + special_nones: Optional[List[str]] = None) -> None: """ Support for `_check_python_file`. Gets section and option name. - :param str line: Line with get_config call - :param list(str) lines: All lines in the file - :param int index: index of line with `get_config` call + :param line: Line with get_config call + :param lines: All lines in the file + :param index: index of line with `get_config` call :param method: Method to call to check cfg - :param dict(str), set(str) used_cfgs: + :param used_cfgs: Dict of used cfg options to be added to :param special_nones: What special values to except as None :raises ConfigException: If an unexpected or uncovered `get_config` found @@ -382,11 +371,11 @@ def _check_lines(py_path: str, line: str, lines: List[str], index: int, def _check_python_file(py_path: str, used_cfgs: Dict[str, Set[str]], - special_nones: Optional[List[str]] = None): + special_nones: Optional[List[str]] = None) -> None: """ A testing function to check that all the `get_config` calls work. - :param str py_path: path to file to be checked + :param py_path: path to file to be checked :param used_cfgs: dict of cfg options found :param special_nones: What special values to except as None :raises ConfigException: If an unexpected or uncovered `get_config` found @@ -426,13 +415,12 @@ def _check_python_file(py_path: str, used_cfgs: Dict[str, Set[str]], get_config_str_list, used_cfgs, "get_config") -def _find_double_defaults(repeaters: Optional[Collection[str]] = ()): +def _find_double_defaults(repeaters: Optional[Collection[str]] = ()) -> None: """ Testing function to identify any configuration options in multiple default files. :param repeaters: List of options that are expected to be repeated. - :type repeaters: list(str) :raises ConfigException: If two defaults configuration files set the same value """ @@ -454,12 +442,12 @@ def _find_double_defaults(repeaters: Optional[Collection[str]] = ()): f"repeats [{section}]{option}") -def _check_cfg_file(config1: CamelCaseConfigParser, cfg_path: str): +def _check_cfg_file(config1: CamelCaseConfigParser, cfg_path: str) -> None: """ Support method for :py:func:`check_cfgs`. - :param CamelCaseConfigParser config1: - :param str cfg_path: + :param config1: + :param cfg_path: :raises ConfigException: If an unexpected option is found """ config2 = CamelCaseConfigParser() @@ -475,7 +463,7 @@ def _check_cfg_file(config1: CamelCaseConfigParser, cfg_path: str): f"has unexpected options [{section}]{option}") -def _check_cfgs(path: str): +def _check_cfgs(path: str) -> None: """ A testing function check local configuration files against the defaults. @@ -504,7 +492,7 @@ def run_config_checks(directories: Union[str, Collection[str]], *, exceptions: Union[str, Collection[str]] = (), repeaters: Optional[Collection[str]] = (), check_all_used: bool = True, - special_nones: Optional[List[str]] = None): + special_nones: Optional[List[str]] = None) -> None: """ Master test. diff --git a/spinn_utilities/configs/camel_case_config_parser.py b/spinn_utilities/configs/camel_case_config_parser.py index 47268eef..7e4f02a8 100644 --- a/spinn_utilities/configs/camel_case_config_parser.py +++ b/spinn_utilities/configs/camel_case_config_parser.py @@ -11,15 +11,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +from collections.abc import Iterable import configparser -from typing import List, Optional +from typing import List, Optional, TYPE_CHECKING, Union NONES = ("none", ) TRUES = ('y', 'yes', 't', 'true', 'on', '1') FALSES = ('n', 'no', 'f', 'false', 'off', '0') +# Type support +if TYPE_CHECKING: + _Path = Union[Union[str, bytes, os.PathLike], + Iterable[Union[str, bytes, os.PathLike]]] +else: + # Python 3.8 does not support above typing + _Path = str + class CamelCaseConfigParser(configparser.RawConfigParser): """ @@ -35,11 +45,12 @@ def optionxform(self, optionstr: str) -> str: lower = optionstr.lower() return lower.replace("_", "") - def __init__(self): + def __init__(self) -> None: super().__init__() - self._read_files = list() + self._read_files: List[str] = list() - def read(self, filenames, encoding=None): + def read(self, filenames: _Path, + encoding: Optional[str] = None) -> List[str]: """ Read and parse a filename or a list of filenames. """ @@ -48,7 +59,7 @@ def read(self, filenames, encoding=None): return new_files @property - def read_files(self): + def read_files(self) -> List[str]: """ The configuration files that have been actually read. """ @@ -58,10 +69,9 @@ def get_str(self, section: str, option: str) -> Optional[str]: """ Get the string value of an option. - :param str section: What section to get the option from. - :param str option: What option to read. + :param section: What section to get the option from. + :param option: What option to read. :return: The option value - :rtype: str or None """ value = self.get(section, option) if value.lower() in NONES: @@ -73,11 +83,10 @@ def get_str_list( """ Get the string value of an option split into a list. - :param str section: What section to get the option from. - :param str option: What option to read. + :param section: What section to get the option from. + :param option: What option to read. :param token: The token to split the string into a list :return: The list (possibly empty) of the option values - :rtype: list(str) """ value = self.get(section, option) if value.lower() in NONES: diff --git a/spinn_utilities/data/data_status.py b/spinn_utilities/data/data_status.py index 8c1929c7..190d3896 100644 --- a/spinn_utilities/data/data_status.py +++ b/spinn_utilities/data/data_status.py @@ -13,7 +13,7 @@ # limitations under the License. from enum import Enum -from typing import Type +from typing import Type, Tuple from spinn_utilities.exceptions import ( DataNotMocked, DataNotYetAvialable, NotSetupException, ShutdownException, SpiNNUtilsException) @@ -34,20 +34,19 @@ class DataStatus(Enum): #: The system has been shut down. SHUTDOWN = (3, ShutdownException) - def __new__(cls, *args) -> 'DataStatus': + def __new__(cls, *args: Tuple[int, SpiNNUtilsException]) -> 'DataStatus': obj = object.__new__(cls) obj._value_ = args[0] return obj - def __init__(self, value, exception: Type[SpiNNUtilsException]): + def __init__(self, value: int, exception: Type[SpiNNUtilsException]): # pylint: disable=unused-argument self._exception = exception - def exception(self, data) -> SpiNNUtilsException: + def exception(self, data: str) -> SpiNNUtilsException: """ Returns an instance of the most suitable data-not-available exception. :param data: Parameter to pass to the relevant constructor. - :rtype: ~spinn_utilities.exceptions.SpiNNUtilsException - """ + """ return self._exception(data) diff --git a/spinn_utilities/data/utils_data_view.py b/spinn_utilities/data/utils_data_view.py index 0791e013..2ed3c71b 100644 --- a/spinn_utilities/data/utils_data_view.py +++ b/spinn_utilities/data/utils_data_view.py @@ -88,7 +88,7 @@ def _hard_reset(self) -> None: self._temporary_directory: Optional[TemporaryDirectory] = None self._soft_reset() - def _soft_reset(self): + def _soft_reset(self) -> None: """ Puts all data back into the state expected at `sim.reset` but not graph changed. @@ -180,8 +180,7 @@ def _exception(cls, data: str) -> SpiNNUtilsException: """ The most suitable no data Exception based on the status. - :param str data: Name of the data not found - :rtype: ~spinn_utilities.exceptions.SpiNNUtilsException + :param data: Name of the data not found """ return cls.__data._data_status.exception(data) @@ -191,8 +190,6 @@ def _exception(cls, data: str) -> SpiNNUtilsException: def _is_mocked(cls) -> bool: """ Checks if the view is in mocked state. - - :rtype: bool """ return cls.__data._data_status == DataStatus.MOCKED @@ -206,8 +203,6 @@ def is_hard_reset(cls) -> bool: During the first run after reset this continues to return True! Returns False after a reset that was considered soft. - - :rtype: bool """ return cls.__data._reset_status == ResetStatus.HARD_RESET @@ -220,8 +215,6 @@ def is_soft_reset(cls) -> bool: During the first run after reset this continues to return True! Returns False after a reset that was considered hard. - - :rtype: bool """ return cls.__data._reset_status == ResetStatus.SOFT_RESET @@ -230,7 +223,6 @@ def is_ran_ever(cls) -> bool: """ Check if the simulation has run at least once, ignoring resets. - :rtype: bool :raises NotImplementedError: If this is called from an unexpected state """ @@ -249,7 +241,6 @@ def is_ran_last(cls) -> bool: """ Checks if the simulation has run and not been reset. - :rtype: bool :raises NotImplementedError: If this is called from an unexpected state """ @@ -273,7 +264,6 @@ def is_reset_last(cls) -> bool: It also returns False after a `sim.stop` or `sim.end` call starts - :rtype: bool :raises NotImplementedError: If this is called from an unexpected state """ @@ -302,7 +292,6 @@ def is_no_stop_requested(cls) -> bool: """ Checks that a stop request has not been sent. - :rtype: bool :raises NotImplementedError: If this is called from an unexpected state """ @@ -321,7 +310,6 @@ def is_running(cls) -> bool: That is a call to run has started but not yet stopped. - :rtype: bool """ return cls.__data._run_status in [ RunStatus.IN_RUN, RunStatus.STOP_REQUESTED] @@ -377,7 +365,6 @@ def is_setup(cls) -> bool: """ Checks to see if there is already a simulator. - :rtype: bool :raises NotImplementedError: If this is called from an unexpected state """ @@ -399,7 +386,6 @@ def is_user_mode(cls) -> bool: This returns False in the Mocked state. - :rtype: bool :raises NotImplementedError: If the data has not yet been set up or on an unexpected run_status """ @@ -425,7 +411,6 @@ def is_stop_already_requested(cls) -> bool: :return: True if the stop has already been requested or if the system is stopping or has already stopped False if the stop request makes sense. - :rtype: bool :raises NotImplementedError: If this is called from an unexpected state :raises SpiNNUtilsException: @@ -448,8 +433,6 @@ def is_shutdown(cls) -> bool: Determines if simulator has already been shutdown. This returns False in the Mocked state - - :rtype: bool """ return cls.__data._run_status == RunStatus.SHUTDOWN @@ -480,7 +463,6 @@ def get_run_dir_path(cls) -> str: In unit test mode this returns a temporary directory shared by all path methods. - :rtype: str :raises ~spinn_utilities.exceptions.SpiNNUtilsException: If the run_dir_path is currently unavailable """ @@ -495,18 +477,17 @@ def get_executable_finder(cls) -> ExecutableFinder: """ The ExcutableFinder object created at time code is imported. - :rtype: ExcutableFinder """ return cls.__data._executable_finder @classmethod - def register_binary_search_path(cls, search_path: str): + def register_binary_search_path(cls, search_path: str) -> None: """ Register an additional binary search path for executables. Syntactic sugar for `get_executable_finder().add_path()` - :param str search_path: absolute search path for binaries + :param search_path: absolute search path for binaries """ cls.__data._executable_finder.add_path(search_path) @@ -518,9 +499,8 @@ def get_executable_path(cls, executable_name: str) -> str: Syntactic sugar for `get_executable_finder().get_executable_path()` - :param str executable_name: The name of the executable to find + :param executable_name: The name of the executable to find :return: The full path of the discovered executable - :rtype: str :raises KeyError: If no executable was found in the set of folders """ return cls.__data._executable_finder.get_executable_path( @@ -539,12 +519,11 @@ def get_executable_paths(cls, executable_names: str) -> List[str]: Syntactic sugar for `get_executable_finder().get_executable_paths()` - :param str executable_names: The name of the executable to find. + :param executable_names: The name of the executable to find. Assumed to be comma separated. :return: The full path of the discovered executable, or ``None`` if no executable was found in the set of folders - :rtype: list(str) """ return cls.__data._executable_finder.get_executable_paths( executable_names) @@ -559,7 +538,6 @@ def get_requires_data_generation(cls) -> bool: Remains True during the first run after a data change Only set to False at the *end* of the first run - :rtype: bool """ return cls.__data._requires_data_generation @@ -582,8 +560,6 @@ def get_requires_mapping(cls) -> bool: any mapping stage to be called Remains True during the first run after a requires mapping. Only set to False at the *end* of the first run - - :rtype: bool """ return cls.__data._requires_mapping @@ -621,7 +597,7 @@ def raise_skiptest(cls, reason: str, """ Sets the status as shutdown and raises a SkipTest - :param str reason: Message for the Skip + :param reason: Message for the Skip :param parent: Exception which triggered the skip if any :type parent: Exception or None :raises: SkipTest very time called diff --git a/spinn_utilities/data/utils_data_writer.py b/spinn_utilities/data/utils_data_writer.py index 0966cd6a..79a71370 100644 --- a/spinn_utilities/data/utils_data_writer.py +++ b/spinn_utilities/data/utils_data_writer.py @@ -63,8 +63,7 @@ class UtilsDataWriter(UtilsDataView): def __init__(self, state: DataStatus): """ - :param ~spinn_utilities.data.DataStatus state: - State writer should be in + :param state: State writer should be in """ if state == DataStatus.MOCKED: self._mock() @@ -92,7 +91,6 @@ def mock(cls) -> Self: then set that value. :return: A Data Writer - :rtype: UtilsDataWriter """ return cls(DataStatus.MOCKED) @@ -104,7 +102,6 @@ def setup(cls) -> Self: All previous data will be cleared :return: A Data Writer - :rtype: UtilsDataWriter """ return cls(DataStatus.SETUP) @@ -259,7 +256,6 @@ def get_report_dir_path(self) -> str: As it is only accessed to create `timestamp` directories and remove old reports, this is not a view method. - :rtype: str :raises SpiNNUtilsException: If the `simulation_time_step` is currently unavailable """ @@ -267,11 +263,11 @@ def get_report_dir_path(self) -> str: return self.__data._report_dir_path raise self._exception("report_dir_path") - def set_run_dir_path(self, run_dir_path: str): + def set_run_dir_path(self, run_dir_path: str) -> None: """ Checks and sets the `run_dir_path`. - :param str run_dir_path: + :param run_dir_path: :raises InvalidDirectory: if the `run_dir_path` is not a directory """ if os.path.isdir(run_dir_path): @@ -280,11 +276,11 @@ def set_run_dir_path(self, run_dir_path: str): self.__data._run_dir_path = None raise InvalidDirectory("run_dir_path", run_dir_path) - def set_report_dir_path(self, reports_dir_path: str): + def set_report_dir_path(self, reports_dir_path: str) -> None: """ Checks and sets the `reports_dir_path`. - :param str reports_dir_path: + :param reports_dir_path: :raises InvalidDirectory: if the `reports_dir_path` is not a directory """ if os.path.isdir(reports_dir_path): @@ -293,7 +289,8 @@ def set_report_dir_path(self, reports_dir_path: str): self.__data._report_dir_path = None raise InvalidDirectory("run_dir_path", reports_dir_path) - def _set_executable_finder(self, executable_finder: ExecutableFinder): + def _set_executable_finder( + self, executable_finder: ExecutableFinder) -> None: """ Only usable by unit tests! diff --git a/spinn_utilities/exceptions.py b/spinn_utilities/exceptions.py index bbdd716d..c9ebbe9c 100644 --- a/spinn_utilities/exceptions.py +++ b/spinn_utilities/exceptions.py @@ -30,7 +30,7 @@ class NotSetupException(SpiNNUtilsException): Raised when trying to get data before simulator has been setup. """ - def __init__(self, data): + def __init__(self, data: str) -> None: super().__init__(f"Requesting {data} is not valid before setup") @@ -38,7 +38,7 @@ class InvalidDirectory(SpiNNUtilsException): """ Raised when trying to set an invalid directory. """ - def __init__(self, name, value): + def __init__(self, name: str, value: str) -> None: super().__init__(f"Unable to set {name} has {value} is not a dir.") @@ -46,7 +46,7 @@ class DataNotYetAvialable(SpiNNUtilsException): """ Raised when trying to get data before simulator has created it. """ - def __init__(self, data): + def __init__(self, data: str) -> None: super().__init__(f"{data} has not yet been created.") @@ -54,7 +54,7 @@ class DataNotMocked(DataNotYetAvialable): """ Raised when trying to get data before a mocked simulator has created it. """ - def __init__(self, data): + def __init__(self, data: str) -> None: super().__init__(f"MOCK {data}") @@ -62,7 +62,7 @@ class ShutdownException(SpiNNUtilsException): """ Raised when trying to get simulator data after it has been shut down. """ - def __init__(self, data): + def __init__(self, data: str) -> None: super().__init__(f"Requesting {data} is not valid after end") diff --git a/spinn_utilities/helpful_functions.py b/spinn_utilities/helpful_functions.py index 8a59515d..d95e7276 100644 --- a/spinn_utilities/helpful_functions.py +++ b/spinn_utilities/helpful_functions.py @@ -53,7 +53,7 @@ def lcm(number: int, /, *numbers: int) -> int: ... -def lcm(*numbers) -> int: +def lcm(*numbers) -> int: # type: ignore[no-untyped-def] """ Lowest common multiple of 0, 1 or more integers. @@ -90,7 +90,7 @@ def gcd(number: int, /, *numbers: int) -> int: ... -def gcd(*numbers) -> int: +def gcd(*numbers) -> int: # type: ignore[no-untyped-def] """ Greatest common divisor of 1 or more integers. diff --git a/spinn_utilities/log.py b/spinn_utilities/log.py index 92edac7c..edba1504 100644 --- a/spinn_utilities/log.py +++ b/spinn_utilities/log.py @@ -13,12 +13,14 @@ # limitations under the License. import atexit +import configparser from datetime import datetime import logging import re import sys -from typing import Any, List, Optional, Tuple +from typing import Any, Dict, KeysView, List, Mapping, Optional, Tuple from inspect import getfullargspec +from spinn_utilities.configs import CamelCaseConfigParser from .log_store import LogStore from .overrides import overrides @@ -38,13 +40,13 @@ class ConfiguredFilter(object): __slots__ = [ "_default_level", "_levels"] - def __init__(self, conf): + def __init__(self, conf: configparser.RawConfigParser): self._levels = ConfiguredFormatter.construct_logging_parents(conf) self._default_level = logging.INFO if conf.has_option("Logging", "default"): self._default_level = _LEVELS[conf.get("Logging", "default")] - def filter(self, record): + def filter(self, record: logging.LogRecord) -> bool: """ Get the level for the deepest parent, and filter appropriately. """ @@ -64,7 +66,7 @@ class ConfiguredFormatter(logging.Formatter): # Precompile this RE; it gets used quite a few times __last_component = re.compile(r'\.[^.]+$') - def __init__(self, conf): + def __init__(self, conf: CamelCaseConfigParser) -> None: if (conf.has_option("Logging", "default") and conf.get("Logging", "default") == "debug"): fmt = "%(asctime)-15s %(levelname)s: %(pathname)s: %(message)s" @@ -73,12 +75,13 @@ def __init__(self, conf): super().__init__(fmt=fmt, datefmt="%Y-%m-%d %H:%M:%S") @staticmethod - def construct_logging_parents(conf): + def construct_logging_parents( + conf: configparser.RawConfigParser) -> Dict[str, int]: """ Create a dictionary of module names and logging levels. """ # Construct the dictionary - _levels = {} + _levels: Dict[str, int] = {} if not conf.has_section("Logging"): return _levels @@ -92,7 +95,7 @@ def construct_logging_parents(conf): return _levels @staticmethod - def deepest_parent(parents, child): + def deepest_parent(parents: KeysView[str], child: str) -> Optional[str]: """ Greediest match between child and parent. """ @@ -111,7 +114,8 @@ def deepest_parent(parents, child): return match @staticmethod - def level_of_deepest_parent(parents, child): + def level_of_deepest_parent( + parents: Dict[str, int], child: str) -> Optional[int]: """ The logging level of the greediest match between child and parent. """ @@ -129,14 +133,16 @@ class _BraceMessage(object): A message that converts a Python format string to a string. """ __slots__ = [ + "args", "fmt", "kwargs"] - def __init__(self, fmt, args, kwargs): + def __init__(self, fmt: object, + args: Tuple[object, ...], kwargs: Dict[str, object]) -> None: self.fmt = fmt self.args = args self.kwargs = kwargs - def __str__(self): + def __str__(self) -> str: try: return str(self.fmt).format(*self.args, **self.kwargs) except KeyError: @@ -182,7 +188,7 @@ class FormatAdapter(logging.LoggerAdapter): __log_store: Optional[LogStore] = None @classmethod - def set_kill_level(cls, level: Optional[int] = None): + def set_kill_level(cls, level: Optional[int] = None) -> None: """ Allow system to change the level at which a log is changed to an Exception. @@ -199,7 +205,7 @@ def set_kill_level(cls, level: Optional[int] = None): cls.__kill_level = level @classmethod - def set_log_store(cls, log_store: Optional[LogStore]): + def set_log_store(cls, log_store: Optional[LogStore]) -> None: """ Sets a Object to write the log messages to :param LogStore log_store: @@ -211,14 +217,17 @@ def set_log_store(cls, log_store: Optional[LogStore]): for timestamp, level, message in cls._pop_not_stored_messages(): cls.__log_store.store_log(level, message, timestamp) - def __init__(self, logger: logging.Logger, extra=None): + def __init__( + self, logger: logging.Logger, + extra: Optional[Mapping[str, object]] = None) -> None: if extra is None: extra = {} super().__init__(logger, extra) self.do_log = logger._log # pylint: disable=protected-access @overrides(logging.LoggerAdapter.log, extend_doc=False, adds_typing=True) - def log(self, level: int, msg: object, *args, **kwargs): + def log(self, level: int, msg: object, + *args: object, **kwargs: object) -> None: """ Delegate a log call to the underlying logger, applying appropriate transformations to allow the log message to be written using @@ -297,7 +306,8 @@ def atexit_handler(cls) -> None: print(message, file=sys.stderr) @classmethod - def _pop_not_stored_messages(cls, min_level=0): + def _pop_not_stored_messages( + cls, min_level: int = 0) -> List[Tuple[datetime, int, str]]: """ Returns the log of messages to print on exit and *clears that log*. @@ -305,7 +315,7 @@ def _pop_not_stored_messages(cls, min_level=0): .. note:: Should only be called externally from test code! """ - result = [] + result: List[Tuple[datetime, int, str]] = [] try: for timestamp, level, message in cls.__not_stored_messages: if level >= min_level: diff --git a/spinn_utilities/log_store.py b/spinn_utilities/log_store.py index 1d78e6ea..2ea7d056 100644 --- a/spinn_utilities/log_store.py +++ b/spinn_utilities/log_store.py @@ -24,7 +24,7 @@ class LogStore(object): @abstractmethod def store_log(self, level: int, message: str, - timestamp: Optional[datetime] = None): + timestamp: Optional[datetime] = None) -> None: """ Writes the log message for later retrieval. diff --git a/spinn_utilities/logger_utils.py b/spinn_utilities/logger_utils.py index aa99b595..e266ca40 100644 --- a/spinn_utilities/logger_utils.py +++ b/spinn_utilities/logger_utils.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from spinn_utilities.log import FormatAdapter + _already_issued = set() -def warn_once(logger, msg): +def warn_once(logger: FormatAdapter, msg: str) -> None: """ Write a warning message to the given logger where that message should only be written to the logger once. @@ -32,7 +34,7 @@ def warn_once(logger, msg): logger.warning(msg) -def error_once(logger, msg): +def error_once(logger: FormatAdapter, msg: str) -> None: """ Write an error message to the given logger where that message should only be written to the logger once. @@ -49,7 +51,7 @@ def error_once(logger, msg): logger.error(msg) -def reset(): +def reset() -> None: """ Clear the store of what messages have already been seen. diff --git a/spinn_utilities/make_tools/converter.py b/spinn_utilities/make_tools/converter.py index 018a0216..1ba76a0a 100644 --- a/spinn_utilities/make_tools/converter.py +++ b/spinn_utilities/make_tools/converter.py @@ -14,6 +14,7 @@ import os import sys +from typing import Optional from .file_converter import FileConverter from .log_sqllite_database import LogSqlLiteDatabase @@ -24,13 +25,13 @@ "neural_build.mk", "Makefile.neural_build"]) -def convert(src, dest, new_dict): +def convert(src: str, dest: str, new_dict: bool) -> None: """ Converts a whole directory including sub-directories. - :param str src: Full source directory - :param str dest: Full destination directory - :param bool new_dict: + :param src: Full source directory + :param dest: Full destination directory + :param new_dict: Whether we should generate a new dictionary/DB. If not, we add to the existing one. """ @@ -45,13 +46,14 @@ def convert(src, dest, new_dict): _convert_dir(src_path, dest_path) -def _convert_dir(src_path, dest_path, make_directories=False): +def _convert_dir(src_path: str, dest_path: str, + make_directories: Optional[bool] = False) -> None: """ Converts a whole directory including sub directories. - :param str src_path: Full source directory - :param str dest_path: Full destination directory - :param bool make_directories: Whether to do `mkdir()` first + :param src_path: Full source directory + :param dest_path: Full destination directory + :param make_directories: Whether to do `mkdir()` first """ if make_directories: _mkdir(dest_path) @@ -70,7 +72,7 @@ def _convert_dir(src_path, dest_path, make_directories=False): print(f"Unexpected file {source}") -def _mkdir(destination): +def _mkdir(destination: str) -> None: if not os.path.exists(destination): os.mkdir(destination) if not os.path.exists(destination): diff --git a/spinn_utilities/make_tools/file_converter.py b/spinn_utilities/make_tools/file_converter.py index 9a9b33cb..e313da1b 100644 --- a/spinn_utilities/make_tools/file_converter.py +++ b/spinn_utilities/make_tools/file_converter.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import enum +from io import TextIOBase import os import re +from typing import List, Optional from spinn_utilities.exceptions import UnexpectedCException from .log_sqllite_database import LogSqlLiteDatabase @@ -66,15 +68,15 @@ class FileConverter(object): "_too_many_lines" ] - def __call__(self, src, dest, log_file_id, log_database): + def __call__(self, src: str, dest: str, log_file_id: int, + log_database: LogSqlLiteDatabase) -> None: """ Creates the file_convertor to convert one file. - :param str src: Absolute path to source file - :param str dest: Absolute path to destination file - :param int log_file_id: - Id in the database for this file - :param LogSqlLiteDatabase log_database: + :param src: Absolute path to source file + :param dest: Absolute path to destination file + :param log_file_id: Id in the database for this file + :param log_database: The database which handles the mapping of id to log messages. """ #: Absolute path to source file @@ -92,7 +94,7 @@ def __call__(self, src, dest, log_file_id, log_database): #: Current status of state machine #: #: :type: State - self._status = None + self._status: Optional[State] = None #: Number of extra lines written to modified not yet recovered #: Extra lines are caused by the header and possibly log comment #: Extra lines are recovered by omitting blank lines @@ -110,7 +112,7 @@ def __call__(self, src, dest, log_file_id, log_database): #: The previous state #: #: :type: State - self._previous_status = None + self._previous_status: Optional[State] = None with open(src, encoding="utf-8") as src_f: with open(dest, 'w', encoding="utf-8") as dest_f: @@ -132,7 +134,7 @@ def __call__(self, src, dest, log_file_id, log_database): self._process_chars(dest_f, line_num, text) self._check_end_status() - def _check_end_status(self): + def _check_end_status(self) -> None: if self._status == State.NORMAL_CODE: return if self._status == State.IN_LOG: @@ -147,15 +149,15 @@ def _check_end_status(self): f"Unclosed block comment in {self._src}") raise NotImplementedError(f"Unexpected status {self._status}") - def _process_line(self, dest_f, line_num, text): + def _process_line( + self, dest_f: TextIOBase, line_num: int, text: str) -> bool: """ Process a single line. :param dest_f: Open file like Object to write modified source to - :param int line_num: Line number in the source c file - :param str text: Text of that line including whitespace + :param line_num: Line number in the source c file + :param text: Text of that line including whitespace :return: True if and only if the whole line was processed - :rtype: bool """ if self._status == State.COMMENT: return self._process_line_in_comment(dest_f, text) @@ -172,18 +174,18 @@ def _process_line(self, dest_f, line_num, text): assert self._status == State.NORMAL_CODE return self._process_line_normal_code(dest_f, line_num, text) - def _process_line_in_comment(self, dest_f, text): + def _process_line_in_comment(self, dest_f: TextIOBase, text: str) -> bool: """ Process a single line when in a multi-line comment: ``/* .. */`` :param dest_f: Open file like Object to write modified source to - :param str text: Text of that line including whitespace + :param text: Text of that line including whitespace :return: True if and only if the whole line was processed - :rtype: bool """ if "*/" in text: stripped = text.strip() match = END_COMMENT_REGEX.search(stripped) + assert match is not None if match.end(0) == len(stripped): # OK Comment until end of line dest_f.write(text) @@ -194,17 +196,17 @@ def _process_line_in_comment(self, dest_f, text): dest_f.write(text) return True - def _process_line_comment_start(self, dest_f, line_num, text): + def _process_line_comment_start( + self, dest_f: TextIOBase, line_num: int, text: str) -> bool: """ Processes a line known assumed to contain a ``/*`` but not know where. There is also the assumption that the start status is not ``COMMENT``. :param dest_f: Open file like Object to write modified source to - :param int line_num: Line number in the source c file - :param str text: Text of that line including whitespace + :param line_num: Line number in the source c file + :param text: Text of that line including whitespace :return: True if and only if the whole line was processed - :rtype: bool """ stripped = text.strip() if stripped.startswith("/*"): @@ -215,15 +217,15 @@ def _process_line_comment_start(self, dest_f, line_num, text): # Stuff before comment so check by char return False # More than one possible end so check by char - def _process_line_in_log(self, dest_f, line_num, text): + def _process_line_in_log( + self, dest_f: TextIOBase, line_num: int, text: str) -> bool: """ Process a line when the status is a log call has been started. :param dest_f: Open file like Object to write modified source to - :param int line_num: Line number in the source c file - :param str text: Text of that line including whitespace + :param line_num: Line number in the source c file + :param text: Text of that line including whitespace :return: True if and only if the whole line was processed - :rtype: bool """ stripped = text.strip() if stripped.startswith("//"): @@ -249,15 +251,15 @@ def _process_line_in_log(self, dest_f, line_num, text): self._status = State.NORMAL_CODE return True - def _process_line_in_log_close_bracket(self, dest_f, line_num, text): + def _process_line_in_log_close_bracket( + self, dest_f: TextIOBase, line_num: int, text: str) -> bool: """ Process where the last log line has the ``)`` but not the ``;`` :param dest_f: Open file like Object to write modified source to - :param int line_num: Line number in the source c file - :param str text: Text of that line including whitespace + :param line_num: Line number in the source c file + :param text: Text of that line including whitespace :return: True if and only if the whole line was processed - :rtype: bool """ stripped = text.strip() if len(stripped) == 0: @@ -283,15 +285,15 @@ def _process_line_in_log_close_bracket(self, dest_f, line_num, text): self._status = State.IN_LOG return self._process_line_in_log(dest_f, line_num, text) - def _process_line_normal_code(self, dest_f, line_num, text): + def _process_line_normal_code( + self, dest_f: TextIOBase, line_num: int, text: str) -> bool: """ Process a line where the status is normal code. :param dest_f: Open file like Object to write modified source to - :param int line_num: Line number in the source c file - :param str text: Text of that line including whitespace + :param line_num: Line number in the source c file + :param text: Text of that line including whitespace :return: True if and only if the whole line was processed - :rtype: bool """ stripped = text.strip() match = LOG_START_REGEX.search(stripped) @@ -323,31 +325,28 @@ def _process_line_normal_code(self, dest_f, line_num, text): # Now check for the end of log command return self._process_line_in_log(dest_f, line_num, text[start_len:]) - def quote_part(self, text): + def quote_part(self, text: str) -> int: """ Net count of double quotes in line. - :param str text: - :rtype: int + :param text: """ return (text.count('"') - text.count('\\"')) % 2 > 0 - def bracket_count(self, text): + def bracket_count(self, text: str) -> int: """ Net count of open brackets in line. - :param str text: - :rtype: int + :param text: """ return (text.count('(') - text.count(')')) - def split_by_comma_plus(self, main, line_num): + def split_by_comma_plus(self, main: str, line_num: int) -> List[str]: """ Split line by comma and partially parse. - :param str main: - :param int line_num: - :rtype: list(str) + :param main: + :param line_num: :raises UnexpectedCException: """ try: @@ -393,16 +392,16 @@ def split_by_comma_plus(self, main, line_num): raise UnexpectedCException(f"Unexpected line {self._log_full} " f"at {line_num} in {self._src}") from e - def _short_log(self, line_num): + def _short_log(self, line_num: int) -> str: """ Shortens the log string message and adds the ID. - :param int line_num: Current line number + :param line_num: Current line number :return: shorten form - :rtype: str """ try: full_match = LOG_END_REGEX.search(self._log_full) + assert full_match is not None main = self._log_full[:-len(full_match.group(0))] except Exception as e: raise UnexpectedCException( @@ -449,7 +448,8 @@ def _short_log(self, line_num): back += ");" return front + back - def _write_log_method(self, dest_f, line_num, tail=""): + def _write_log_method( + self, dest_f: TextIOBase, line_num: int, tail: str = "") -> None: """ Writes the log message and the dict value. @@ -460,8 +460,8 @@ def _write_log_method(self, dest_f, line_num, tail=""): - Old log message with full text added as comment :param dest_f: Open file like Object to write modified source to - :param int line_num: Line number in the source C file - :param str text: Text of that line including whitespace + :param line_num: Line number in the source C file + :param text: Text of that line including whitespace """ self._log_full = self._log_full.replace('""', '') short_log = self._short_log(line_num) @@ -489,13 +489,14 @@ def _write_log_method(self, dest_f, line_num, tail=""): dest_f.write("*/") dest_f.write(end * (self._log_lines - 1)) - def _process_chars(self, dest_f, line_num, text): + def _process_chars( + self, dest_f: TextIOBase, line_num: int, text: str) -> None: """ Deals with complex lines that can not be handled in one go. :param dest_f: Open file like Object to write modified source to - :param int line_num: Line number in the source c file - :param str text: Text of that line including whitespace + :param line_num: Line number in the source c file + :param text: Text of that line including whitespace :raises UnexpectedCException: """ position = 0 @@ -619,13 +620,13 @@ def _process_chars(self, dest_f, line_num, text): dest_f.write(text[write_flag:]) @staticmethod - def convert(src_dir, dest_dir, file_name): + def convert(src_dir: str, dest_dir: str, file_name: str) -> None: """ Static method to create Object and do the conversion. - :param str src_dir: Source directory - :param str dest_dir: Destination directory - :param str file_name: + :param src_dir: Source directory + :param dest_dir: Destination directory + :param file_name: The name of the file to convert within the source directory; it will be made with the same name in the destination directory. """ diff --git a/spinn_utilities/make_tools/log_sqllite_database.py b/spinn_utilities/make_tools/log_sqllite_database.py index 2344178e..96b06fb6 100644 --- a/spinn_utilities/make_tools/log_sqllite_database.py +++ b/spinn_utilities/make_tools/log_sqllite_database.py @@ -24,7 +24,7 @@ DB_FILE_NAME = "logs.sqlite3" -def _timestamp(): +def _timestamp() -> int: return int(time.time() * _SECONDS_TO_MICRO_SECONDS_CONVERSION) @@ -46,7 +46,7 @@ class LogSqlLiteDatabase(AbstractContextManager): "_db", ] - def __init__(self, new_dict=False): + def __init__(self, new_dict: bool = False) -> None: """ Connects to a log dict. The location of the file can be overridden using the ``C_LOGS_DICT`` environment variable. @@ -89,7 +89,6 @@ def _database_file(self) -> str: otherwise the default path in this directory is used. :return: Absolute path to where the database file is or will be - :rtype: str """ if 'C_LOGS_DICT' in os.environ: return str(os.environ['C_LOGS_DICT']) @@ -104,7 +103,6 @@ def _extra_database_error_message(self) -> str: Adds a possible extra part to the error message. :return: A likely empty string - :rtype: str """ return "" @@ -112,7 +110,7 @@ def _check_database_file(self, database_file: str) -> None: """ Checks the database file exists: - :param str database_file: Absolute path to the database file + :param database_file: Absolute path to the database file :raises FileNotFoundErrorL If the file does not exists """ if os.path.exists(database_file): @@ -125,10 +123,10 @@ def _check_database_file(self, database_file: str) -> None: message += "Please rebuild the C code." raise FileNotFoundError(message) - def __del__(self): + def __del__(self) -> None: self.close() - def close(self): + def close(self) -> None: """ Finalises and closes the database. """ @@ -139,10 +137,11 @@ def close(self): pass self._db = None - def __init_db(self): + def __init_db(self) -> None: """ Set up the database if required. """ + assert self._db is not None self._db.row_factory = sqlite3.Row # Don't use memoryview / buffer as hard to deal with difference self._db.text_factory = str @@ -150,7 +149,8 @@ def __init_db(self): sql = f.read() self._db.executescript(sql) - def __clear_db(self): + def __clear_db(self) -> None: + assert self._db is not None with self._db: cursor = self._db.cursor() cursor.execute("DELETE FROM log") @@ -166,10 +166,10 @@ def get_directory_id(self, src_path: str, dest_path: str) -> int: """ gets the Ids for this directory. Making a new one if needed - :param str src_path: - :param str dest_path: - :rtype: int + :param src_path: + :param dest_path: """ + assert self._db is not None with self._db: cursor = self._db.cursor() # reuse the existing if it exists @@ -188,16 +188,18 @@ def get_directory_id(self, src_path: str, dest_path: str) -> int: INSERT INTO directory(src_path, dest_path) VALUES(?, ?) """, (src_path, dest_path)) - return cursor.lastrowid + directory_id = cursor.lastrowid + assert directory_id is not None + return directory_id def get_file_id(self, directory_id: int, file_name: str) -> int: """ Gets the id for this file, making a new one if needed. - :param int directory_id: - :param str file_name: - :rtype: int + :param directory_id: + :param file_name: """ + assert self._db is not None with self._db: # Make previous one as not last with self._db: @@ -214,18 +216,21 @@ def get_file_id(self, directory_id: int, file_name: str) -> int: directory_id, file_name, convert_time, last_build) VALUES(?, ?, ?, 1) """, (directory_id, file_name, _timestamp())) - return cursor.lastrowid + file_id = cursor.lastrowid + assert file_id is not None + return file_id - def set_log_info( - self, log_level: int, line_num: int, original: str, file_id: int): + def set_log_info(self, log_level: int, line_num: int, + original: str, file_id: int) -> int: """ Saves the data needed to replace a short log back to the original. - :param int log_level: - :param int line_num: - :param str original: - :param int file_id: + :param log_level: + :param line_num: + :param original: + :param file_id: """ + assert self._db is not None with self._db: cursor = self._db.cursor() # reuse the existing number if nothing has changed @@ -243,7 +248,9 @@ def set_log_info( INSERT INTO log(log_level, line_num, original, file_id) VALUES(?, ?, ?, ?) """, (log_level, line_num, original, file_id)) - return cursor.lastrowid + log_id = cursor.lastrowid + assert log_id is not None + return log_id else: for row in self._db.execute( """ @@ -260,9 +267,9 @@ def get_log_info(self, log_id: str) -> Optional[Tuple[int, str, int, str]]: """ Gets the data needed to replace a short log back to the original. - :param str log_id: The int id as a String - :rtype: tuple(int, str, int, str) + :param log_id: The int id as a String """ + assert self._db is not None with self._db: for row in self._db.execute( """ @@ -275,15 +282,16 @@ def get_log_info(self, log_id: str) -> Optional[Tuple[int, str, int, str]]: row["original"]) return None - def check_original(self, original: str): + def check_original(self, original: str) -> None: """ Checks that an original log line has been added to the database. Mainly used for testing - :param str original: + :param original: :raises ValueError: If the original is not in the database """ + assert self._db is not None with self._db: for row in self._db.execute( """ @@ -294,12 +302,11 @@ def check_original(self, original: str): if row["counts"] == 0: raise ValueError(f"{original} not found in database") - def get_max_log_id(self): + def get_max_log_id(self) -> Optional[int]: """ Get the max id of any log message. - - :rtype: int """ + assert self._db is not None with self._db: for row in self._db.execute( """ diff --git a/spinn_utilities/make_tools/replacer.py b/spinn_utilities/make_tools/replacer.py index 579672b7..43abab49 100644 --- a/spinn_utilities/make_tools/replacer.py +++ b/spinn_utilities/make_tools/replacer.py @@ -17,10 +17,15 @@ import shutil import struct import sys -from typing import Optional, Tuple +from types import TracebackType +from typing import Optional, Type, Tuple + +from typing_extensions import Literal, Self + from spinn_utilities.overrides import overrides from spinn_utilities.config_holder import get_config_str_or_none from spinn_utilities.log import FormatAdapter + from .file_converter import FORMAT_EXP from .file_converter import TOKEN from .log_sqllite_database import DB_FILE_NAME, LogSqlLiteDatabase @@ -61,11 +66,11 @@ def _extra_database_error_message(self) -> str: return (f"The cfg {extra__binaries=} " f"also does not contain a {DB_FILE_NAME}. ") - def __enter__(self): + def __enter__(self) -> Self: return self - def __exit__(self, exc_type, exc_value, exc_traceback): - # nothing yet + def __exit__(self, exc_type: Optional[Type], exc_val: Exception, + exc_tb: TracebackType) -> Literal[False]: return False _INT_FMT = struct.Struct("!I") @@ -124,11 +129,11 @@ def replace(self, short: str) -> str: (log_level, file_name, line_num, replaced) = data return f"{LEVELS[log_level]} ({file_name}: {line_num}): {replaced}" - def _hex_to_float(self, hex_str): + def _hex_to_float(self, hex_str: str) -> str: return self._FLT_FMT.unpack( self._INT_FMT.pack(int(hex_str, 16)))[0] - def _hexes_to_double(self, upper, lower): + def _hexes_to_double(self, upper: str, lower: str) -> str: return self._DBL_FMT.unpack( self._INT_FMT.pack(int(upper, 16)) + self._INT_FMT.pack(int(lower, 16)))[0] diff --git a/spinn_utilities/ordered_set.py b/spinn_utilities/ordered_set.py index 42a0c52f..3ce1afba 100644 --- a/spinn_utilities/ordered_set.py +++ b/spinn_utilities/ordered_set.py @@ -39,11 +39,11 @@ def __init__(self, iterable: Optional[Iterable[T]] = None): if iterable is not None: self.update(iterable) - def add(self, value: T): + def add(self, value: T) -> None: if value not in self._map: self._map[value] = None - def discard(self, value: T): + def discard(self, value: T) -> None: if value in self._map: self._map.pop(value) @@ -72,7 +72,7 @@ def __len__(self) -> int: def __contains__(self, key: Any) -> bool: return key in self._map - def update(self, iterable: Iterable[T]): + def update(self, iterable: Iterable[T]) -> None: """ Updates the set by adding each item in order diff --git a/spinn_utilities/overrides.py b/spinn_utilities/overrides.py index 33f86d7a..f46f1e6d 100644 --- a/spinn_utilities/overrides.py +++ b/spinn_utilities/overrides.py @@ -15,7 +15,7 @@ import inspect import os from types import FunctionType, MethodType -from typing import Any, Callable, Iterable, Optional, TypeVar +from typing import Any, Callable, Iterable, Optional, List, Tuple, TypeVar #: :meta private: Method = TypeVar("Method", bound=Callable[..., Any]) @@ -31,7 +31,7 @@ class overrides(object): """ # This near constant is changed by unit tests to check our code # Github actions sets TYPE_OVERRIDES as True - __CHECK_TYPES = os.getenv("TYPE_OVERRIDES") + __CHECK_TYPES: Optional[Any] = os.getenv("TYPE_OVERRIDES") __slots__ = [ # The method in the superclass that this method overrides @@ -51,7 +51,7 @@ class overrides(object): ] def __init__( - self, super_class_method, *, extend_doc: bool = True, + self, super_class_method: Callable, *, extend_doc: bool = True, additional_arguments: Optional[Iterable[str]] = None, extend_defaults: bool = False, adds_typing: bool = False,): """ @@ -86,7 +86,9 @@ def __init__( self._adds_typing = adds_typing @staticmethod - def __match_defaults(default_args, super_defaults, extend_ok): + def __match_defaults(default_args: Optional[List[Any]], + super_defaults: Optional[Tuple[Any]], + extend_ok: bool) -> bool: if default_args is None: return super_defaults is None elif super_defaults is None: @@ -95,7 +97,9 @@ def __match_defaults(default_args, super_defaults, extend_ok): return len(default_args) >= len(super_defaults) return len(default_args) == len(super_defaults) - def _verify_types(self, method_args, super_args, all_args): + def _verify_types(self, method_args: inspect.FullArgSpec, + super_args: inspect.FullArgSpec, + all_args: List[str]) -> None: """ Check that the arguments match. """ @@ -142,7 +146,7 @@ def _verify_types(self, method_args, super_args, all_args): f"Super Method {self._superclass_method.__name__} " f"has no return type, while this does") - def __verify_method_arguments(self, method: Method): + def __verify_method_arguments(self, method: Method) -> None: """ Check that the arguments match. """ @@ -203,7 +207,7 @@ def __call__(self, method: Method) -> Method: return method @classmethod - def check_types(cls): + def check_types(cls) -> None: """ If called will trigger check that all parameters are checked. diff --git a/spinn_utilities/package_loader.py b/spinn_utilities/package_loader.py index 9d6e6120..befa0483 100644 --- a/spinn_utilities/package_loader.py +++ b/spinn_utilities/package_loader.py @@ -15,10 +15,13 @@ import os import sys import traceback +from typing import List, Optional, Set + from spinn_utilities.overrides import overrides -def all_modules(directory, prefix, remove_pyc_files=False): +def all_modules(directory: str, prefix: str, + remove_pyc_files: bool = False) -> Set[str]: """ List all the python files found in this directory giving then the prefix. @@ -57,8 +60,9 @@ def all_modules(directory, prefix, remove_pyc_files=False): def load_modules( - directory, prefix, remove_pyc_files=False, exclusions=None, - gather_errors=True): + directory: str, prefix: str, remove_pyc_files: bool = False, + exclusions: Optional[List[str]] = None, + gather_errors: bool = True) -> None: """ Loads all the python files found in this directory, giving them the specified prefix. @@ -66,13 +70,12 @@ def load_modules( Any file that ends in either ``.py`` or ``.pyc`` is assume a python module and added to the result set. - :param str directory: path to check for python files - :param str prefix: package prefix top add to the file name - :param bool remove_pyc_files: True if ``.pyc`` files should be deleted - :param list(str) exclusions: a list of modules to exclude - :param bool gather_errors: + :param directory: path to check for python files + :param prefix: package prefix top add to the file name + :param remove_pyc_files: True if ``.pyc`` files should be deleted + :param exclusions: a list of modules to exclude + :param gather_errors: True if errors should be gathered, False to report on first error - :return: None """ if exclusions is None: exclusions = [] @@ -103,23 +106,26 @@ def load_modules( def load_module( - name, remove_pyc_files=False, exclusions=None, gather_errors=True): + name: str, remove_pyc_files: bool = False, + exclusions: Optional[List[str]] = None, + gather_errors: bool = True) -> None: """ Loads this modules and all its children. - :param str name: name of the modules - :param bool remove_pyc_files: True if ``.pyc`` files should be deleted - :param list(str) exclusions: a list of modules to exclude - :param bool gather_errors: + :param name: name of the modules + :param remove_pyc_files: True if ``.pyc`` files should be deleted + :param exclusions: a list of modules to exclude + :param gather_errors: True if errors should be gathered, False to report on first error - :return: None """ overrides.check_types() if exclusions is None: exclusions = [] module = __import__(name) path = module.__file__ + assert path is not None directory = os.path.dirname(path) + assert directory is not None load_modules(directory, name, remove_pyc_files, exclusions, gather_errors) diff --git a/spinn_utilities/ping.py b/spinn_utilities/ping.py index 01909088..5de9df67 100644 --- a/spinn_utilities/ping.py +++ b/spinn_utilities/ping.py @@ -27,17 +27,16 @@ class Ping(object): unreachable: Set[str] = set() @staticmethod - def ping(ip_address): + def ping(ip_address: str) -> int: """ Send a ping (ICMP ECHO request) to the given host. SpiNNaker boards support ICMP ECHO when booted. - :param str ip_address: + :param ip_address: The IP address to ping. Hostnames can be used, but are not recommended. :return: return code of subprocess; 0 for success, anything else for failure - :rtype: int """ if platform.platform().lower().startswith("windows"): cmd = "ping -n 1 -w 1 " @@ -46,12 +45,14 @@ def ping(ip_address): process = subprocess.Popen( cmd + ip_address, shell=True, stdout=subprocess.PIPE) time.sleep(1.2) - process.stdout.close() + _stdout = process.stdout + assert _stdout is not None + _stdout.close() process.wait() return process.returncode @staticmethod - def host_is_reachable(ip_address): + def host_is_reachable(ip_address: str) -> bool: """ Test if a host is unreachable via ICMP ECHO. @@ -59,10 +60,9 @@ def host_is_reachable(ip_address): This information may be cached in various ways. Transient failures are not necessarily detected or recovered from. - :param str ip_address: + :param ip_address: The IP address to ping. Hostnames can be used, but are not recommended. - :rtype: bool """ if ip_address in Ping.unreachable: return False diff --git a/spinn_utilities/progress_bar.py b/spinn_utilities/progress_bar.py index 3364029e..c3beea29 100644 --- a/spinn_utilities/progress_bar.py +++ b/spinn_utilities/progress_bar.py @@ -19,7 +19,12 @@ import math import os import sys -from typing import Dict, Iterable, List, TypeVar, Union +from types import TracebackType +from typing import (Dict, Iterable, List, Optional, Tuple, Type, + TypeVar, Union) + +from typing_extensions import Literal, Self + from spinn_utilities.config_holder import get_config_bool from spinn_utilities.log import FormatAdapter from spinn_utilities.overrides import overrides @@ -48,14 +53,14 @@ class ProgressBar(object): ) def __init__(self, total_number_of_things_to_do: Union[int, Sized], - string_describing_what_being_progressed, - step_character="=", end_character="|"): + string_describing_what_being_progressed: str, + step_character: str = "=", end_character: str = "|"): if isinstance(total_number_of_things_to_do, Sized): self._number_of_things = len(total_number_of_things_to_do) else: self._number_of_things = int(total_number_of_things_to_do) self._currently_completed = 0 - self._chars_per_thing = None + self._chars_per_thing = 1.0 self._chars_done = 0 self._string = string_describing_what_being_progressed self._destination = sys.stderr @@ -69,7 +74,7 @@ def __init__(self, total_number_of_things_to_do: Union[int, Sized], self._create_initial_progress_bar( string_describing_what_being_progressed) - def update(self, amount_to_add=1): + def update(self, amount_to_add: int = 1) -> None: """ Update the progress bar by a given amount. @@ -81,10 +86,10 @@ def update(self, amount_to_add=1): self._currently_completed += amount_to_add self._check_differences() - def _print_overwritten_line(self, string: str): + def _print_overwritten_line(self, string: str) -> None: print("\r" + string, end="", file=self._destination) - def _print_distance_indicator(self, description: str): + def _print_distance_indicator(self, description: str) -> None: if description is not None: print(description, file=self._destination) @@ -105,12 +110,13 @@ def _print_distance_indicator(self, description: str): print("", file=self._destination) print(" ", end="", file=self._destination) - def _print_distance_line(self, first_space, second_space): + def _print_distance_line( + self, first_space: int, second_space: int) -> None: line = f"{self._end_character}0%{' ' * first_space}50%" \ f"{' ' * second_space}100%{self._end_character}" print(line, end="", file=self._destination) - def _print_progress(self, length: int): + def _print_progress(self, length: int) -> None: chars_to_print = length if not self._in_bad_terminal: self._print_overwritten_line(self._end_character) @@ -121,7 +127,7 @@ def _print_progress(self, length: int): self._print_progress_unit(chars_to_print) self._destination.flush() - def _print_progress_unit(self, chars_to_print): + def _print_progress_unit(self, chars_to_print: int) -> None: # pylint: disable=unused-argument print(self._step_character, end='', file=self._destination) @@ -132,7 +138,7 @@ def _print_progress_done(self) -> None: else: print("", file=self._destination) - def _create_initial_progress_bar(self, description): + def _create_initial_progress_bar(self, description: str) -> None: if self._number_of_things == 0: self._chars_per_thing = ProgressBar.MAX_LENGTH_IN_CHARS else: @@ -142,7 +148,7 @@ def _create_initial_progress_bar(self, description): self._print_progress(0) self._check_differences() - def _check_differences(self): + def _check_differences(self) -> None: expected_chars_done = int(math.floor( self._currently_completed * self._chars_per_thing)) if self._currently_completed == self._number_of_things: @@ -150,7 +156,7 @@ def _check_differences(self): self._print_progress(expected_chars_done) self._chars_done = expected_chars_done - def end(self): + def end(self) -> None: """ Close the progress bar, updating whatever is left if needed. """ @@ -159,10 +165,10 @@ def end(self): self._check_differences() self._print_progress_done() - def __repr__(self): + def __repr__(self) -> str: return f"" - def __enter__(self): + def __enter__(self) -> Self: """ Support method to use the progress bar as a context manager:: @@ -184,7 +190,8 @@ def __enter__(self): """ return self - def __exit__(self, exty, exval, traceback): + def __exit__(self, exc_type: Optional[Type], exc_val: Exception, + exc_tb: TracebackType) -> Literal[False]: self.end() return False @@ -211,7 +218,7 @@ def over(self, collection: Iterable[T], if finish_at_end: self.end() - def __new__(cls, *args, **kwargs): + def __new__(cls, *args: Tuple[int, str], **kwargs: Dict) -> "ProgressBar": # pylint: disable=unused-argument c = cls if _EnhancedProgressBar._enabled: @@ -228,12 +235,12 @@ class _EnhancedProgressBar(ProgressBar): """ _line_no = 0 - _seq_id = 0 - _step_characters: Dict[int, List[str]] = defaultdict(list) + _seq_id = "Unset" + _step_characters: Dict[str, List[str]] = defaultdict(list) _enabled = False _DATA_FILE = "progress_bar.txt" - def _print_progress_unit(self, chars_to_print): + def _print_progress_unit(self, chars_to_print: int) -> None: song_line = self.__line if not self._in_bad_terminal: print(song_line[0:self._chars_done + chars_to_print], @@ -243,7 +250,7 @@ def _print_progress_unit(self, chars_to_print): end='', file=self._destination) self._chars_done += 1 - def _print_progress_done(self): + def _print_progress_done(self) -> None: self._print_progress(ProgressBar.MAX_LENGTH_IN_CHARS) if not self._in_bad_terminal: self._print_overwritten_line(self._end_character) @@ -255,19 +262,19 @@ def _print_progress_done(self): self.__next_line() @property - def __line(self): + def __line(self) -> str: return _EnhancedProgressBar._step_characters[ _EnhancedProgressBar._seq_id][_EnhancedProgressBar._line_no] @classmethod - def __next_line(cls): + def __next_line(cls) -> None: if cls._line_no + 1 >= len(cls._step_characters[cls._seq_id]): cls._line_no = 0 else: cls._line_no += 1 @classmethod - def init_once(cls): + def init_once(cls) -> None: """ At startup reads progress bar data from file to be used every time """ @@ -301,7 +308,7 @@ def init_once(cls): cls._enabled = ( date.today().strftime("%m%d") in cls._step_characters) except IOError: - cls._seq_id = 0 + cls._seq_id = "error" finally: cls._line_no = 0 if cls._enabled: @@ -322,22 +329,22 @@ class DummyProgressBar(ProgressBar): fails in exactly the same way. """ @overrides(ProgressBar._print_overwritten_line) - def _print_overwritten_line(self, string: str): + def _print_overwritten_line(self, string: str) -> None: pass @overrides(ProgressBar._print_distance_indicator) - def _print_distance_indicator(self, description: str): + def _print_distance_indicator(self, description: str) -> None: pass @overrides(ProgressBar._print_progress) - def _print_progress(self, length: int): + def _print_progress(self, length: int) -> None: pass @overrides(ProgressBar._print_progress_done) def _print_progress_done(self) -> None: pass - def __repr__(self): + def __repr__(self) -> str: return f"" diff --git a/spinn_utilities/ranged/abstract_dict.py b/spinn_utilities/ranged/abstract_dict.py index cb227d2a..562732a5 100644 --- a/spinn_utilities/ranged/abstract_dict.py +++ b/spinn_utilities/ranged/abstract_dict.py @@ -67,7 +67,7 @@ def keys(self) -> Iterable[str]: @abstractmethod def set_value( - self, key: str, value: T, use_list_as_value: bool = False): + self, key: str, value: T, use_list_as_value: bool = False) -> None: """ Sets a already existing key to the new value. All IDs in the whole range or view will have this key set. @@ -108,7 +108,8 @@ def ids(self) -> Sequence[int]: raise NotImplementedError @overload - def iter_all_values(self, key: str, update_safe=False) -> Iterator[T]: + def iter_all_values( + self, key: str, update_safe: bool = False) -> Iterator[T]: ... @overload @@ -117,7 +118,9 @@ def iter_all_values(self, key: Optional[_StrSeq], ... @abstractmethod - def iter_all_values(self, key: _Keys, update_safe: bool = False): + def iter_all_values( + self, key: _Keys, update_safe: bool = False + ) -> Union[Iterator[T], Iterator[Dict[str, T]]]: """ Iterates over the value(s) for all IDs covered by this view. There will be one yield for each ID even if values are repeated. @@ -329,7 +332,7 @@ def has_key(self, key: str) -> bool: """ return key in self.keys() - def reset(self, key: str): + def reset(self, key: str) -> None: """ Sets the value(s) for a single key back to the default value. diff --git a/spinn_utilities/ranged/abstract_list.py b/spinn_utilities/ranged/abstract_list.py index 7ac22124..b7bbe442 100644 --- a/spinn_utilities/ranged/abstract_list.py +++ b/spinn_utilities/ranged/abstract_list.py @@ -90,7 +90,7 @@ class AbstractList(AbstractSized, Generic[T], metaclass=AbstractBase): """ __slots__ = ("_key", ) - def __init__(self, size: int, key=None): + def __init__(self, size: int, key: Optional[str] = None) -> None: """ :param int size: Fixed length of the list :param key: The dict key this list covers. diff --git a/spinn_utilities/ranged/abstract_sized.py b/spinn_utilities/ranged/abstract_sized.py index 39dd68af..ed66864b 100644 --- a/spinn_utilities/ranged/abstract_sized.py +++ b/spinn_utilities/ranged/abstract_sized.py @@ -164,7 +164,8 @@ def _check_mask_size(self, selector: Sized) -> None: "but the length was only %d. All the missing entries will be " "ignored!", self._size, len(selector)) - def selector_to_ids(self, selector: Selector, warn=False) -> Sequence[int]: + def selector_to_ids( + self, selector: Selector, warn: bool = False) -> Sequence[int]: """ Gets the list of IDs covered by this selector. The types of selector currently supported are: diff --git a/spinn_utilities/ranged/abstract_view.py b/spinn_utilities/ranged/abstract_view.py index 7f626254..8fd86229 100644 --- a/spinn_utilities/ranged/abstract_view.py +++ b/spinn_utilities/ranged/abstract_view.py @@ -37,7 +37,8 @@ def __init__(self, range_dict: RangeDictionary[T]): """ self._range_dict = range_dict - def __getitem__(self, key: Union[int, slice, Iterable[int]]): + def __getitem__(self, key: Union[int, slice, Iterable[int]] + ) -> AbstractView[T]: """ Support for the view[x] based the type of the key @@ -64,7 +65,7 @@ def __getitem__(self, key: Union[int, slice, Iterable[int]]): return self._range_dict.view_factory(ids[key]) return self._range_dict.view_factory([ids[i] for i in key]) - def __setitem__(self, key: str, value: T): + def __setitem__(self, key: str, value: T) -> None: """ See :py:meth:`AbstractDict.set_value` diff --git a/spinn_utilities/ranged/ids_view.py b/spinn_utilities/ranged/ids_view.py index 3a177603..4b6de1b3 100644 --- a/spinn_utilities/ranged/ids_view.py +++ b/spinn_utilities/ranged/ids_view.py @@ -66,12 +66,12 @@ def get_value(self, key: _Keys) -> Union[T, Dict[str, T]]: @overrides(AbstractDict.set_value) def set_value( - self, key: str, value: T, use_list_as_value: bool = False): + self, key: str, value: T, use_list_as_value: bool = False) -> None: ranged_list = self._range_dict.get_list(key) for _id in self._ids: ranged_list.set_value_by_id(the_id=_id, value=value) - def set_value_by_ids(self, key: str, ids: Iterable[int], value: T): + def set_value_by_ids(self, key: str, ids: Iterable[int], value: T) -> None: """ Sets a already existing key to the new value. For the view specified. @@ -84,7 +84,8 @@ def set_value_by_ids(self, key: str, ids: Iterable[int], value: T): rl.set_value_by_id(the_id=_id, value=value) @overload - def iter_all_values(self, key: str, update_safe=False) -> Iterator[T]: + def iter_all_values( + self, key: str, update_safe: bool = False) -> Iterator[T]: ... @overload @@ -93,7 +94,9 @@ def iter_all_values(self, key: Optional[_StrSeq], ... @overrides(AbstractDict.iter_all_values) - def iter_all_values(self, key: _Keys, update_safe: bool = False): + def iter_all_values( + self, key: _Keys, update_safe: bool = False + ) -> Union[Iterator[T], Iterator[Dict[str, T]]]: if isinstance(key, str): yield from self._range_dict.iter_values_by_ids( ids=self._ids, key=key, update_safe=update_safe) diff --git a/spinn_utilities/ranged/multiple_values_exception.py b/spinn_utilities/ranged/multiple_values_exception.py index dedb5aa7..86149891 100644 --- a/spinn_utilities/ranged/multiple_values_exception.py +++ b/spinn_utilities/ranged/multiple_values_exception.py @@ -12,13 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Optional + class MultipleValuesException(Exception): """ Raised when there more than one value found unexpectedly. """ - def __init__(self, key, value1, value2): + def __init__(self, key: Optional[str], value1: Any, value2: Any): if key is None: msg = "Multiple values found" else: diff --git a/spinn_utilities/ranged/range_dictionary.py b/spinn_utilities/ranged/range_dictionary.py index b82f80ba..12de84c2 100644 --- a/spinn_utilities/ranged/range_dictionary.py +++ b/spinn_utilities/ranged/range_dictionary.py @@ -146,7 +146,8 @@ def __getitem__(self, key: str) -> RangedList[T]: ... @overload def __getitem__(self, key: _KeyType) -> AbstractView: ... - def __getitem__(self, key): + def __getitem__(self, key: Union[str, _KeyType] + ) -> Union[RangedList[T], AbstractView]: """ Support for the view[x] based the type of the key @@ -189,7 +190,9 @@ def get_values_by_id(self, key: str, the_id: int) -> T: ... def get_values_by_id( self, key: Optional[_StrSeq], the_id: int) -> Dict[str, T]: ... - def get_values_by_id(self, key, the_id) -> Union[T, Dict[str, T]]: + def get_values_by_id( + self, key: Union[str, _StrSeq, None], + the_id: int) -> Union[T, Dict[str, T]]: """ Same as :py:meth:`get_value` but limited to a single ID. @@ -250,7 +253,8 @@ def iter_all_values(self, key: Optional[_StrSeq], ... @overrides(AbstractDict.iter_all_values, extend_defaults=True) - def iter_all_values(self, key: _Keys, update_safe: bool = False): + def iter_all_values(self, key: _Keys, update_safe: bool = False + ) -> Union[Iterator[T], Iterator[Dict[str, T]]]: if isinstance(key, str): if update_safe: return self._value_lists[key].iter() @@ -275,8 +279,10 @@ def iter_values_by_slice( ... def iter_values_by_slice( - self, slice_start: int, slice_stop: int, key=None, - update_safe=False): + self, slice_start: int, slice_stop: int, + key: Union[str, _StrSeq, None] = None, + update_safe: bool = False) -> Union[ + Iterator[T], Iterator[Dict[str, T]]]: """ Same as :py:meth:`iter_all_values` but limited to a simple slice. """ @@ -291,17 +297,21 @@ def iter_values_by_slice( slice_start=slice_start, slice_stop=slice_stop, key=key)) @overload - def iter_values_by_ids(self, ids: IdsType, key: str, - update_safe: bool = False) -> Iterator[T]: + def iter_values_by_ids( + self, ids: IdsType, key: str, + update_safe: bool = False) -> Generator[T, None, None]: ... @overload def iter_values_by_ids( self, ids: IdsType, key: Optional[_StrSeq] = None, - update_safe: bool = False) -> Iterator[Dict[str, T]]: + update_safe: bool = False) -> Generator[Dict[str, T], None, None]: ... - def iter_values_by_ids(self, ids: IdsType, key=None, update_safe=False): + def iter_values_by_ids( + self, ids: IdsType, key: Union[str, _StrSeq, None] = None, + update_safe: bool = False) -> Union[ + Generator[T, None, None], Generator[Dict[str, T], None, None]]: """ Same as :py:meth:`iter_all_values` but limited to a simple slice. """ @@ -311,18 +321,19 @@ def iter_values_by_ids(self, ids: IdsType, key=None, update_safe=False): key=key, ids=ids)) @staticmethod - def _values_from_ranges(ranges: _SimpleRangeIter) -> Iterable[T]: + def _values_from_ranges( + ranges: _SimpleRangeIter) -> Generator[T, None, None]: for (start, stop, value) in ranges: for _ in range(start, stop): yield value @overrides(AbstractDict.set_value) def set_value( - self, key: str, value: T, use_list_as_value: bool = False): + self, key: str, value: T, use_list_as_value: bool = False) -> None: self._value_lists[key].set_value( value, use_list_as_value=use_list_as_value) - def __setitem__(self, key: str, value: Union[T, RangedList[T]]): + def __setitem__(self, key: str, value: Union[T, RangedList[T]]) -> None: """ Wrapper around set_value to support ``range["key"] =`` @@ -413,9 +424,9 @@ def iter_ranges(self, key: Optional[_StrSeq]) -> Iterator[Tuple[ ... @overrides(AbstractDict.iter_ranges) - def iter_ranges(self, key: _Keys = None) -> \ - Union[Iterator[Tuple[int, int, T]], - Iterator[Tuple[int, int, Dict[str, T]]]]: + def iter_ranges(self, key: _Keys = None) -> Union[ + Iterator[Tuple[int, int, T]], + Iterator[Tuple[int, int, Dict[str, T]]]]: if isinstance(key, str): return self._value_lists[key].iter_ranges() if key is None: @@ -467,7 +478,9 @@ def iter_ranges_by_slice( slice_stop: int) -> _CompoundRangeIter: ... - def iter_ranges_by_slice(self, key, slice_start: int, slice_stop: int): + def iter_ranges_by_slice( + self, key: Union[str, _StrSeq, None], slice_start: int, + slice_stop: int) -> Union[_SimpleRangeIter, _CompoundRangeIter]: """ Same as :py:meth:`iter_ranges` but limited to a simple slice. @@ -501,7 +514,10 @@ def iter_ranges_by_ids( key: Optional[_StrSeq] = None) -> _CompoundRangeIter: ... - def iter_ranges_by_ids(self, ids: IdsType, key=None): + def iter_ranges_by_ids( + self, ids: IdsType, + key: Union[str, _StrSeq, None] = None) -> Union[ + _SimpleRangeIter, _CompoundRangeIter]: """ Same as :py:meth:`iter_ranges` but limited to a collection of IDs. @@ -519,7 +535,7 @@ def iter_ranges_by_ids(self, ids: IdsType, key=None): a_key: self._value_lists[a_key].iter_ranges_by_ids(ids=ids) for a_key in key}) - def set_default(self, key: str, default: T): + def set_default(self, key: str, default: T) -> None: """ Sets the default value for a single key. @@ -541,7 +557,7 @@ def set_default(self, key: str, default: T): def get_default(self, key: str) -> Optional[T]: return self._value_lists[key].get_default() - def copy_into(self, other: RangeDictionary[T]): + def copy_into(self, other: RangeDictionary[T]) -> None: """ Turns this dict into a copy of the other dict but keep its id. diff --git a/spinn_utilities/ranged/ranged_list.py b/spinn_utilities/ranged/ranged_list.py index 8bb87b7b..62eb3435 100644 --- a/spinn_utilities/ranged/ranged_list.py +++ b/spinn_utilities/ranged/ranged_list.py @@ -66,7 +66,8 @@ class RangedList(AbstractList[T], Generic[T]): def __init__( self, size: Optional[int] = None, value: _ValueType = None, - key=None, use_list_as_value=False): + key: Optional[str] = None, use_list_as_value: bool = False + ) -> None: """ :param size: Fixed length of the list; @@ -339,7 +340,8 @@ def as_list( f"does not equal the size:{size}") return values - def set_value(self, value: _ValueType, use_list_as_value=False): + def set_value( + self, value: _ValueType, use_list_as_value: bool = False) -> None: """ Sets *all* elements in the list to this value. @@ -361,7 +363,7 @@ def set_value(self, value: _ValueType, use_list_as_value=False): self._ranges = [(0, self._size, value)] self._ranged_based = True - def set_value_by_id(self, the_id: int, value: T): + def set_value_by_id(self, the_id: int, value: T) -> None: """ Sets the value for a single ID to the new value. @@ -422,7 +424,7 @@ def set_value_by_id(self, the_id: int, value: T): def set_value_by_slice( self, slice_start: int, slice_stop: int, value: _ValueType, - use_list_as_value=False): + use_list_as_value: bool = False) -> None: """ Sets the value for a single range to the new value. @@ -504,13 +506,14 @@ def set_value_by_slice( # set the value in case missed elsewhere ranges[index] = (ranges[index][0], ranges[index][1], value) - def _set_values_list(self, ids: IdsType, value: _ListType): + def _set_values_list(self, ids: IdsType, value: _ListType) -> None: values = self.as_list(value=value, size=len(ids), ids=ids) for id_value, val in zip(ids, values): self.set_value_by_id(id_value, val) def set_value_by_ids( - self, ids: IdsType, value: _ValueType, use_list_as_value=False): + self, ids: IdsType, value: _ValueType, + use_list_as_value: bool = False) -> None: """ Sets a already existing key to the new value. For the ids specified. @@ -526,7 +529,7 @@ def set_value_by_ids( def set_value_by_selector( self, selector: Selector, value: _ValueType, - use_list_as_value=False): + use_list_as_value: bool = False) -> None: """ Support for the ``list[x] =`` format. @@ -564,7 +567,7 @@ def get_ranges(self) -> List[_RangeType]: return list(self.__the_ranges) return list(self.iter_ranges()) - def set_default(self, default: Optional[T]): + def set_default(self, default: Optional[T]) -> None: """ Sets the default value. @@ -588,7 +591,7 @@ def get_default(self) -> Optional[T]: except AttributeError as e: raise AttributeError("Default value not set.") from e - def copy_into(self, other: RangedList[T]): + def copy_into(self, other: RangedList[T]) -> None: """ Turns this List into a of the other list but keep its ID. diff --git a/spinn_utilities/ranged/single_view.py b/spinn_utilities/ranged/single_view.py index 7cc699ff..b2412a85 100644 --- a/spinn_utilities/ranged/single_view.py +++ b/spinn_utilities/ranged/single_view.py @@ -64,17 +64,20 @@ def get_value(self, key: _Keys) -> Union[T, Dict[str, T]]: for k in key} @overload - def iter_all_values(self, key: str, update_safe=False) -> Iterator[T]: + def iter_all_values(self, key: str, update_safe: bool = False + ) -> Iterator[T]: ... @overload def iter_all_values( - self, key: Optional[_StrSeq], update_safe=False) -> Iterator[ - Dict[str, T]]: + self, key: Optional[_StrSeq], update_safe: bool = False + ) -> Iterator[Dict[str, T]]: ... @overrides(AbstractDict.iter_all_values) - def iter_all_values(self, key: _Keys, update_safe: bool = False): + def iter_all_values( + self, key: _Keys, update_safe: bool = False + ) -> Union[Iterator[T], Iterator[Dict[str, T]]]: if isinstance(key, str): yield self._range_dict.get_list(key).get_value_by_id( the_id=self._id) @@ -82,7 +85,8 @@ def iter_all_values(self, key: _Keys, update_safe: bool = False): yield self._range_dict.get_values_by_id(key=key, the_id=self._id) @overrides(AbstractDict.set_value) - def set_value(self, key: str, value: T, use_list_as_value: bool = False): + def set_value(self, key: str, value: T, use_list_as_value: bool = False + ) -> None: return self._range_dict.get_list(key).set_value_by_id( value=value, the_id=self._id) diff --git a/spinn_utilities/ranged/slice_view.py b/spinn_utilities/ranged/slice_view.py index 924c322e..42afdb77 100644 --- a/spinn_utilities/ranged/slice_view.py +++ b/spinn_utilities/ranged/slice_view.py @@ -13,7 +13,7 @@ # limitations under the License. from __future__ import annotations from typing import ( - Dict, Generic, Iterable, Iterator, Optional, Sequence, Tuple, overload, + Dict, Generic, Iterator, Optional, Sequence, Tuple, overload, TYPE_CHECKING, Union) from spinn_utilities.overrides import overrides from .abstract_dict import AbstractDict, T, _StrSeq, _Keys @@ -65,7 +65,7 @@ def get_value(self, key: _Keys) -> Union[T, Dict[str, T]]: slice_start=self._start, slice_stop=self._stop) for k in key} - def update_safe_iter_all_values(self, key: str) -> Iterable[T]: + def update_safe_iter_all_values(self, key: str) -> Iterator[T]: """ Iterate over the Values in a way that will work even between updates @@ -78,17 +78,18 @@ def update_safe_iter_all_values(self, key: str) -> Iterable[T]: @overload def iter_all_values( - self, key: str, update_safe=False) -> Iterator[T]: + self, key: str, update_safe: bool = False) -> Iterator[T]: ... @overload def iter_all_values( self, key: Optional[_StrSeq] = None, - update_safe=False) -> Iterator[Dict[str, T]]: + update_safe: bool = False) -> Iterator[Dict[str, T]]: ... @overrides(AbstractDict.iter_all_values, extend_defaults=True) - def iter_all_values(self, key: _Keys = None, update_safe: bool = False): + def iter_all_values(self, key: _Keys = None, update_safe: bool = False + ) -> Union[Iterator[T], Iterator[Dict[str, T]]]: if isinstance(key, str): if update_safe: return self.update_safe_iter_all_values(key) @@ -100,7 +101,7 @@ def iter_all_values(self, key: _Keys = None, update_safe: bool = False): @overrides(AbstractDict.set_value) def set_value(self, key: str, value: _ValueType, - use_list_as_value: bool = False): + use_list_as_value: bool = False) -> None: self._range_dict.get_list(key).set_value_by_slice( slice_start=self._start, slice_stop=self._stop, value=value, use_list_as_value=use_list_as_value) diff --git a/spinn_utilities/require_subclass.py b/spinn_utilities/require_subclass.py index 89f5e14f..64ce3707 100644 --- a/spinn_utilities/require_subclass.py +++ b/spinn_utilities/require_subclass.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Callable, Dict, Type + class _RequiresSubclassTypeError(TypeError): """ @@ -22,7 +24,7 @@ class _RequiresSubclassTypeError(TypeError): """ -def require_subclass(required_class): +def require_subclass(required_class: Type) -> Callable[[Type], Type]: """ Decorator that arranges for subclasses of the decorated class to require that they are also subclasses of the given class. @@ -52,20 +54,22 @@ class AbstractVirtual(object): # without it, some very weird interactions with meta classes happen and I # really don't want to debug that stuff. - def decorate(target_class): + def decorate(target_class: Type) -> Type: # pylint: disable=unused-variable __class__ = target_class # @ReservedAssignment # noqa: F841 - def __init_subclass__(cls, allow_derivation=False, **kwargs): + def __init_subclass__( + cls: Type, allow_derivation: bool = False, + **kwargs: Dict[str, Any]) -> None: if not issubclass(cls, required_class) and not allow_derivation: raise _RequiresSubclassTypeError( f"{cls.__name__} must be a subclass " f"of {required_class.__name__} and the derivation was not " "explicitly allowed with allow_derivation=True") try: - super().__init_subclass__(**kwargs) + super().__init_subclass__(**kwargs) # type: ignore[misc] except _RequiresSubclassTypeError: - super().__init_subclass__( + super().__init_subclass__( # type: ignore[misc] allow_derivation=allow_derivation, **kwargs) setattr(target_class, '__init_subclass__', diff --git a/spinn_utilities/safe_eval.py b/spinn_utilities/safe_eval.py index 13c5f8f2..fce80c44 100644 --- a/spinn_utilities/safe_eval.py +++ b/spinn_utilities/safe_eval.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from types import ModuleType +from typing import Any, Callable, Dict, Union + class SafeEval(object): """ @@ -42,7 +45,8 @@ class SafeEval(object): """ __slots__ = ["_environment"] - def __init__(self, *args, **kwargs): + def __init__(self, *args: Union[Callable, ModuleType], + **kwargs: Any) -> None: """ :param args: The symbols to use to populate the global reference table. @@ -58,13 +62,13 @@ def __init__(self, *args, **kwargs): symbols (e.g., constants in numpy) do not have names that we can otherwise look up easily. """ - env = {} + env: Dict[Any, Any] = {} for item in args: env[item.__name__] = item env.update(kwargs) self._environment = env - def eval(self, expression, **kwargs): + def eval(self, expression: str, **kwargs: Any) -> Any: """ Evaluate an expression and return the result. diff --git a/spinn_utilities/socket_address.py b/spinn_utilities/socket_address.py index dde82a27..43d30c6e 100644 --- a/spinn_utilities/socket_address.py +++ b/spinn_utilities/socket_address.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Any, Optional from spinn_utilities.config_holder import ( get_config_int, get_config_int_or_none, get_config_str) @@ -82,22 +82,22 @@ def listen_port(self) -> Optional[int]: """ return self._listen_port - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if not isinstance(other, SocketAddress): return False return (self._notify_host_name == other.notify_host_name and self._notify_port_no == other.notify_port_no and self._listen_port == other.listen_port) - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self.__eq__(other) - def __hash__(self): + def __hash__(self) -> int: if self.__hash is None: self.__hash = hash((self._listen_port, self._notify_host_name, self._notify_port_no)) return self.__hash - def __repr__(self): + def __repr__(self) -> str: return (f"SocketAddress({repr(self._notify_host_name)}, " f"{self._notify_port_no}, {self._listen_port})") diff --git a/spinn_utilities/testing/log_checker.py b/spinn_utilities/testing/log_checker.py index 063bf6b6..f4c8532b 100644 --- a/spinn_utilities/testing/log_checker.py +++ b/spinn_utilities/testing/log_checker.py @@ -12,10 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from logging import LogRecord +from typing import List + _WRITE_LOGS_TO_STDOUT = True -def _assert_logs_contains(level, log_records, submessage): +def _assert_logs_contains( + level: str, log_records: List[LogRecord], submessage: str) -> None: for record in log_records: if record.levelname == level and submessage in record.getMessage(): return @@ -25,7 +29,8 @@ def _assert_logs_contains(level, log_records, submessage): raise AssertionError(f"\"{submessage}\" not found in any {level} logs") -def _assert_logs_not_contains(level, log_records, submessage): +def _assert_logs_not_contains( + level: str, log_records: List[LogRecord], submessage: str) -> None: for record in log_records: if _WRITE_LOGS_TO_STDOUT: # pragma: no cover print(record) @@ -34,7 +39,8 @@ def _assert_logs_not_contains(level, log_records, submessage): f"\"{submessage}\" found in any {level} logs") -def assert_logs_contains_once(level, log_records, message): +def assert_logs_contains_once( + level: int, log_records: List[LogRecord], message: str) -> None: """ Checks if the log records contain exactly one record at the given level with the given sub-message. @@ -65,7 +71,8 @@ def assert_logs_contains_once(level, log_records, message): raise AssertionError(f"\"{message}\" not found in any {level} logs") -def assert_logs_error_contains(log_records, submessage): +def assert_logs_error_contains( + log_records: List[LogRecord], submessage: str) -> None: """ Checks it the log records contain an ERROR log with this sub-message @@ -80,7 +87,8 @@ def assert_logs_error_contains(log_records, submessage): _assert_logs_contains('ERROR', log_records, submessage) -def assert_logs_warning_contains(log_records, submessage): +def assert_logs_warning_contains( + log_records: List[LogRecord], submessage: str) -> None: """ Checks it the log records contain an WARNING log with this sub-message @@ -95,7 +103,8 @@ def assert_logs_warning_contains(log_records, submessage): _assert_logs_contains('WARNING', log_records, submessage) -def assert_logs_info_contains(log_records, sub_message): +def assert_logs_info_contains( + log_records: List[LogRecord], sub_message: str) -> None: """ Checks it the log records contain an INFO log with this sub-message @@ -110,7 +119,8 @@ def assert_logs_info_contains(log_records, sub_message): _assert_logs_contains('INFO', log_records, sub_message) -def assert_logs_error_not_contains(log_records, submessage): +def assert_logs_error_not_contains( + log_records: List[LogRecord], submessage: str) -> None: """ Checks it the log records do not contain an ERROR log with this sub-message. @@ -126,7 +136,8 @@ def assert_logs_error_not_contains(log_records, submessage): _assert_logs_not_contains('ERROR', log_records, submessage) -def assert_logs_info_not_contains(log_records, submessage): +def assert_logs_info_not_contains( + log_records: List[LogRecord], submessage: str) -> None: """ Checks it the log records do not contain an INFO log with this sub-message. diff --git a/spinn_utilities/timer.py b/spinn_utilities/timer.py index 75d4763e..6cd9809f 100644 --- a/spinn_utilities/timer.py +++ b/spinn_utilities/timer.py @@ -13,7 +13,7 @@ # limitations under the License. from datetime import timedelta from time import perf_counter_ns -from typing import Optional +from typing import Any, Optional, Tuple from typing_extensions import Literal # conversion factor @@ -75,7 +75,7 @@ def __enter__(self) -> 'Timer': self.start_timing() return self - def __exit__(self, *_args) -> Literal[False]: + def __exit__(self, *_args: Tuple[Any, ...]) -> Literal[False]: self._measured_section_interval = self.take_sample() return False diff --git a/unittests/test_log.py b/unittests/test_log.py index 56e20333..c0a3c093 100644 --- a/unittests/test_log.py +++ b/unittests/test_log.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import configparser from datetime import datetime import logging from typing import List, Optional @@ -136,7 +137,7 @@ class Exn(Exception): assert len(logger._pop_not_stored_messages()) == 1 -class MockConfig1(object): +class MockConfig1(configparser.RawConfigParser): def get(self, section, option): return "debug"