Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug on gettype on empty array, add validator tests #319

Draft
wants to merge 1 commit into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion src/hdmf/validate/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
"Error",
"DtypeError",
"MissingError",
"ExpectedScalarError",
"ExpectedArrayError",
"ShapeError",
"MissingDataType",
"IllegalLinkError",
"IncorrectDataType"
"IncorrectDataType",
"EmptyDataNoTypeWarning"
]


Expand Down Expand Up @@ -96,6 +98,19 @@ def data_type(self):
return self.__data_type


class ExpectedScalarError(Error):

@docval({'name': 'name', 'type': str, 'doc': 'the name of the component that is erroneous'},
{'name': 'received', 'type': (tuple, list), 'doc': 'the received data'},
{'name': 'location', 'type': str, 'doc': 'the location of the error', 'default': None})
def __init__(self, **kwargs):
name = getargs('name', kwargs)
received = getargs('received', kwargs)
reason = "incorrect shape - expected a scalar, got array with shape '%s'" % str(received)
loc = getargs('location', kwargs)
super().__init__(name, reason, location=loc)


class ExpectedArrayError(Error):

@docval({'name': 'name', 'type': str, 'doc': 'the name of the component that is erroneous'},
Expand Down Expand Up @@ -157,3 +172,24 @@ def __init__(self, **kwargs):
reason = "incorrect data_type - expected '%s', got '%s'" % (expected, received)
loc = getargs('location', kwargs)
super().__init__(name, reason, location=loc)


class ValidatorWarning(UserWarning):

pass


class EmptyDataNoTypeWarning(Error, ValidatorWarning):
"""
A warning for indicating that a value is empty and has no data type (e.g., an empty list).
"""

@docval({'name': 'name', 'type': str, 'doc': 'the name of the component that is erroneous'},
{'name': 'data_type', 'type': type, 'doc': 'the type of the data'},
{'name': 'location', 'type': str, 'doc': 'the location of the error', 'default': None})
def __init__(self, **kwargs):
name = getargs('name', kwargs)
data_type = getargs('data_type', kwargs)
reason = "could not determine data type for empty data %s" % data_type
loc = getargs('location', kwargs)
super().__init__(name, reason, location=loc)
129 changes: 80 additions & 49 deletions src/hdmf/validate/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from copy import copy
import re
from itertools import chain
from warnings import warn

from ..utils import docval, getargs, call_docval_func, pystr, get_data_shape

Expand All @@ -14,7 +15,7 @@
from ..build.builders import BaseBuilder

from .errors import Error, DtypeError, MissingError, MissingDataType, ShapeError, IllegalLinkError, IncorrectDataType
from .errors import ExpectedArrayError
from .errors import ExpectedArrayError, ExpectedScalarError, EmptyDataNoTypeWarning, ValidatorWarning

__synonyms = DtypeHelper.primary_dtype_synonyms

Expand Down Expand Up @@ -104,42 +105,41 @@ def get_type(data):
return 'region'
elif isinstance(data, ReferenceBuilder):
return 'object'
elif isinstance(data, np.ndarray):
return get_type(data[0])
if not hasattr(data, '__len__'):
elif not hasattr(data, '__len__'):
return type(data).__name__
# for conditions below, len(data) works
elif hasattr(data, 'dtype'):
if data.dtype.metadata is not None and data.dtype.metadata.get('vlen') is not None:
return get_type(data[0])
return data.dtype
elif len(data) == 0:
# raise ValueError('cannot determine type for empty data without dtype attribute')
return None
else:
if hasattr(data, 'dtype'):
if data.dtype.metadata is not None and data.dtype.metadata.get('vlen') is not None:
return get_type(data[0])
return data.dtype
if len(data) == 0:
raise ValueError('cannot determine type for empty data')
return get_type(data[0])


def check_shape(expected, received):
if expected is None: # scalar
return not isinstance(received, (list, tuple, np.ndarray))
ret = False
if expected is None:
ret = True
else:
if isinstance(expected, (list, tuple)):
if isinstance(expected[0], (list, tuple)):
for sub in expected:
if check_shape(sub, received):
ret = True
break
else:
if len(expected) > 0 and received is None:
ret = False
elif len(expected) == len(received):
if isinstance(expected, (list, tuple)):
if isinstance(expected[0], (list, tuple)):
for sub in expected:
if check_shape(sub, received):
ret = True
for e, r in zip(expected, received):
if not check_shape(e, r):
ret = False
break
elif isinstance(expected, int):
ret = expected == received
break
else:
if len(expected) > 0 and received is None:
ret = False
elif len(expected) == len(received):
ret = True
for e, r in zip(expected, received):
if not check_shape(e, r):
ret = False
break
elif isinstance(expected, int):
ret = expected == received
return ret


Expand Down Expand Up @@ -282,6 +282,30 @@ def get_builder_loc(cls, builder):
tmp = tmp.parent
return "/".join(reversed(stack))

@classmethod
def check_data_type(cls, value, spec, spec_loc, builder_location=None):
dtype = get_type(value)
if dtype is None:
# dtype = 'unknown data type of empty data %s' % type(value)
# return DtypeError(spec_loc, spec.dtype, dtype, location=builder_location)
return EmptyDataNoTypeWarning(spec_loc, type(value), location=builder_location)
elif not check_type(spec.dtype, dtype):
return DtypeError(spec_loc, spec.dtype, dtype, location=builder_location)
return None

@classmethod
def check_data_shape(cls, value, spec, spec_loc, builder_location=None):
shape = get_data_shape(value)
if shape == ():
shape = None
if shape is None and spec.shape is not None:
return ExpectedArrayError(spec_loc, spec.shape, str(value), location=builder_location)
elif shape is not None and spec.shape is None:
return ExpectedScalarError(spec_loc, shape, location=builder_location)
elif not check_shape(spec.shape, shape):
return ShapeError(spec_loc, spec.shape, shape, location=builder_location)
return None


class AttributeValidator(Validator):
'''A class for validating values against AttributeSpecs'''
Expand Down Expand Up @@ -313,12 +337,13 @@ def validate(self, **kwargs):
if spec.dtype.target_type not in hierarchy:
ret.append(IncorrectDataType(self.get_spec_loc(spec), spec.dtype.target_type, data_type))
else:
dtype = get_type(value)
if not check_type(spec.dtype, dtype):
ret.append(DtypeError(self.get_spec_loc(spec), spec.dtype, dtype))
shape = get_data_shape(value)
if not check_shape(spec.shape, shape):
ret.append(ShapeError(self.get_spec_loc(spec), spec.shape, shape))
dtype_err = self.check_data_type(value, spec, self.get_spec_loc(spec))
if dtype_err:
ret.append(dtype_err)

shape_err = self.check_data_shape(value, spec, self.get_spec_loc(spec))
if shape_err:
ret.append(shape_err)
return ret


Expand Down Expand Up @@ -349,7 +374,11 @@ def validate(self, **kwargs):
errors = validator.validate(attr_val)
for err in errors:
err.location = self.get_builder_loc(builder) + ".%s" % validator.spec.name
ret.extend(errors)
if isinstance(err, ValidatorWarning):
warn(err, type(err))
else:
ret.append(err) # return only errors, not warnings

return ret


Expand All @@ -367,19 +396,21 @@ def validate(self, **kwargs):
builder = getargs('builder', kwargs)
ret = super().validate(builder)
data = builder.data
if self.spec.dtype is not None:
dtype = get_type(data)
if not check_type(self.spec.dtype, dtype):
ret.append(DtypeError(self.get_spec_loc(self.spec), self.spec.dtype, dtype,
location=self.get_builder_loc(builder)))
shape = get_data_shape(data)
if not check_shape(self.spec.shape, shape):
if shape is None:
ret.append(ExpectedArrayError(self.get_spec_loc(self.spec), self.spec.shape, str(data),
location=self.get_builder_loc(builder)))
else:
ret.append(ShapeError(self.get_spec_loc(self.spec), self.spec.shape, shape,
location=self.get_builder_loc(builder)))
spec = self.spec

if spec.dtype is not None:
dtype_err = self.check_data_type(data, spec, self.get_spec_loc(spec),
builder_location=self.get_builder_loc(builder))
if dtype_err:
if isinstance(dtype_err, ValidatorWarning):
warn(dtype_err, type(dtype_err))
else:
ret.append(dtype_err) # return only errors, not warnings

shape_err = self.check_data_shape(data, spec, self.get_spec_loc(spec),
builder_location=self.get_builder_loc(builder))
if shape_err:
ret.append(shape_err)
return ret


Expand Down
Loading