Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add allow_extra=True to NWBFile.add_x_column #1861

Merged
merged 7 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# PyNWB Changelog

## PyNWB 2.6.1 (Upcoming)

### Bug fixes
- Fix bug where extra keyword arguments could not be passed to `NWBFile.add_{x}_column`` for use in custom `VectorData`` classes. @rly [#1861](https://github.com/NeurodataWithoutBorders/pynwb/pull/1861)

## PyNWB 2.6.0 (February 21, 2024)

### Enhancements and minor changes
Expand Down
10 changes: 5 additions & 5 deletions src/pynwb/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,7 @@ def __check_epochs(self):
if self.epochs is None:
self.epochs = TimeIntervals(name='epochs', description='experimental epochs')

@docval(*get_docval(TimeIntervals.add_column))
@docval(*get_docval(TimeIntervals.add_column), allow_extra=True)
def add_epoch_column(self, **kwargs):
"""
Add a column to the epoch table.
Expand Down Expand Up @@ -645,7 +645,7 @@ def __check_electrodes(self):
if self.electrodes is None:
self.electrodes = ElectrodeTable()

@docval(*get_docval(DynamicTable.add_column))
@docval(*get_docval(DynamicTable.add_column), allow_extra=True)
def add_electrode_column(self, **kwargs):
"""
Add a column to the electrode table.
Expand Down Expand Up @@ -747,7 +747,7 @@ def __check_units(self):
if self.units is None:
self.units = Units(name='units', description='Autogenerated by NWBFile')

@docval(*get_docval(Units.add_column))
@docval(*get_docval(Units.add_column), allow_extra=True)
def add_unit_column(self, **kwargs):
"""
Add a column to the unit table.
Expand All @@ -770,7 +770,7 @@ def __check_trials(self):
if self.trials is None:
self.trials = TimeIntervals(name='trials', description='experimental trials')

@docval(*get_docval(DynamicTable.add_column))
@docval(*get_docval(DynamicTable.add_column), allow_extra=True)
def add_trial_column(self, **kwargs):
"""
Add a column to the trial table.
Expand Down Expand Up @@ -798,7 +798,7 @@ def __check_invalid_times(self):
description='time intervals to be removed from analysis'
)

@docval(*get_docval(DynamicTable.add_column))
@docval(*get_docval(DynamicTable.add_column), allow_extra=True)
def add_invalid_times_column(self, **kwargs):
"""
Add a column to the invalid times table.
Expand Down
14 changes: 9 additions & 5 deletions src/pynwb/testing/testh5io.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from abc import ABCMeta, abstractmethod
import warnings

from pynwb import NWBFile, NWBHDF5IO, validate as pynwb_validate
from pynwb import NWBFile, NWBHDF5IO, get_manager, validate as pynwb_validate
from .utils import remove_test_file
from hdmf.backends.warnings import BrokenLinkWarning
from hdmf.build.warnings import MissingRequiredBuildWarning
Expand Down Expand Up @@ -247,7 +247,11 @@ def tearDown(self):
remove_test_file(self.filename)
remove_test_file(self.export_filename)

def getContainerType() -> str:
def get_manager(self):
return get_manager() # get the pynwb manager unless overridden

@abstractmethod
def getContainerType(self) -> str:
"""Return the name of the type of Container being tested, for test ID purposes."""
raise NotImplementedError('Cannot run test unless getContainerType is implemented.')

Expand Down Expand Up @@ -296,13 +300,13 @@ def roundtripContainer(self, cache_spec=True):

# catch all warnings
with warnings.catch_warnings(record=True) as ws:
with NWBHDF5IO(self.filename, mode='w') as write_io:
with NWBHDF5IO(self.filename, mode='w', manager=self.get_manager()) as write_io:
write_io.write(self.nwbfile, cache_spec=cache_spec)

self.validate()

# this is not closed until tearDown() or an exception from self.getContainer below
self.reader = NWBHDF5IO(self.filename, mode='r')
self.reader = NWBHDF5IO(self.filename, mode='r', manager=self.get_manager())
self.read_nwbfile = self.reader.read()

# parse warnings and raise exceptions for certain types of warnings
Expand Down Expand Up @@ -340,7 +344,7 @@ def roundtripExportContainer(self, cache_spec=True):
self.validate()

# this is not closed until tearDown() or an exception from self.getContainer below
self.export_reader = NWBHDF5IO(self.export_filename, mode='r')
self.export_reader = NWBHDF5IO(self.export_filename, mode='r', manager=self.get_manager())
self.read_exported_nwbfile = self.export_reader.read()

# parse warnings and raise exceptions for certain types of warnings
Expand Down
75 changes: 75 additions & 0 deletions tests/integration/hdf5/test_extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from hdmf.build import BuildManager
from hdmf.common import VectorData
from hdmf.utils import docval, get_docval, popargs

from pynwb import NWBFile
from pynwb.spec import NWBDatasetSpec, NWBAttributeSpec
from pynwb.testing import NWBH5IOFlexMixin, TestCase

from ..helpers.utils import create_test_extension


class TestDynamicTableCustomColumnWithArgs(NWBH5IOFlexMixin, TestCase):

class SubVectorData(VectorData):
__fields__ = ('extra_kwarg', )

@docval(
*get_docval(VectorData.__init__, "name", "description", "data"),
{'name': 'extra_kwarg', 'type': 'str', 'doc': 'An extra kwarg.'},
)
def __init__(self, **kwargs):
extra_kwarg = popargs('extra_kwarg', kwargs)
super().__init__(**kwargs)
self.extra_kwarg = extra_kwarg

def setUp(self):
"""Set up an extension with a custom VectorData column."""

spec = NWBDatasetSpec(
neurodata_type_def='SubVectorData',
neurodata_type_inc='VectorData',
doc='A custom VectorData column.',
dtype='text',
shape=(None,),
attributes=[
NWBAttributeSpec(
name="extra_kwarg",
doc='An extra kwarg.',
dtype='text'
),
],
)

self.type_map = create_test_extension([spec], {"SubVectorData": self.SubVectorData})
self.manager = BuildManager(self.type_map)
super().setUp()

def get_manager(self):
return self.manager

def getContainerType(self):
return "TrialsWithCustomColumnsWithArgs"

def addContainer(self):
""" Add the test DynamicTable to the given NWBFile """
self.nwbfile.add_trial_column(
name="test",
description="test",
col_cls=self.SubVectorData,
extra_kwarg="test_extra_kwarg"
)
self.nwbfile.add_trial(start_time=1.0, stop_time=2.0, test="test_data")

def getContainer(self, nwbfile: NWBFile):
return nwbfile.trials["test"]

def test_roundtrip(self):
super().test_roundtrip()
assert isinstance(self.read_container, self.SubVectorData)
assert self.read_container.extra_kwarg == "test_extra_kwarg"

def test_roundtrip_export(self):
super().test_roundtrip_export()
assert isinstance(self.read_container, self.SubVectorData)
assert self.read_container.extra_kwarg == "test_extra_kwarg"
Empty file.
31 changes: 31 additions & 0 deletions tests/integration/helpers/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""Utilities for creating a custom TypeMap for testing so that we don't use the global type map."""
import tempfile
from pynwb import get_type_map
from pynwb.spec import NWBNamespaceBuilder, export_spec


NAMESPACE_NAME = "test_core"


def create_test_extension(specs, container_classes, mappers=None):
ns_builder = NWBNamespaceBuilder(
name=NAMESPACE_NAME,
version="0.1.0",
doc="test extension",
)
ns_builder.include_namespace("core")

output_dir = tempfile.TemporaryDirectory()
export_spec(ns_builder, specs, output_dir.name)

# this will copy the global pynwb TypeMap and add the extension types to the copy
type_map = get_type_map(f"{output_dir.name}/{NAMESPACE_NAME}.namespace.yaml")
for type_name, container_cls in container_classes.items():
type_map.register_container_type(NAMESPACE_NAME, type_name, container_cls)
if mappers:
for type_name, mapper_cls in mappers.items():
container_cls = container_classes[type_name]
type_map.register_map(container_cls, mapper_cls)

output_dir.cleanup()
return type_map
23 changes: 23 additions & 0 deletions tests/unit/test_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from datetime import datetime, timedelta
from dateutil.tz import tzlocal, tzutc

from hdmf.common import VectorData
from hdmf.utils import docval, get_docval, popargs
from pynwb import NWBFile, TimeSeries, NWBHDF5IO
from pynwb.base import Image, Images
from pynwb.file import Subject, ElectrodeTable, _add_missing_timezone
Expand Down Expand Up @@ -222,6 +224,27 @@ def test_add_trial_column(self):
self.nwbfile.add_trial_column('trial_type', 'the type of trial')
self.assertEqual(self.nwbfile.trials.colnames, ('start_time', 'stop_time', 'trial_type'))

def test_add_trial_column_custom_class(self):
class SubVectorData(VectorData):
__fields__ = ('extra_kwarg', )

@docval(
*get_docval(VectorData.__init__, "name", "description", "data"),
{'name': 'extra_kwarg', 'type': 'str', 'doc': 'An extra kwarg.'},
)
def __init__(self, **kwargs):
extra_kwarg = popargs('extra_kwarg', kwargs)
super().__init__(**kwargs)
self.extra_kwarg = extra_kwarg

self.nwbfile.add_trial_column(
name="test",
description="test",
col_cls=SubVectorData,
extra_kwarg="test_extra_kwarg"
)
self.assertEqual(self.nwbfile.trials["test"].extra_kwarg, "test_extra_kwarg")

def test_add_trial(self):
self.nwbfile.add_trial(start_time=10.0, stop_time=20.0)
self.assertEqual(len(self.nwbfile.trials), 1)
Expand Down
Loading