diff --git a/src/hdmf/spec/spec.py b/src/hdmf/spec/spec.py index 359a01736..fa5e76afd 100644 --- a/src/hdmf/spec/spec.py +++ b/src/hdmf/spec/spec.py @@ -376,6 +376,24 @@ def resolve_spec(self, **kwargs): for attribute in inc_spec.attributes: self.__new_attributes.discard(attribute) if attribute.name in self.__attributes: + # copy over the fields 'shape' and 'dims' of the parent attribute to this spec's attribute + # NOTE: because the default value for the 'required' field is True, it is not possible to know whether + # the 'required' field was explicitly set. thus, 'required' defaults to True no matter what the parent + # specification defines. + # NOTE: the value and default value of a parent attribute are also not copied to the child spec + my_attribute = self.get_attribute(attribute.name) + if attribute.shape is not None: + if my_attribute.shape is None: + my_attribute['shape'] = attribute.shape + else: + # TODO: test whether child shape is compatible with parent shape + pass + if attribute.dims is not None: + if my_attribute.dims is None: + my_attribute['dims'] = attribute.dims + else: + # TODO: test whether child dims is compatible with parent dims + pass self.__overridden_attributes.add(attribute.name) continue self.set_attribute(attribute) @@ -695,6 +713,30 @@ def __is_sub_dtype(cls, orig, new): @docval({'name': 'inc_spec', 'type': 'DatasetSpec', 'doc': 'the data type this specification represents'}) def resolve_spec(self, **kwargs): inc_spec = getargs('inc_spec', kwargs) + # copy over fields of the parent dataset to this spec + # NOTE: because the default value for the 'quantity' field is 1, it is not possible to know whether + # the 'quantity' field was explicitly set. thus, 'quantity' defaults to 1 no matter what the parent + # specification defines. + # NOTE: the default value of a parent attribute is also not copied to the child spec + if inc_spec.dtype is not None: + if self.dtype is None: + self['dtype'] = inc_spec.dtype + else: + # TODO: test whether child dtype is compatible with parent dtype + # e.g., if parent dtype is int, child dtype cannot be text + pass + if inc_spec.shape is not None: + if self.shape is None: + self['shape'] = inc_spec.shape + elif inc_spec.shape is not None: + # TODO: test whether child shape is compatible with parent shape + pass + if inc_spec.dims is not None: + if self.dims is None: + self['dims'] = inc_spec.dims + elif inc_spec.dims is not None: + # TODO: test whether child dims is compatible with parent dims + pass if isinstance(self.dtype, list): # merge the new types inc_dtype = inc_spec.dtype diff --git a/tests/unit/spec_tests/test_load_namespace.py b/tests/unit/spec_tests/test_load_namespace.py index 25c09f1a6..9dc24a808 100644 --- a/tests/unit/spec_tests/test_load_namespace.py +++ b/tests/unit/spec_tests/test_load_namespace.py @@ -223,3 +223,192 @@ def test_get_namespace_missing_version(self): namespace['version'] = None # work around lack of setter to remove version key self.assertEqual(namespace.version, SpecNamespace.UNVERSIONED) + + +class TestLoadSpecInheritProperties(TestCase): + NS_NAME = 'test_ns' + + def setUp(self): + self.specs_path = 'test_load_spec_inherit.specs.yaml' + self.namespace_path = 'test_load_spec_inherit.namespace.yaml' + + ns_dict = { + 'doc': 'a test namespace', + 'name': self.NS_NAME, + 'schema': [ + {'source': self.specs_path} + ], + 'version': '0.1.0' + } + self.namespace = SpecNamespace.build_namespace(**ns_dict) + to_dump = {'namespaces': [self.namespace]} + with open(self.namespace_path, 'w') as tmp: + yaml.safe_dump(json.loads(json.dumps(to_dump)), tmp, default_flow_style=False) + self.ns_catalog = NamespaceCatalog() + + def tearDown(self): + if os.path.exists(self.namespace_path): + os.remove(self.namespace_path) + if os.path.exists(self.specs_path): + os.remove(self.specs_path) + + def test_inherited_shape(self): + """Test that an extended dataset/attribute inherits the dims and shape of the original dataset/attribute""" + attributes = [ + AttributeSpec( + 'attribute1', + 'my first attribute', + 'text', + dims=['my_dims'], + shape=[None] + ) + ] + datasets = [ + DatasetSpec( + 'my first dataset', + 'int', + name='dataset1', + dims=['my_dims'], + shape=[None] + ) + ] + group_spec = GroupSpec( + 'A test group', + name='root_constructor_datatype', + datasets=datasets, + attributes=attributes, + data_type_def='MyGroup' + ) + + # same name, same dtype, change docstring, omit dims and shape specifications + ext_attributes = [ + AttributeSpec( + 'attribute1', + 'my first attribute extension', + 'text', + ) + ] + # same name, same dtype, change docstring, omit dims and shape specifications + ext_datasets = [ + DatasetSpec( + 'my first dataset extension', + 'int', + name='dataset1' + ) + ] + ext_group_spec = GroupSpec( + 'A test group extension', + name='root_constructor_datatype', + datasets=ext_datasets, + attributes=ext_attributes, + data_type_inc='MyGroup', + data_type_def='MyExtGroup' + ) + to_dump = {'groups': [group_spec, ext_group_spec]} + with open(self.specs_path, 'w') as tmp: + yaml.safe_dump(json.loads(json.dumps(to_dump)), tmp, default_flow_style=False) + + self.ns_catalog.load_namespaces(self.namespace_path) + + # first test that shape of the original dataset/attribute is set correctly + read_group_spec = self.ns_catalog.get_spec(self.NS_NAME, 'MyGroup') + read_dset_spec = read_group_spec.datasets[0] + read_attr_spec = read_group_spec.attributes[0] + self.assertEqual(read_dset_spec.dims, ['my_dims']) + self.assertEqual(read_attr_spec.dims, ['my_dims']) + self.assertEqual(read_dset_spec.shape, [None]) + self.assertEqual(read_attr_spec.shape, [None]) + + # then test that shape of the extended dataset/attribute is set correctly + read_group_spec = self.ns_catalog.get_spec(self.NS_NAME, 'MyExtGroup') + read_dset_spec = read_group_spec.datasets[0] + read_attr_spec = read_group_spec.attributes[0] + self.assertEqual(read_dset_spec.dims, ['my_dims']) + self.assertEqual(read_attr_spec.dims, ['my_dims']) + self.assertEqual(read_dset_spec.shape, [None]) + self.assertEqual(read_attr_spec.shape, [None]) + + def test_inherited_dtype(self): + """Test that an extended dataset inherits the dtype of the original dataset""" + # NOTE: we do not need to test whether dtype of attributes are extended because attribute dtype is required + datasets = [ + DatasetSpec( + 'my first dataset', + 'int', + name='dataset1', + ) + ] + group_spec = GroupSpec( + 'A test group', + name='root_constructor_datatype', + datasets=datasets, + data_type_def='MyGroup' + ) + + # same name, omit dtype, change docstring + ext_datasets = [ + DatasetSpec( + 'my first dataset extension', + name='dataset1' + ) + ] + ext_group_spec = GroupSpec( + 'A test group extension', + name='root_constructor_datatype', + datasets=ext_datasets, + data_type_inc='MyGroup', + data_type_def='MyExtGroup' + ) + to_dump = {'groups': [group_spec, ext_group_spec]} + with open(self.specs_path, 'w') as tmp: + yaml.safe_dump(json.loads(json.dumps(to_dump)), tmp, default_flow_style=False) + + self.ns_catalog.load_namespaces(self.namespace_path) + + # first test that dtype of the original dataset is set correctly + read_group_spec = self.ns_catalog.get_spec(self.NS_NAME, 'MyGroup') + read_dset_spec = read_group_spec.datasets[0] + self.assertEqual(read_dset_spec.dtype, 'int') + + # then test that dtype of the extended dataset is set correctly + read_group_spec = self.ns_catalog.get_spec(self.NS_NAME, 'MyExtGroup') + read_dset_spec = read_group_spec.datasets[0] + self.assertEqual(read_dset_spec.dtype, 'int') + + def test_inherited_incompatible_dtype(self): + """Test that an extended dataset cannot change the dtype of the original dataset to an incompatible dtype""" + datasets = [ + DatasetSpec( + 'my first dataset', + 'int', + name='dataset1', + ) + ] + group_spec = GroupSpec( + 'A test group', + name='root_constructor_datatype', + datasets=datasets, + data_type_def='MyGroup' + ) + + # same name, change dtype, change docstring + ext_datasets = [ + DatasetSpec( + 'my first dataset extension', + 'text', + name='dataset1' + ) + ] + ext_group_spec = GroupSpec( + 'A test group extension', + name='root_constructor_datatype', + datasets=ext_datasets, + data_type_inc='MyGroup', + data_type_def='MyExtGroup' + ) + to_dump = {'groups': [group_spec, ext_group_spec]} + with open(self.specs_path, 'w') as tmp: + yaml.safe_dump(json.loads(json.dumps(to_dump)), tmp, default_flow_style=False) + + # TODO this test should fail but currently succeeds + self.ns_catalog.load_namespaces(self.namespace_path)