Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Warn instead of raise error if SpatialSeries has more than 3 columns #1480

Merged
merged 4 commits into from
May 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
## PyNWB 2.1.0 (Upcoming)

### Breaking changes:
- Restrict `SpatialSeries.data` to have no more than 3 columns (#1455)
- Raise a warning if `SpatialSeries.data` has more than 3 columns (#1455, #1480)
- Updated ``TimeIntervals`` to use the new ``TimeSeriesReferenceVectorData`` type. This does not alter the overall
structure of ``TimeIntervals`` in a major way aside from changing the value of the ``neurodata_type`` attribute of the
``TimeIntervals.timeseries`` column from ``VectorData`` to ``TimeSeriesReferenceVectorData``. This change facilitates
Expand Down
26 changes: 22 additions & 4 deletions src/pynwb/behavior.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from hdmf.utils import docval, popargs, get_docval
import warnings

from hdmf.utils import docval, popargs, get_docval, get_data_shape

from . import register_class, CORE_NAMESPACE
from .core import MultiContainerInterface
Expand All @@ -21,9 +23,7 @@ class SpatialSeries(TimeSeries):
__nwbfields__ = ('reference_frame',)

@docval(*get_docval(TimeSeries.__init__, 'name'), # required
{'name': 'data', 'type': ('array_data', 'data', TimeSeries), 'shape': (
(None, ), (None, 1), (None, 2), (None, 3)
), # required
{'name': 'data', 'type': ('array_data', 'data', TimeSeries), 'shape': ((None, ), (None, None)), # required
'doc': ('The data values. Can be 1D or 2D. The first dimension must be time. If 2D, there can be 1, 2, '
'or 3 columns, which represent x, y, and z.')},
{'name': 'reference_frame', 'type': str, # required
Expand All @@ -38,8 +38,26 @@ def __init__(self, **kwargs):
"""
name, data, reference_frame, unit = popargs('name', 'data', 'reference_frame', 'unit', kwargs)
super(SpatialSeries, self).__init__(name, data, unit, **kwargs)

# NWB 2.5 restricts length of second dimension to be <= 3
allowed_data_shapes = ((None, ), (None, 1), (None, 2), (None, 3))
data_shape = get_data_shape(data)
if not any(self._validate_data_shape(data_shape, a) for a in allowed_data_shapes):
warnings.warn("SpatialSeries '%s' has data shape %s which is not compliant with NWB 2.5 and greater. "
"The second dimension should have length <= 3 to represent at most x, y, z." %
(name, str(data_shape)))

self.reference_frame = reference_frame

@staticmethod
def _validate_data_shape(valshape, argshape):
if not len(valshape) == len(argshape):
return False
for a, b in zip(valshape, argshape):
if b not in (a, None):
return False
return True


@register_class('BehavioralEpochs', CORE_NAMESPACE)
class BehavioralEpochs(MultiContainerInterface):
Expand Down
10 changes: 3 additions & 7 deletions tests/unit/test_behavior.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,11 @@ def test_set_unit(self):
self.assertEqual(sS.unit, 'degrees')

def test_gt_3_cols(self):
with self.assertRaises(ValueError) as error:
msg = ("SpatialSeries 'test_sS' has data shape (5, 4) which is not compliant with NWB 2.5 and greater. "
"The second dimension should have length <= 3 to represent at most x, y, z.")
with self.assertWarnsWith(UserWarning, msg):
SpatialSeries("test_sS", np.ones((5, 4)), "reference_frame", "meters", rate=30.)

self.assertEqual(
"SpatialSeries.__init__: incorrect shape for 'data' (got '(5, 4)', expected "
"'((None,), (None, 1), (None, 2), (None, 3))')",
str(error.exception)
)


class BehavioralEpochsConstructor(TestCase):
def test_init(self):
Expand Down