Skip to content

Commit

Permalink
Merge branch 'dev' into add/aligned_dynamic_table
Browse files Browse the repository at this point in the history
  • Loading branch information
rly authored Apr 13, 2021
2 parents 88d36f3 + 4e32a9f commit 09ce741
Show file tree
Hide file tree
Showing 7 changed files with 190 additions and 36 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
- Fix CI testing on Python 3.9. @rly (#523)
- Fix certain edge cases where `GroupValidator` would not validate all of the child groups or datasets
attached to a `GroupBuilder`. @dsleiter (#526)
- Various fixes for dynamic class generation. @rly (#561)

- Fix generation of classes that extends both `MultiContainerInterface` and another class that extends
`MultiContainerInterface`. @rly (#567)

Expand Down
30 changes: 22 additions & 8 deletions src/hdmf/build/classgenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np

from ..container import Container, Data, DataRegion, MultiContainerInterface
from ..spec import AttributeSpec, LinkSpec, RefSpec
from ..spec import AttributeSpec, LinkSpec, RefSpec, GroupSpec
from ..spec.spec import BaseStorageSpec, ZERO_OR_MANY, ONE_OR_MANY
from ..utils import docval, getargs, ExtenderMeta, get_docval, fmt_docval_args

Expand Down Expand Up @@ -50,6 +50,8 @@ def generate_class(self, **kwargs):
if k == 'help': # pragma: no cover
# (legacy) do not add field named 'help' to any part of class object
continue
if isinstance(field_spec, GroupSpec) and field_spec.data_type is None: # skip named, untyped groups
continue
if not spec.is_inherited_spec(field_spec):
not_inherited_fields[k] = field_spec
try:
Expand All @@ -61,7 +63,7 @@ def generate_class(self, **kwargs):
# each generator can update classdict and docval_args
if class_generator.apply_generator_to_field(field_spec, bases, type_map):
class_generator.process_field_spec(classdict, docval_args, parent_cls, attr_name,
not_inherited_fields, type_map)
not_inherited_fields, type_map, spec)
break # each field_spec should be processed by only one generator

for class_generator in self.__custom_generators:
Expand Down Expand Up @@ -204,14 +206,15 @@ def apply_generator_to_field(cls, field_spec, bases, type_map):
return True

@classmethod
def process_field_spec(cls, classdict, docval_args, parent_cls, attr_name, not_inherited_fields, type_map):
def process_field_spec(cls, classdict, docval_args, parent_cls, attr_name, not_inherited_fields, type_map, spec):
"""Add __fields__ to the classdict and update the docval args for the field spec with the given attribute name.
:param classdict: The dict to update with __fields__.
:param docval_args: The list of docval arguments.
:param parent_cls: The parent class.
:param attr_name: The attribute name of the field spec for the container class to generate.
:param spec: The spec for the container class to generate.
:param not_inherited_fields: Dictionary of fields not inherited from the parent class.
:param type_map: The type map to use.
:param spec: The spec for the container class to generate.
"""
field_spec = not_inherited_fields[attr_name]
dtype = cls._get_type(field_spec, type_map)
Expand All @@ -231,10 +234,20 @@ def process_field_spec(cls, classdict, docval_args, parent_cls, attr_name, not_i
shape = getattr(field_spec, 'shape', None)
if shape is not None:
docval_arg['shape'] = shape
if not field_spec.required:
if cls._check_spec_optional(field_spec, spec):
docval_arg['default'] = getattr(field_spec, 'default_value', None)
cls._add_to_docval_args(docval_args, docval_arg)

@classmethod
def _check_spec_optional(cls, field_spec, spec):
"""Returns True if the spec or any of its parents (up to the parent type spec) are optional."""
if not field_spec.required:
return True
if field_spec == spec:
return False
if field_spec.parent is not None:
return cls._check_spec_optional(field_spec.parent, spec)

@classmethod
def _add_to_docval_args(cls, docval_args, arg):
"""Add the docval arg to the list if not present. If present, overwrite it in place."""
Expand Down Expand Up @@ -301,14 +314,15 @@ def apply_generator_to_field(cls, field_spec, bases, type_map):
return getattr(field_spec, 'quantity', None) in (ZERO_OR_MANY, ONE_OR_MANY)

@classmethod
def process_field_spec(cls, classdict, docval_args, parent_cls, attr_name, not_inherited_fields, type_map):
def process_field_spec(cls, classdict, docval_args, parent_cls, attr_name, not_inherited_fields, type_map, spec):
"""Add __clsconf__ to the classdict and update the docval args for the field spec with the given attribute name.
:param classdict: The dict to update with __clsconf__.
:param docval_args: The list of docval arguments.
:param parent_cls: The parent class.
:param attr_name: The attribute name of the field spec for the container class to generate.
:param spec: The spec for the container class to generate.
:param not_inherited_fields: Dictionary of fields not inherited from the parent class.
:param type_map: The type map to use.
:param spec: The spec for the container class to generate.
"""
field_spec = not_inherited_fields[attr_name]
field_clsconf = dict(
Expand All @@ -326,7 +340,7 @@ def process_field_spec(cls, classdict, docval_args, parent_cls, attr_name, not_i
doc=field_spec.doc,
type=(list, tuple, dict, cls._get_type(field_spec, type_map))
)
if not field_spec.required:
if cls._check_spec_optional(field_spec, spec):
docval_arg['default'] = getattr(field_spec, 'default_value', None)
cls._add_to_docval_args(docval_args, docval_arg)

Expand Down
30 changes: 19 additions & 11 deletions src/hdmf/build/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .builders import DatasetBuilder, GroupBuilder, LinkBuilder, Builder, BaseBuilder
from .classgenerator import ClassGenerator, CustomClassGenerator, MCIClassGenerator
from ..container import AbstractContainer, Container, Data
from ..spec import DatasetSpec, GroupSpec, NamespaceCatalog, SpecReader
from ..spec import DatasetSpec, GroupSpec, LinkSpec, NamespaceCatalog, SpecReader
from ..spec.spec import BaseStorageSpec
from ..utils import docval, getargs, call_docval_func, ExtenderMeta

Expand Down Expand Up @@ -448,7 +448,8 @@ def merge(self, type_map, ns_catalog=False):
self.register_container_type(namespace, data_type, container_cls)
for container_cls in type_map.__mapper_cls:
self.register_map(container_cls, type_map.__mapper_cls[container_cls])
for custom_generators in type_map.__class_generator.custom_generators:
for custom_generators in reversed(type_map.__class_generator.custom_generators):
# iterate in reverse order because generators are stored internally as a stack
self.register_generator(custom_generators)

@docval({"name": "generator", "type": type, "doc": "the CustomClassGenerator class to register"})
Expand Down Expand Up @@ -511,19 +512,26 @@ def get_container_cls(self, **kwargs):
return cls

def __check_dependent_types(self, spec, namespace):
"""Ensure that classes for all types used by this type exist and generate them if not.
"""Ensure that classes for all types used by this type exist in this namespace and generate them if not.
"""
def __check_dependent_types_helper(spec, namespace):
if isinstance(spec, (GroupSpec, DatasetSpec)):
if spec.data_type_inc is not None:
self.get_container_cls(spec.data_type_inc, namespace) # TODO handle recursive definitions
if spec.data_type_def is not None:
self.get_container_cls(spec.data_type_def, namespace)
elif isinstance(spec, LinkSpec):
if spec.target_type is not None:
self.get_container_cls(spec.target_type, namespace)
if isinstance(spec, GroupSpec):
for child_spec in (spec.groups + spec.datasets + spec.links):
__check_dependent_types_helper(child_spec, namespace)

if spec.data_type_inc is not None:
self.get_container_cls(spec.data_type_inc, namespace)
if isinstance(spec, GroupSpec):
for child_spec in (spec.groups + spec.datasets):
if child_spec.data_type_inc is not None:
self.get_container_cls(child_spec.data_type_inc, namespace)
if child_spec.data_type_def is not None:
self.get_container_cls(child_spec.data_type_def, namespace)
for child_spec in spec.links:
if child_spec.target_type is not None:
self.get_container_cls(child_spec.target_type, namespace)
for child_spec in (spec.groups + spec.datasets + spec.links):
__check_dependent_types_helper(child_spec, namespace)

def __get_parent_cls(self, namespace, data_type, spec):
dt_hier = self.__ns_catalog.get_hierarchy(namespace, data_type)
Expand Down
18 changes: 13 additions & 5 deletions src/hdmf/common/io/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,18 +50,24 @@ class DynamicTableGenerator(CustomClassGenerator):
@classmethod
def apply_generator_to_field(cls, field_spec, bases, type_map):
"""Return True if this is a DynamicTable and the field spec is a column."""
for b in bases:
if issubclass(b, DynamicTable):
break
else: # return False if no base is a subclass of DynamicTable
return False
dtype = cls._get_type(field_spec, type_map)
return DynamicTable in bases and issubclass(dtype, VectorData)
return isinstance(dtype, type) and issubclass(dtype, VectorData)

@classmethod
def process_field_spec(cls, classdict, docval_args, parent_cls, attr_name, not_inherited_fields, type_map):
def process_field_spec(cls, classdict, docval_args, parent_cls, attr_name, not_inherited_fields, type_map, spec):
"""Add __columns__ to the classdict and update the docval args for the field spec with the given attribute name.
:param classdict: The dict to update with __columns__.
:param docval_args: The list of docval arguments.
:param parent_cls: The parent class.
:param attr_name: The attribute name of the field spec for the container class to generate.
:param spec: The spec for the container class to generate.
:param not_inherited_fields: Dictionary of fields not inherited from the parent class.
:param type_map: The type map to use.
:param spec: The spec for the container class to generate.
"""
if attr_name.endswith('_index'): # do not add index columns to __columns__
return
Expand Down Expand Up @@ -116,14 +122,16 @@ def post_process(cls, classdict, bases, docval_args, spec):

@classmethod
def set_init(cls, classdict, bases, docval_args, not_inherited_fields, name):
base_init = classdict['__init__']
base_init = classdict.get('__init__')
if base_init is None: # pragma: no cover
raise ValueError("Generated class dictionary is missing base __init__ method.")

@docval(*docval_args)
def __init__(self, **kwargs):
base_init(self, **kwargs)

# set target attribute on DTR
target_tables = kwargs['target_tables']
target_tables = kwargs.get('target_tables')
if target_tables:
for colname, table in target_tables.items():
if colname not in self: # column has not yet been added (it is optional)
Expand Down
3 changes: 2 additions & 1 deletion src/hdmf/spec/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def register_spec(self, **kwargs):
self.__parent_types[ndt_def] = ndt
type_name = ndt_def if ndt_def is not None else ndt
if type_name in self.__specs:
raise ValueError("'%s' - cannot overwrite existing specification" % type_name)
if self.__specs[type_name] != spec or self.__spec_source_files[type_name] != source_file:
raise ValueError("'%s' - cannot overwrite existing specification" % type_name)
self.__specs[type_name] = spec
self.__spec_source_files[type_name] = source_file

Expand Down
96 changes: 85 additions & 11 deletions tests/unit/build_tests/test_classgenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ def apply_generator_to_field(cls, field_spec, bases, type_map):
return True

@classmethod
def process_field_spec(cls, classdict, docval_args, parent_cls, attr_name, not_inherited_fields, type_map):
def process_field_spec(cls, classdict, docval_args, parent_cls, attr_name, not_inherited_fields, type_map,
spec):
# append attr_name to classdict['__custom_fields__'] list
classdict.setdefault('process_field_spec', list()).append(attr_name)

Expand Down Expand Up @@ -548,7 +549,8 @@ def test_update_docval(self):
parent_cls=EmptyBar, # <-- arbitrary class
attr_name=attr_name,
not_inherited_fields=not_inherited_fields,
type_map=self.type_map
type_map=self.type_map,
spec=spec
)
self.assertListEqual(docval_args, expected[:(i+1)]) # compare with the first i elements of expected

Expand All @@ -570,7 +572,8 @@ def test_update_docval_attr_shape(self):
parent_cls=EmptyBar, # <-- arbitrary class
attr_name='attr1',
not_inherited_fields=not_inherited_fields,
type_map=TypeMap()
type_map=TypeMap(),
spec=spec
)

expected = [{'name': 'attr1', 'type': ('array_data', 'data'), 'doc': 'a string attribute', 'shape': [None]}]
Expand All @@ -594,7 +597,8 @@ def test_update_docval_dset_shape(self):
parent_cls=EmptyBar, # <-- arbitrary class
attr_name='dset1',
not_inherited_fields=not_inherited_fields,
type_map=TypeMap()
type_map=TypeMap(),
spec=spec
)

expected = [{'name': 'dset1', 'type': ('array_data', 'data'), 'doc': 'a string dataset', 'shape': [None]}]
Expand All @@ -619,7 +623,8 @@ def test_update_docval_default_value(self):
parent_cls=EmptyBar, # <-- arbitrary class
attr_name='attr1',
not_inherited_fields=not_inherited_fields,
type_map=TypeMap()
type_map=TypeMap(),
spec=spec
)

expected = [{'name': 'attr1', 'type': str, 'doc': 'a string attribute', 'default': 'value'}]
Expand All @@ -643,7 +648,71 @@ def test_update_docval_default_value_none(self):
parent_cls=EmptyBar, # <-- arbitrary class
attr_name='attr1',
not_inherited_fields=not_inherited_fields,
type_map=TypeMap()
type_map=TypeMap(),
spec=spec
)

expected = [{'name': 'attr1', 'type': str, 'doc': 'a string attribute', 'default': None}]
self.assertListEqual(docval_args, expected)

def test_update_docval_default_value_none_required_parent(self):
"""Test that update_docval_args for an optional field with a required parent sets default: None."""
spec = GroupSpec(
doc='A test group specification with a data type',
data_type_def='Baz',
groups=[
GroupSpec(
name='group1',
doc='required untyped group',
attributes=[
AttributeSpec(name='attr1', doc='a string attribute', dtype='text', required=False)
]
)
]
)
not_inherited_fields = {'attr1': spec.get_group('group1').get_attribute('attr1')}

docval_args = list()
CustomClassGenerator.process_field_spec(
classdict={},
docval_args=docval_args,
parent_cls=EmptyBar, # <-- arbitrary class
attr_name='attr1',
not_inherited_fields=not_inherited_fields,
type_map=TypeMap(),
spec=spec
)

expected = [{'name': 'attr1', 'type': str, 'doc': 'a string attribute', 'default': None}]
self.assertListEqual(docval_args, expected)

def test_update_docval_required_field_optional_parent(self):
"""Test that update_docval_args for a required field with an optional parent sets default: None."""
spec = GroupSpec(
doc='A test group specification with a data type',
data_type_def='Baz',
groups=[
GroupSpec(
name='group1',
doc='required untyped group',
attributes=[
AttributeSpec(name='attr1', doc='a string attribute', dtype='text')
],
quantity='?'
)
]
)
not_inherited_fields = {'attr1': spec.get_group('group1').get_attribute('attr1')}

docval_args = list()
CustomClassGenerator.process_field_spec(
classdict={},
docval_args=docval_args,
parent_cls=EmptyBar, # <-- arbitrary class
attr_name='attr1',
not_inherited_fields=not_inherited_fields,
type_map=TypeMap(),
spec=spec
)

expected = [{'name': 'attr1', 'type': str, 'doc': 'a string attribute', 'default': None}]
Expand All @@ -670,7 +739,8 @@ def test_process_field_spec_overwrite(self):
parent_cls=EmptyBar, # <-- arbitrary class
attr_name='attr1',
not_inherited_fields=not_inherited_fields,
type_map=TypeMap()
type_map=TypeMap(),
spec=spec
)

expected = [{'name': 'attr1', 'type': ('array_data', 'data'), 'doc': 'a string attribute',
Expand All @@ -689,7 +759,8 @@ def test_process_field_spec_link(self):
parent_cls=EmptyBar, # <-- arbitrary class
attr_name='attr3',
not_inherited_fields=not_inherited_fields,
type_map=self.type_map
type_map=self.type_map,
spec=GroupSpec('dummy', 'doc')
)

expected = {'__fields__': [{'name': 'attr3', 'doc': 'a link'}]}
Expand Down Expand Up @@ -773,7 +844,8 @@ def test_update_docval(self):
parent_cls=Container,
attr_name='empty_bars',
not_inherited_fields=not_inherited_fields,
type_map=self.type_map
type_map=self.type_map,
spec=spec
)

expected = [
Expand All @@ -798,7 +870,8 @@ def test_update_init_zero_or_more(self):
parent_cls=Container,
attr_name='empty_bars',
not_inherited_fields=not_inherited_fields,
type_map=self.type_map
type_map=self.type_map,
spec=spec
)

expected = [{'name': 'empty_bars', 'type': (list, tuple, dict, EmptyBar), 'doc': 'test multi', 'default': None}]
Expand All @@ -815,7 +888,8 @@ def test_update_init_one_or_more(self):
parent_cls=Container,
attr_name='empty_bars',
not_inherited_fields=not_inherited_fields,
type_map=self.type_map
type_map=self.type_map,
spec=spec
)

expected = [{'name': 'empty_bars', 'type': (list, tuple, dict, EmptyBar), 'doc': 'test multi'}]
Expand Down
Loading

0 comments on commit 09ce741

Please sign in to comment.