Skip to content

Commit

Permalink
Add allow_extra=True to NWBFile.add_x_column (#1861)
Browse files Browse the repository at this point in the history
  • Loading branch information
rly authored Mar 22, 2024
1 parent 751983b commit 5b80bea
Show file tree
Hide file tree
Showing 7 changed files with 148 additions and 10 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# PyNWB Changelog

## PyNWB 2.6.1 (Upcoming)

### Bug fixes
- Fix bug where extra keyword arguments could not be passed to `NWBFile.add_{x}_column`` for use in custom `VectorData`` classes. @rly [#1861](https://github.com/NeurodataWithoutBorders/pynwb/pull/1861)

## PyNWB 2.6.0 (February 21, 2024)

### Enhancements and minor changes
Expand Down
10 changes: 5 additions & 5 deletions src/pynwb/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,7 @@ def __check_epochs(self):
if self.epochs is None:
self.epochs = TimeIntervals(name='epochs', description='experimental epochs')

@docval(*get_docval(TimeIntervals.add_column))
@docval(*get_docval(TimeIntervals.add_column), allow_extra=True)
def add_epoch_column(self, **kwargs):
"""
Add a column to the epoch table.
Expand Down Expand Up @@ -645,7 +645,7 @@ def __check_electrodes(self):
if self.electrodes is None:
self.electrodes = ElectrodeTable()

@docval(*get_docval(DynamicTable.add_column))
@docval(*get_docval(DynamicTable.add_column), allow_extra=True)
def add_electrode_column(self, **kwargs):
"""
Add a column to the electrode table.
Expand Down Expand Up @@ -747,7 +747,7 @@ def __check_units(self):
if self.units is None:
self.units = Units(name='units', description='Autogenerated by NWBFile')

@docval(*get_docval(Units.add_column))
@docval(*get_docval(Units.add_column), allow_extra=True)
def add_unit_column(self, **kwargs):
"""
Add a column to the unit table.
Expand All @@ -770,7 +770,7 @@ def __check_trials(self):
if self.trials is None:
self.trials = TimeIntervals(name='trials', description='experimental trials')

@docval(*get_docval(DynamicTable.add_column))
@docval(*get_docval(DynamicTable.add_column), allow_extra=True)
def add_trial_column(self, **kwargs):
"""
Add a column to the trial table.
Expand Down Expand Up @@ -798,7 +798,7 @@ def __check_invalid_times(self):
description='time intervals to be removed from analysis'
)

@docval(*get_docval(DynamicTable.add_column))
@docval(*get_docval(DynamicTable.add_column), allow_extra=True)
def add_invalid_times_column(self, **kwargs):
"""
Add a column to the invalid times table.
Expand Down
14 changes: 9 additions & 5 deletions src/pynwb/testing/testh5io.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from abc import ABCMeta, abstractmethod
import warnings

from pynwb import NWBFile, NWBHDF5IO, validate as pynwb_validate
from pynwb import NWBFile, NWBHDF5IO, get_manager, validate as pynwb_validate
from .utils import remove_test_file
from hdmf.backends.warnings import BrokenLinkWarning
from hdmf.build.warnings import MissingRequiredBuildWarning
Expand Down Expand Up @@ -247,7 +247,11 @@ def tearDown(self):
remove_test_file(self.filename)
remove_test_file(self.export_filename)

def getContainerType() -> str:
def get_manager(self):
return get_manager() # get the pynwb manager unless overridden

@abstractmethod
def getContainerType(self) -> str:
"""Return the name of the type of Container being tested, for test ID purposes."""
raise NotImplementedError('Cannot run test unless getContainerType is implemented.')

Expand Down Expand Up @@ -296,13 +300,13 @@ def roundtripContainer(self, cache_spec=True):

# catch all warnings
with warnings.catch_warnings(record=True) as ws:
with NWBHDF5IO(self.filename, mode='w') as write_io:
with NWBHDF5IO(self.filename, mode='w', manager=self.get_manager()) as write_io:
write_io.write(self.nwbfile, cache_spec=cache_spec)

self.validate()

# this is not closed until tearDown() or an exception from self.getContainer below
self.reader = NWBHDF5IO(self.filename, mode='r')
self.reader = NWBHDF5IO(self.filename, mode='r', manager=self.get_manager())
self.read_nwbfile = self.reader.read()

# parse warnings and raise exceptions for certain types of warnings
Expand Down Expand Up @@ -340,7 +344,7 @@ def roundtripExportContainer(self, cache_spec=True):
self.validate()

# this is not closed until tearDown() or an exception from self.getContainer below
self.export_reader = NWBHDF5IO(self.export_filename, mode='r')
self.export_reader = NWBHDF5IO(self.export_filename, mode='r', manager=self.get_manager())
self.read_exported_nwbfile = self.export_reader.read()

# parse warnings and raise exceptions for certain types of warnings
Expand Down
75 changes: 75 additions & 0 deletions tests/integration/hdf5/test_extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from hdmf.build import BuildManager
from hdmf.common import VectorData
from hdmf.utils import docval, get_docval, popargs

from pynwb import NWBFile
from pynwb.spec import NWBDatasetSpec, NWBAttributeSpec
from pynwb.testing import NWBH5IOFlexMixin, TestCase

from ..helpers.utils import create_test_extension


class TestDynamicTableCustomColumnWithArgs(NWBH5IOFlexMixin, TestCase):

class SubVectorData(VectorData):
__fields__ = ('extra_kwarg', )

@docval(
*get_docval(VectorData.__init__, "name", "description", "data"),
{'name': 'extra_kwarg', 'type': 'str', 'doc': 'An extra kwarg.'},
)
def __init__(self, **kwargs):
extra_kwarg = popargs('extra_kwarg', kwargs)
super().__init__(**kwargs)
self.extra_kwarg = extra_kwarg

def setUp(self):
"""Set up an extension with a custom VectorData column."""

spec = NWBDatasetSpec(
neurodata_type_def='SubVectorData',
neurodata_type_inc='VectorData',
doc='A custom VectorData column.',
dtype='text',
shape=(None,),
attributes=[
NWBAttributeSpec(
name="extra_kwarg",
doc='An extra kwarg.',
dtype='text'
),
],
)

self.type_map = create_test_extension([spec], {"SubVectorData": self.SubVectorData})
self.manager = BuildManager(self.type_map)
super().setUp()

def get_manager(self):
return self.manager

def getContainerType(self):
return "TrialsWithCustomColumnsWithArgs"

def addContainer(self):
""" Add the test DynamicTable to the given NWBFile """
self.nwbfile.add_trial_column(
name="test",
description="test",
col_cls=self.SubVectorData,
extra_kwarg="test_extra_kwarg"
)
self.nwbfile.add_trial(start_time=1.0, stop_time=2.0, test="test_data")

def getContainer(self, nwbfile: NWBFile):
return nwbfile.trials["test"]

def test_roundtrip(self):
super().test_roundtrip()
assert isinstance(self.read_container, self.SubVectorData)
assert self.read_container.extra_kwarg == "test_extra_kwarg"

def test_roundtrip_export(self):
super().test_roundtrip_export()
assert isinstance(self.read_container, self.SubVectorData)
assert self.read_container.extra_kwarg == "test_extra_kwarg"
Empty file.
31 changes: 31 additions & 0 deletions tests/integration/helpers/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""Utilities for creating a custom TypeMap for testing so that we don't use the global type map."""
import tempfile
from pynwb import get_type_map
from pynwb.spec import NWBNamespaceBuilder, export_spec


NAMESPACE_NAME = "test_core"


def create_test_extension(specs, container_classes, mappers=None):
ns_builder = NWBNamespaceBuilder(
name=NAMESPACE_NAME,
version="0.1.0",
doc="test extension",
)
ns_builder.include_namespace("core")

output_dir = tempfile.TemporaryDirectory()
export_spec(ns_builder, specs, output_dir.name)

# this will copy the global pynwb TypeMap and add the extension types to the copy
type_map = get_type_map(f"{output_dir.name}/{NAMESPACE_NAME}.namespace.yaml")
for type_name, container_cls in container_classes.items():
type_map.register_container_type(NAMESPACE_NAME, type_name, container_cls)
if mappers:
for type_name, mapper_cls in mappers.items():
container_cls = container_classes[type_name]
type_map.register_map(container_cls, mapper_cls)

output_dir.cleanup()
return type_map
23 changes: 23 additions & 0 deletions tests/unit/test_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from datetime import datetime, timedelta
from dateutil.tz import tzlocal, tzutc

from hdmf.common import VectorData
from hdmf.utils import docval, get_docval, popargs
from pynwb import NWBFile, TimeSeries, NWBHDF5IO
from pynwb.base import Image, Images
from pynwb.file import Subject, ElectrodeTable, _add_missing_timezone
Expand Down Expand Up @@ -222,6 +224,27 @@ def test_add_trial_column(self):
self.nwbfile.add_trial_column('trial_type', 'the type of trial')
self.assertEqual(self.nwbfile.trials.colnames, ('start_time', 'stop_time', 'trial_type'))

def test_add_trial_column_custom_class(self):
class SubVectorData(VectorData):
__fields__ = ('extra_kwarg', )

@docval(
*get_docval(VectorData.__init__, "name", "description", "data"),
{'name': 'extra_kwarg', 'type': 'str', 'doc': 'An extra kwarg.'},
)
def __init__(self, **kwargs):
extra_kwarg = popargs('extra_kwarg', kwargs)
super().__init__(**kwargs)
self.extra_kwarg = extra_kwarg

self.nwbfile.add_trial_column(
name="test",
description="test",
col_cls=SubVectorData,
extra_kwarg="test_extra_kwarg"
)
self.assertEqual(self.nwbfile.trials["test"].extra_kwarg, "test_extra_kwarg")

def test_add_trial(self):
self.nwbfile.add_trial(start_time=10.0, stop_time=20.0)
self.assertEqual(len(self.nwbfile.trials), 1)
Expand Down

0 comments on commit 5b80bea

Please sign in to comment.