Skip to content

Commit

Permalink
Automatically migrate TimeIntervals.timeseries column from VectorData…
Browse files Browse the repository at this point in the history
… to TimeSeriesReferenceVectorData
  • Loading branch information
oruebel committed Aug 11, 2021
1 parent 472be82 commit fe87879
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 4 deletions.
37 changes: 35 additions & 2 deletions src/pynwb/epoch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from hdmf.data_utils import DataIO

from . import register_class, CORE_NAMESPACE
from .base import TimeSeries
from .base import TimeSeries, TimeSeriesReferenceVectorData, TimeSeriesReference
from hdmf.common import DynamicTable


Expand All @@ -29,6 +29,39 @@ class TimeIntervals(DynamicTable):
*get_docval(DynamicTable.__init__, 'id', 'columns', 'colnames'))
def __init__(self, **kwargs):
call_docval_func(super(TimeIntervals, self).__init__, kwargs)
self.__timeseries_column_type_migrated = False
self.__migrate_timeseries_column_type()

def __migrate_timeseries_column_type(self):
"""
Internal helper function used to migrate the self.timeseries column
from a regular VectorData to a TimeSeriesReferenceVectorData object
if necessary.
"""
if getattr(self, 'timeseries', None) is not None:
if not isinstance(self.timeseries, TimeSeriesReferenceVectorData):
self.timeseries.__class__ = TimeSeriesReferenceVectorData
self.__timeseries_column_type_migrated = True

@docval(*get_docval(DynamicTable.add_column))
def add_column(self, **kwargs):
"""
Overwrite :py:meth:~hdmf.common.table.VectorData.add_column` to
automatically migrate the :py:meth:`~pynwb.epoch.TimeIntervals.timeseries`
from a regular :py:class:`~hdmf.common.table.VectorData` to a
:py:class:`~pynwb.base.TimeSeriesReferenceVectorData` if necessary.
"""
super(TimeIntervals, self).add_column(**kwargs)
self.__migrate_timeseries_column_type()

@property
def timeseries_column_type_migrated(self):
"""
Check whether the :py:meth:`~pynwb.epoch.TimeIntervals.timeseries` column has been automatically migrated
from a regular :py:class:`~hdmf.common.table.VectorData` to a
:py:class:`~pynwb.base.TimeSeriesReferenceVectorData`
"""
return self.__timeseries_column_type_migrated

@docval({'name': 'start_time', 'type': 'float', 'doc': 'Start time of epoch, in seconds'},
{'name': 'stop_time', 'type': 'float', 'doc': 'Stop time of epoch, in seconds'},
Expand All @@ -51,7 +84,7 @@ def add_interval(self, **kwargs):
tmp = list()
for ts in timeseries:
idx_start, count = self.__calculate_idx_count(start_time, stop_time, ts)
tmp.append((idx_start, count, ts))
tmp.append(TimeSeriesReference(idx_start, count, ts))
timeseries = tmp
rkwargs['timeseries'] = timeseries
return super(TimeIntervals, self).add_row(**rkwargs)
Expand Down
9 changes: 7 additions & 2 deletions tests/unit/test_epoch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from pynwb.epoch import TimeIntervals
from pynwb import TimeSeries, NWBFile
from pynwb.base import TimeSeriesReference
from pynwb.testing import TestCase


Expand Down Expand Up @@ -36,7 +37,11 @@ def get_dataframe(self):
'bar': ['fish', 'fowl', 'dog', 'cat'],
'start_time': [0.2, 0.25, 0.30, 0.35],
'stop_time': [0.25, 0.30, 0.40, 0.45],
'timeseries': [[tsa], [tsb], [], [tsb, tsa]],
'timeseries': [[TimeSeriesReference(idx_start=0, count=11, timeseries=tsa)],
[TimeSeriesReference(idx_start=0, count=13, timeseries=tsb)],
[],
[TimeSeriesReference(idx_start=4, count=6, timeseries=tsb),
TimeSeriesReference(idx_start=3, count=4, timeseries=tsa)]],
'keys': ['q', 'w', 'e', 'r'],
'tags': [[], [], ['fizz', 'buzz'], ['qaz']]
})
Expand All @@ -46,7 +51,7 @@ def test_dataframe_roundtrip(self):
epochs = TimeIntervals.from_dataframe(df, name='test epochs')
obtained = epochs.to_dataframe()

self.assertIs(obtained.loc[3, 'timeseries'][1], df.loc[3, 'timeseries'][1])
self.assertTupleEqual(obtained.loc[3, 'timeseries'][1], df.loc[3, 'timeseries'][1])
self.assertEqual(obtained.loc[2, 'foo'], df.loc[2, 'foo'])

def test_dataframe_roundtrip_drop_ts(self):
Expand Down

0 comments on commit fe87879

Please sign in to comment.