Skip to content

Commit

Permalink
Merge pull request #121 from makoeppel/resolve_datatype/add_typeCheck…
Browse files Browse the repository at this point in the history
…_for_input

add type check for input name of resolve_datatype()
  • Loading branch information
maltanar authored Dec 20, 2024
2 parents 364f995 + 7e747d7 commit 51965ab
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 1 deletion.
3 changes: 3 additions & 0 deletions src/qonnx/core/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,9 @@ def get_canonical_name(self):


def resolve_datatype(name):
if not isinstance(name, str):
raise TypeError(f"Input 'name' must be of type 'str', but got type '{type(name).__name__}'")

_special_types = {
"BINARY": IntType(1, False),
"BIPOLAR": BipolarType(),
Expand Down
51 changes: 50 additions & 1 deletion tests/core/test_datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

import numpy as np

from qonnx.core.datatype import DataType
from qonnx.core.datatype import DataType, resolve_datatype


def test_datatypes():
Expand Down Expand Up @@ -97,3 +97,52 @@ def test_smallest_possible():
assert DataType.get_smallest_possible(-1) == DataType["BIPOLAR"]
assert DataType.get_smallest_possible(-3) == DataType["INT3"]
assert DataType.get_smallest_possible(-3.2) == DataType["FLOAT32"]


def test_resolve_datatype():
assert resolve_datatype("BIPOLAR")
assert resolve_datatype("BINARY")
assert resolve_datatype("TERNARY")
assert resolve_datatype("UINT2")
assert resolve_datatype("UINT3")
assert resolve_datatype("UINT4")
assert resolve_datatype("UINT8")
assert resolve_datatype("UINT16")
assert resolve_datatype("UINT32")
assert resolve_datatype("INT2")
assert resolve_datatype("INT3")
assert resolve_datatype("INT4")
assert resolve_datatype("INT8")
assert resolve_datatype("INT16")
assert resolve_datatype("INT32")
assert resolve_datatype("FLOAT32")


def test_input_type_error():
def test_resolve_datatype(input):
# test with invalid input to check if the TypeError works
try:
resolve_datatype(input) # This should raise a TypeError
except TypeError:
pass
else:
assert False, "Test with invalid input failed: No TypeError was raised."

test_resolve_datatype(123)
test_resolve_datatype(1.23)
test_resolve_datatype(DataType["BIPOLAR"])
test_resolve_datatype(DataType["BINARY"])
test_resolve_datatype(DataType["TERNARY"])
test_resolve_datatype(DataType["UINT2"])
test_resolve_datatype(DataType["UINT3"])
test_resolve_datatype(DataType["UINT4"])
test_resolve_datatype(DataType["UINT8"])
test_resolve_datatype(DataType["UINT16"])
test_resolve_datatype(DataType["UINT32"])
test_resolve_datatype(DataType["INT2"])
test_resolve_datatype(DataType["INT3"])
test_resolve_datatype(DataType["INT4"])
test_resolve_datatype(DataType["INT8"])
test_resolve_datatype(DataType["INT16"])
test_resolve_datatype(DataType["INT32"])
test_resolve_datatype(DataType["FLOAT32"])

0 comments on commit 51965ab

Please sign in to comment.