diff --git a/CHANGELOG.md b/CHANGELOG.md index 9c9a3e6bd..2079969be 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,8 @@ - Fix CI testing on Python 3.9. @rly (#523) - Fix certain edge cases where `GroupValidator` would not validate all of the child groups or datasets attached to a `GroupBuilder`. @dsleiter (#526) +- Various fixes for dynamic class generation. @rly (#561) + - Fix generation of classes that extends both `MultiContainerInterface` and another class that extends `MultiContainerInterface`. @rly (#567) diff --git a/src/hdmf/build/classgenerator.py b/src/hdmf/build/classgenerator.py index cd4deb580..d775722f0 100644 --- a/src/hdmf/build/classgenerator.py +++ b/src/hdmf/build/classgenerator.py @@ -4,7 +4,7 @@ import numpy as np from ..container import Container, Data, DataRegion, MultiContainerInterface -from ..spec import AttributeSpec, LinkSpec, RefSpec +from ..spec import AttributeSpec, LinkSpec, RefSpec, GroupSpec from ..spec.spec import BaseStorageSpec, ZERO_OR_MANY, ONE_OR_MANY from ..utils import docval, getargs, ExtenderMeta, get_docval, fmt_docval_args @@ -50,6 +50,8 @@ def generate_class(self, **kwargs): if k == 'help': # pragma: no cover # (legacy) do not add field named 'help' to any part of class object continue + if isinstance(field_spec, GroupSpec) and field_spec.data_type is None: # skip named, untyped groups + continue if not spec.is_inherited_spec(field_spec): not_inherited_fields[k] = field_spec try: @@ -61,7 +63,7 @@ def generate_class(self, **kwargs): # each generator can update classdict and docval_args if class_generator.apply_generator_to_field(field_spec, bases, type_map): class_generator.process_field_spec(classdict, docval_args, parent_cls, attr_name, - not_inherited_fields, type_map) + not_inherited_fields, type_map, spec) break # each field_spec should be processed by only one generator for class_generator in self.__custom_generators: @@ -204,14 +206,15 @@ def apply_generator_to_field(cls, field_spec, bases, type_map): return True @classmethod - def process_field_spec(cls, classdict, docval_args, parent_cls, attr_name, not_inherited_fields, type_map): + def process_field_spec(cls, classdict, docval_args, parent_cls, attr_name, not_inherited_fields, type_map, spec): """Add __fields__ to the classdict and update the docval args for the field spec with the given attribute name. :param classdict: The dict to update with __fields__. :param docval_args: The list of docval arguments. :param parent_cls: The parent class. :param attr_name: The attribute name of the field spec for the container class to generate. - :param spec: The spec for the container class to generate. + :param not_inherited_fields: Dictionary of fields not inherited from the parent class. :param type_map: The type map to use. + :param spec: The spec for the container class to generate. """ field_spec = not_inherited_fields[attr_name] dtype = cls._get_type(field_spec, type_map) @@ -231,10 +234,20 @@ def process_field_spec(cls, classdict, docval_args, parent_cls, attr_name, not_i shape = getattr(field_spec, 'shape', None) if shape is not None: docval_arg['shape'] = shape - if not field_spec.required: + if cls._check_spec_optional(field_spec, spec): docval_arg['default'] = getattr(field_spec, 'default_value', None) cls._add_to_docval_args(docval_args, docval_arg) + @classmethod + def _check_spec_optional(cls, field_spec, spec): + """Returns True if the spec or any of its parents (up to the parent type spec) are optional.""" + if not field_spec.required: + return True + if field_spec == spec: + return False + if field_spec.parent is not None: + return cls._check_spec_optional(field_spec.parent, spec) + @classmethod def _add_to_docval_args(cls, docval_args, arg): """Add the docval arg to the list if not present. If present, overwrite it in place.""" @@ -301,14 +314,15 @@ def apply_generator_to_field(cls, field_spec, bases, type_map): return getattr(field_spec, 'quantity', None) in (ZERO_OR_MANY, ONE_OR_MANY) @classmethod - def process_field_spec(cls, classdict, docval_args, parent_cls, attr_name, not_inherited_fields, type_map): + def process_field_spec(cls, classdict, docval_args, parent_cls, attr_name, not_inherited_fields, type_map, spec): """Add __clsconf__ to the classdict and update the docval args for the field spec with the given attribute name. :param classdict: The dict to update with __clsconf__. :param docval_args: The list of docval arguments. :param parent_cls: The parent class. :param attr_name: The attribute name of the field spec for the container class to generate. - :param spec: The spec for the container class to generate. + :param not_inherited_fields: Dictionary of fields not inherited from the parent class. :param type_map: The type map to use. + :param spec: The spec for the container class to generate. """ field_spec = not_inherited_fields[attr_name] field_clsconf = dict( @@ -326,7 +340,7 @@ def process_field_spec(cls, classdict, docval_args, parent_cls, attr_name, not_i doc=field_spec.doc, type=(list, tuple, dict, cls._get_type(field_spec, type_map)) ) - if not field_spec.required: + if cls._check_spec_optional(field_spec, spec): docval_arg['default'] = getattr(field_spec, 'default_value', None) cls._add_to_docval_args(docval_args, docval_arg) diff --git a/src/hdmf/build/manager.py b/src/hdmf/build/manager.py index 44317b807..b60d848b9 100644 --- a/src/hdmf/build/manager.py +++ b/src/hdmf/build/manager.py @@ -5,7 +5,7 @@ from .builders import DatasetBuilder, GroupBuilder, LinkBuilder, Builder, BaseBuilder from .classgenerator import ClassGenerator, CustomClassGenerator, MCIClassGenerator from ..container import AbstractContainer, Container, Data -from ..spec import DatasetSpec, GroupSpec, NamespaceCatalog, SpecReader +from ..spec import DatasetSpec, GroupSpec, LinkSpec, NamespaceCatalog, SpecReader from ..spec.spec import BaseStorageSpec from ..utils import docval, getargs, call_docval_func, ExtenderMeta @@ -448,7 +448,8 @@ def merge(self, type_map, ns_catalog=False): self.register_container_type(namespace, data_type, container_cls) for container_cls in type_map.__mapper_cls: self.register_map(container_cls, type_map.__mapper_cls[container_cls]) - for custom_generators in type_map.__class_generator.custom_generators: + for custom_generators in reversed(type_map.__class_generator.custom_generators): + # iterate in reverse order because generators are stored internally as a stack self.register_generator(custom_generators) @docval({"name": "generator", "type": type, "doc": "the CustomClassGenerator class to register"}) @@ -511,19 +512,26 @@ def get_container_cls(self, **kwargs): return cls def __check_dependent_types(self, spec, namespace): - """Ensure that classes for all types used by this type exist and generate them if not. + """Ensure that classes for all types used by this type exist in this namespace and generate them if not. """ + def __check_dependent_types_helper(spec, namespace): + if isinstance(spec, (GroupSpec, DatasetSpec)): + if spec.data_type_inc is not None: + self.get_container_cls(spec.data_type_inc, namespace) # TODO handle recursive definitions + if spec.data_type_def is not None: + self.get_container_cls(spec.data_type_def, namespace) + elif isinstance(spec, LinkSpec): + if spec.target_type is not None: + self.get_container_cls(spec.target_type, namespace) + if isinstance(spec, GroupSpec): + for child_spec in (spec.groups + spec.datasets + spec.links): + __check_dependent_types_helper(child_spec, namespace) + if spec.data_type_inc is not None: self.get_container_cls(spec.data_type_inc, namespace) if isinstance(spec, GroupSpec): - for child_spec in (spec.groups + spec.datasets): - if child_spec.data_type_inc is not None: - self.get_container_cls(child_spec.data_type_inc, namespace) - if child_spec.data_type_def is not None: - self.get_container_cls(child_spec.data_type_def, namespace) - for child_spec in spec.links: - if child_spec.target_type is not None: - self.get_container_cls(child_spec.target_type, namespace) + for child_spec in (spec.groups + spec.datasets + spec.links): + __check_dependent_types_helper(child_spec, namespace) def __get_parent_cls(self, namespace, data_type, spec): dt_hier = self.__ns_catalog.get_hierarchy(namespace, data_type) diff --git a/src/hdmf/common/io/table.py b/src/hdmf/common/io/table.py index 4991c47f2..4259f66dc 100644 --- a/src/hdmf/common/io/table.py +++ b/src/hdmf/common/io/table.py @@ -50,18 +50,24 @@ class DynamicTableGenerator(CustomClassGenerator): @classmethod def apply_generator_to_field(cls, field_spec, bases, type_map): """Return True if this is a DynamicTable and the field spec is a column.""" + for b in bases: + if issubclass(b, DynamicTable): + break + else: # return False if no base is a subclass of DynamicTable + return False dtype = cls._get_type(field_spec, type_map) - return DynamicTable in bases and issubclass(dtype, VectorData) + return isinstance(dtype, type) and issubclass(dtype, VectorData) @classmethod - def process_field_spec(cls, classdict, docval_args, parent_cls, attr_name, not_inherited_fields, type_map): + def process_field_spec(cls, classdict, docval_args, parent_cls, attr_name, not_inherited_fields, type_map, spec): """Add __columns__ to the classdict and update the docval args for the field spec with the given attribute name. :param classdict: The dict to update with __columns__. :param docval_args: The list of docval arguments. :param parent_cls: The parent class. :param attr_name: The attribute name of the field spec for the container class to generate. - :param spec: The spec for the container class to generate. + :param not_inherited_fields: Dictionary of fields not inherited from the parent class. :param type_map: The type map to use. + :param spec: The spec for the container class to generate. """ if attr_name.endswith('_index'): # do not add index columns to __columns__ return @@ -116,14 +122,16 @@ def post_process(cls, classdict, bases, docval_args, spec): @classmethod def set_init(cls, classdict, bases, docval_args, not_inherited_fields, name): - base_init = classdict['__init__'] + base_init = classdict.get('__init__') + if base_init is None: # pragma: no cover + raise ValueError("Generated class dictionary is missing base __init__ method.") @docval(*docval_args) def __init__(self, **kwargs): base_init(self, **kwargs) # set target attribute on DTR - target_tables = kwargs['target_tables'] + target_tables = kwargs.get('target_tables') if target_tables: for colname, table in target_tables.items(): if colname not in self: # column has not yet been added (it is optional) diff --git a/src/hdmf/spec/catalog.py b/src/hdmf/spec/catalog.py index 9f312cb91..e623aae51 100644 --- a/src/hdmf/spec/catalog.py +++ b/src/hdmf/spec/catalog.py @@ -42,7 +42,8 @@ def register_spec(self, **kwargs): self.__parent_types[ndt_def] = ndt type_name = ndt_def if ndt_def is not None else ndt if type_name in self.__specs: - raise ValueError("'%s' - cannot overwrite existing specification" % type_name) + if self.__specs[type_name] != spec or self.__spec_source_files[type_name] != source_file: + raise ValueError("'%s' - cannot overwrite existing specification" % type_name) self.__specs[type_name] = spec self.__spec_source_files[type_name] = source_file diff --git a/tests/unit/build_tests/test_classgenerator.py b/tests/unit/build_tests/test_classgenerator.py index d25f3b157..4ba8c6af5 100644 --- a/tests/unit/build_tests/test_classgenerator.py +++ b/tests/unit/build_tests/test_classgenerator.py @@ -26,7 +26,8 @@ def apply_generator_to_field(cls, field_spec, bases, type_map): return True @classmethod - def process_field_spec(cls, classdict, docval_args, parent_cls, attr_name, not_inherited_fields, type_map): + def process_field_spec(cls, classdict, docval_args, parent_cls, attr_name, not_inherited_fields, type_map, + spec): # append attr_name to classdict['__custom_fields__'] list classdict.setdefault('process_field_spec', list()).append(attr_name) @@ -548,7 +549,8 @@ def test_update_docval(self): parent_cls=EmptyBar, # <-- arbitrary class attr_name=attr_name, not_inherited_fields=not_inherited_fields, - type_map=self.type_map + type_map=self.type_map, + spec=spec ) self.assertListEqual(docval_args, expected[:(i+1)]) # compare with the first i elements of expected @@ -570,7 +572,8 @@ def test_update_docval_attr_shape(self): parent_cls=EmptyBar, # <-- arbitrary class attr_name='attr1', not_inherited_fields=not_inherited_fields, - type_map=TypeMap() + type_map=TypeMap(), + spec=spec ) expected = [{'name': 'attr1', 'type': ('array_data', 'data'), 'doc': 'a string attribute', 'shape': [None]}] @@ -594,7 +597,8 @@ def test_update_docval_dset_shape(self): parent_cls=EmptyBar, # <-- arbitrary class attr_name='dset1', not_inherited_fields=not_inherited_fields, - type_map=TypeMap() + type_map=TypeMap(), + spec=spec ) expected = [{'name': 'dset1', 'type': ('array_data', 'data'), 'doc': 'a string dataset', 'shape': [None]}] @@ -619,7 +623,8 @@ def test_update_docval_default_value(self): parent_cls=EmptyBar, # <-- arbitrary class attr_name='attr1', not_inherited_fields=not_inherited_fields, - type_map=TypeMap() + type_map=TypeMap(), + spec=spec ) expected = [{'name': 'attr1', 'type': str, 'doc': 'a string attribute', 'default': 'value'}] @@ -643,7 +648,71 @@ def test_update_docval_default_value_none(self): parent_cls=EmptyBar, # <-- arbitrary class attr_name='attr1', not_inherited_fields=not_inherited_fields, - type_map=TypeMap() + type_map=TypeMap(), + spec=spec + ) + + expected = [{'name': 'attr1', 'type': str, 'doc': 'a string attribute', 'default': None}] + self.assertListEqual(docval_args, expected) + + def test_update_docval_default_value_none_required_parent(self): + """Test that update_docval_args for an optional field with a required parent sets default: None.""" + spec = GroupSpec( + doc='A test group specification with a data type', + data_type_def='Baz', + groups=[ + GroupSpec( + name='group1', + doc='required untyped group', + attributes=[ + AttributeSpec(name='attr1', doc='a string attribute', dtype='text', required=False) + ] + ) + ] + ) + not_inherited_fields = {'attr1': spec.get_group('group1').get_attribute('attr1')} + + docval_args = list() + CustomClassGenerator.process_field_spec( + classdict={}, + docval_args=docval_args, + parent_cls=EmptyBar, # <-- arbitrary class + attr_name='attr1', + not_inherited_fields=not_inherited_fields, + type_map=TypeMap(), + spec=spec + ) + + expected = [{'name': 'attr1', 'type': str, 'doc': 'a string attribute', 'default': None}] + self.assertListEqual(docval_args, expected) + + def test_update_docval_required_field_optional_parent(self): + """Test that update_docval_args for a required field with an optional parent sets default: None.""" + spec = GroupSpec( + doc='A test group specification with a data type', + data_type_def='Baz', + groups=[ + GroupSpec( + name='group1', + doc='required untyped group', + attributes=[ + AttributeSpec(name='attr1', doc='a string attribute', dtype='text') + ], + quantity='?' + ) + ] + ) + not_inherited_fields = {'attr1': spec.get_group('group1').get_attribute('attr1')} + + docval_args = list() + CustomClassGenerator.process_field_spec( + classdict={}, + docval_args=docval_args, + parent_cls=EmptyBar, # <-- arbitrary class + attr_name='attr1', + not_inherited_fields=not_inherited_fields, + type_map=TypeMap(), + spec=spec ) expected = [{'name': 'attr1', 'type': str, 'doc': 'a string attribute', 'default': None}] @@ -670,7 +739,8 @@ def test_process_field_spec_overwrite(self): parent_cls=EmptyBar, # <-- arbitrary class attr_name='attr1', not_inherited_fields=not_inherited_fields, - type_map=TypeMap() + type_map=TypeMap(), + spec=spec ) expected = [{'name': 'attr1', 'type': ('array_data', 'data'), 'doc': 'a string attribute', @@ -689,7 +759,8 @@ def test_process_field_spec_link(self): parent_cls=EmptyBar, # <-- arbitrary class attr_name='attr3', not_inherited_fields=not_inherited_fields, - type_map=self.type_map + type_map=self.type_map, + spec=GroupSpec('dummy', 'doc') ) expected = {'__fields__': [{'name': 'attr3', 'doc': 'a link'}]} @@ -773,7 +844,8 @@ def test_update_docval(self): parent_cls=Container, attr_name='empty_bars', not_inherited_fields=not_inherited_fields, - type_map=self.type_map + type_map=self.type_map, + spec=spec ) expected = [ @@ -798,7 +870,8 @@ def test_update_init_zero_or_more(self): parent_cls=Container, attr_name='empty_bars', not_inherited_fields=not_inherited_fields, - type_map=self.type_map + type_map=self.type_map, + spec=spec ) expected = [{'name': 'empty_bars', 'type': (list, tuple, dict, EmptyBar), 'doc': 'test multi', 'default': None}] @@ -815,7 +888,8 @@ def test_update_init_one_or_more(self): parent_cls=Container, attr_name='empty_bars', not_inherited_fields=not_inherited_fields, - type_map=self.type_map + type_map=self.type_map, + spec=spec ) expected = [{'name': 'empty_bars', 'type': (list, tuple, dict, EmptyBar), 'doc': 'test multi'}] diff --git a/tests/unit/spec_tests/test_spec_catalog.py b/tests/unit/spec_tests/test_spec_catalog.py index 7436fe060..d12f352e8 100644 --- a/tests/unit/spec_tests/test_spec_catalog.py +++ b/tests/unit/spec_tests/test_spec_catalog.py @@ -180,3 +180,50 @@ def test_deepcopy_spec_catalog(self): re = copy.deepcopy(self.catalog) self.assertTupleEqual(self.catalog.get_registered_types(), re.get_registered_types()) + + def test_catch_duplicate_spec_nested(self): + spec1 = GroupSpec( + data_type_def='Group1', + doc='This is my new group 1', + ) + spec2 = GroupSpec( + data_type_def='Group2', + doc='This is my new group 2', + groups=[spec1], # nested definition + ) + source = 'test_extension.yaml' + self.catalog.register_spec(spec1, source) + self.catalog.register_spec(spec2, source) # this is OK because Group1 is the same spec + ret = self.catalog.get_spec('Group1') + self.assertIs(ret, spec1) + + def test_catch_duplicate_spec_different(self): + spec1 = GroupSpec( + data_type_def='Group1', + doc='This is my new group 1', + ) + spec2 = GroupSpec( + data_type_def='Group1', + doc='This is my other group 1', + ) + source = 'test_extension.yaml' + self.catalog.register_spec(spec1, source) + msg = "'Group1' - cannot overwrite existing specification" + with self.assertRaisesWith(ValueError, msg): + self.catalog.register_spec(spec2, source) + + def test_catch_duplicate_spec_different_source(self): + spec1 = GroupSpec( + data_type_def='Group1', + doc='This is my new group 1', + ) + spec2 = GroupSpec( + data_type_def='Group1', + doc='This is my new group 1', + ) + source1 = 'test_extension1.yaml' + source2 = 'test_extension2.yaml' + self.catalog.register_spec(spec1, source1) + msg = "'Group1' - cannot overwrite existing specification" + with self.assertRaisesWith(ValueError, msg): + self.catalog.register_spec(spec2, source2)