diff --git a/CHANGELOG.md b/CHANGELOG.md index d3f762419..76a32b02c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,10 @@ # PyNWB Changelog +## PyNWB 2.6.1 (Upcoming) + +### Bug fixes +- Fix bug where extra keyword arguments could not be passed to `NWBFile.add_{x}_column`` for use in custom `VectorData`` classes. @rly [#1861](https://github.com/NeurodataWithoutBorders/pynwb/pull/1861) + ## PyNWB 2.6.0 (February 21, 2024) ### Enhancements and minor changes diff --git a/src/pynwb/file.py b/src/pynwb/file.py index 8bd7ff447..a8e9dd1b6 100644 --- a/src/pynwb/file.py +++ b/src/pynwb/file.py @@ -608,7 +608,7 @@ def __check_epochs(self): if self.epochs is None: self.epochs = TimeIntervals(name='epochs', description='experimental epochs') - @docval(*get_docval(TimeIntervals.add_column)) + @docval(*get_docval(TimeIntervals.add_column), allow_extra=True) def add_epoch_column(self, **kwargs): """ Add a column to the epoch table. @@ -645,7 +645,7 @@ def __check_electrodes(self): if self.electrodes is None: self.electrodes = ElectrodeTable() - @docval(*get_docval(DynamicTable.add_column)) + @docval(*get_docval(DynamicTable.add_column), allow_extra=True) def add_electrode_column(self, **kwargs): """ Add a column to the electrode table. @@ -747,7 +747,7 @@ def __check_units(self): if self.units is None: self.units = Units(name='units', description='Autogenerated by NWBFile') - @docval(*get_docval(Units.add_column)) + @docval(*get_docval(Units.add_column), allow_extra=True) def add_unit_column(self, **kwargs): """ Add a column to the unit table. @@ -770,7 +770,7 @@ def __check_trials(self): if self.trials is None: self.trials = TimeIntervals(name='trials', description='experimental trials') - @docval(*get_docval(DynamicTable.add_column)) + @docval(*get_docval(DynamicTable.add_column), allow_extra=True) def add_trial_column(self, **kwargs): """ Add a column to the trial table. @@ -798,7 +798,7 @@ def __check_invalid_times(self): description='time intervals to be removed from analysis' ) - @docval(*get_docval(DynamicTable.add_column)) + @docval(*get_docval(DynamicTable.add_column), allow_extra=True) def add_invalid_times_column(self, **kwargs): """ Add a column to the invalid times table. diff --git a/src/pynwb/testing/testh5io.py b/src/pynwb/testing/testh5io.py index c7b3bfdcc..7234e79f5 100644 --- a/src/pynwb/testing/testh5io.py +++ b/src/pynwb/testing/testh5io.py @@ -4,7 +4,7 @@ from abc import ABCMeta, abstractmethod import warnings -from pynwb import NWBFile, NWBHDF5IO, validate as pynwb_validate +from pynwb import NWBFile, NWBHDF5IO, get_manager, validate as pynwb_validate from .utils import remove_test_file from hdmf.backends.warnings import BrokenLinkWarning from hdmf.build.warnings import MissingRequiredBuildWarning @@ -247,7 +247,11 @@ def tearDown(self): remove_test_file(self.filename) remove_test_file(self.export_filename) - def getContainerType() -> str: + def get_manager(self): + return get_manager() # get the pynwb manager unless overridden + + @abstractmethod + def getContainerType(self) -> str: """Return the name of the type of Container being tested, for test ID purposes.""" raise NotImplementedError('Cannot run test unless getContainerType is implemented.') @@ -296,13 +300,13 @@ def roundtripContainer(self, cache_spec=True): # catch all warnings with warnings.catch_warnings(record=True) as ws: - with NWBHDF5IO(self.filename, mode='w') as write_io: + with NWBHDF5IO(self.filename, mode='w', manager=self.get_manager()) as write_io: write_io.write(self.nwbfile, cache_spec=cache_spec) self.validate() # this is not closed until tearDown() or an exception from self.getContainer below - self.reader = NWBHDF5IO(self.filename, mode='r') + self.reader = NWBHDF5IO(self.filename, mode='r', manager=self.get_manager()) self.read_nwbfile = self.reader.read() # parse warnings and raise exceptions for certain types of warnings @@ -340,7 +344,7 @@ def roundtripExportContainer(self, cache_spec=True): self.validate() # this is not closed until tearDown() or an exception from self.getContainer below - self.export_reader = NWBHDF5IO(self.export_filename, mode='r') + self.export_reader = NWBHDF5IO(self.export_filename, mode='r', manager=self.get_manager()) self.read_exported_nwbfile = self.export_reader.read() # parse warnings and raise exceptions for certain types of warnings diff --git a/tests/integration/hdf5/test_extension.py b/tests/integration/hdf5/test_extension.py new file mode 100644 index 000000000..90ee1e468 --- /dev/null +++ b/tests/integration/hdf5/test_extension.py @@ -0,0 +1,75 @@ +from hdmf.build import BuildManager +from hdmf.common import VectorData +from hdmf.utils import docval, get_docval, popargs + +from pynwb import NWBFile +from pynwb.spec import NWBDatasetSpec, NWBAttributeSpec +from pynwb.testing import NWBH5IOFlexMixin, TestCase + +from ..helpers.utils import create_test_extension + + +class TestDynamicTableCustomColumnWithArgs(NWBH5IOFlexMixin, TestCase): + + class SubVectorData(VectorData): + __fields__ = ('extra_kwarg', ) + + @docval( + *get_docval(VectorData.__init__, "name", "description", "data"), + {'name': 'extra_kwarg', 'type': 'str', 'doc': 'An extra kwarg.'}, + ) + def __init__(self, **kwargs): + extra_kwarg = popargs('extra_kwarg', kwargs) + super().__init__(**kwargs) + self.extra_kwarg = extra_kwarg + + def setUp(self): + """Set up an extension with a custom VectorData column.""" + + spec = NWBDatasetSpec( + neurodata_type_def='SubVectorData', + neurodata_type_inc='VectorData', + doc='A custom VectorData column.', + dtype='text', + shape=(None,), + attributes=[ + NWBAttributeSpec( + name="extra_kwarg", + doc='An extra kwarg.', + dtype='text' + ), + ], + ) + + self.type_map = create_test_extension([spec], {"SubVectorData": self.SubVectorData}) + self.manager = BuildManager(self.type_map) + super().setUp() + + def get_manager(self): + return self.manager + + def getContainerType(self): + return "TrialsWithCustomColumnsWithArgs" + + def addContainer(self): + """ Add the test DynamicTable to the given NWBFile """ + self.nwbfile.add_trial_column( + name="test", + description="test", + col_cls=self.SubVectorData, + extra_kwarg="test_extra_kwarg" + ) + self.nwbfile.add_trial(start_time=1.0, stop_time=2.0, test="test_data") + + def getContainer(self, nwbfile: NWBFile): + return nwbfile.trials["test"] + + def test_roundtrip(self): + super().test_roundtrip() + assert isinstance(self.read_container, self.SubVectorData) + assert self.read_container.extra_kwarg == "test_extra_kwarg" + + def test_roundtrip_export(self): + super().test_roundtrip_export() + assert isinstance(self.read_container, self.SubVectorData) + assert self.read_container.extra_kwarg == "test_extra_kwarg" diff --git a/tests/integration/helpers/__init__.py b/tests/integration/helpers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/integration/helpers/utils.py b/tests/integration/helpers/utils.py new file mode 100644 index 000000000..c8c9770a6 --- /dev/null +++ b/tests/integration/helpers/utils.py @@ -0,0 +1,31 @@ +"""Utilities for creating a custom TypeMap for testing so that we don't use the global type map.""" +import tempfile +from pynwb import get_type_map +from pynwb.spec import NWBNamespaceBuilder, export_spec + + +NAMESPACE_NAME = "test_core" + + +def create_test_extension(specs, container_classes, mappers=None): + ns_builder = NWBNamespaceBuilder( + name=NAMESPACE_NAME, + version="0.1.0", + doc="test extension", + ) + ns_builder.include_namespace("core") + + output_dir = tempfile.TemporaryDirectory() + export_spec(ns_builder, specs, output_dir.name) + + # this will copy the global pynwb TypeMap and add the extension types to the copy + type_map = get_type_map(f"{output_dir.name}/{NAMESPACE_NAME}.namespace.yaml") + for type_name, container_cls in container_classes.items(): + type_map.register_container_type(NAMESPACE_NAME, type_name, container_cls) + if mappers: + for type_name, mapper_cls in mappers.items(): + container_cls = container_classes[type_name] + type_map.register_map(container_cls, mapper_cls) + + output_dir.cleanup() + return type_map diff --git a/tests/unit/test_file.py b/tests/unit/test_file.py index 756009ff3..17870acd4 100644 --- a/tests/unit/test_file.py +++ b/tests/unit/test_file.py @@ -4,6 +4,8 @@ from datetime import datetime, timedelta from dateutil.tz import tzlocal, tzutc +from hdmf.common import VectorData +from hdmf.utils import docval, get_docval, popargs from pynwb import NWBFile, TimeSeries, NWBHDF5IO from pynwb.base import Image, Images from pynwb.file import Subject, ElectrodeTable, _add_missing_timezone @@ -222,6 +224,27 @@ def test_add_trial_column(self): self.nwbfile.add_trial_column('trial_type', 'the type of trial') self.assertEqual(self.nwbfile.trials.colnames, ('start_time', 'stop_time', 'trial_type')) + def test_add_trial_column_custom_class(self): + class SubVectorData(VectorData): + __fields__ = ('extra_kwarg', ) + + @docval( + *get_docval(VectorData.__init__, "name", "description", "data"), + {'name': 'extra_kwarg', 'type': 'str', 'doc': 'An extra kwarg.'}, + ) + def __init__(self, **kwargs): + extra_kwarg = popargs('extra_kwarg', kwargs) + super().__init__(**kwargs) + self.extra_kwarg = extra_kwarg + + self.nwbfile.add_trial_column( + name="test", + description="test", + col_cls=SubVectorData, + extra_kwarg="test_extra_kwarg" + ) + self.assertEqual(self.nwbfile.trials["test"].extra_kwarg, "test_extra_kwarg") + def test_add_trial(self): self.nwbfile.add_trial(start_time=10.0, stop_time=20.0) self.assertEqual(len(self.nwbfile.trials), 1)