Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add waveforms to units table #1330

Merged
merged 11 commits into from
Jan 14, 2021
15 changes: 11 additions & 4 deletions src/pynwb/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,11 @@ class Units(DynamicTable):
'resolution'
)

waveforms_desc = ('Individual waveforms for each spike. If the dataset is three-dimensional, the third dimension '
'shows the response from different electrodes that all observe this unit simultaneously. In this'
' case, the `electrodes` column of this Units table should be used to indicate which electrodes '
'are associated with this unit, and the electrodes dimension here should be in the same order as'
' the electrodes referenced in the `electrodes` column of this table.')
__columns__ = (
{'name': 'spike_times', 'description': 'the spike times for each unit', 'index': True},
{'name': 'obs_intervals', 'description': 'the observation intervals for each unit',
Expand All @@ -148,7 +153,8 @@ class Units(DynamicTable):
'index': True, 'table': True},
{'name': 'electrode_group', 'description': 'the electrode group that each spike unit came from'},
{'name': 'waveform_mean', 'description': 'the spike waveform mean for each spike unit'},
{'name': 'waveform_sd', 'description': 'the spike waveform standard deviation for each spike unit'}
{'name': 'waveform_sd', 'description': 'the spike waveform standard deviation for each spike unit'},
{'name': 'waveforms', 'description': waveforms_desc, 'index': 2}
)

@docval({'name': 'name', 'type': str, 'doc': 'Name of this Units interface', 'default': 'Units'},
Expand All @@ -161,7 +167,7 @@ class Units(DynamicTable):
{'name': 'waveform_unit', 'type': str,
'doc': 'Unit of measurement of the waveform means', 'default': 'volts'},
{'name': 'resolution', 'type': 'float',
'doc': 'The smallest possible difference between two spike times', 'default': None},
'doc': 'The smallest possible difference between two spike times', 'default': None}
)
def __init__(self, **kwargs):
if kwargs.get('description', None) is None:
Expand Down Expand Up @@ -189,8 +195,9 @@ def __init__(self, **kwargs):
'default': None},
{'name': 'waveform_sd', 'type': 'array_data', 'default': None,
'doc': 'the spike waveform standard deviation for each unit. Shape is (time,) or (time, electrodes)'},
{'name': 'id', 'type': int, 'default': None,
'doc': 'the id for each unit'},
{'name': 'waveforms', 'type': 'array_data', 'default': None, 'doc': waveforms_desc,
'shape': ((None, None), (None, None, None))},
{'name': 'id', 'type': int, 'default': None, 'doc': 'the id for each unit'},
allow_extra=True)
def add_unit(self, **kwargs):
"""
Expand Down
33 changes: 31 additions & 2 deletions tests/integration/hdf5/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,38 @@ def setUpContainer(self):
""" Return the test Units to read/write """
ut = Units(name='UnitsTest', description='a simple table for testing Units')
ut.add_unit(spike_times=[0, 1, 2], obs_intervals=[[0, 1], [2, 3]],
waveform_mean=[1., 2., 3.], waveform_sd=[4., 5., 6.])
waveform_mean=[1., 2., 3.], waveform_sd=[4., 5., 6.],
waveforms=[
[ # elec 1
[1, 2, 3],
[1, 2, 3],
[1, 2, 3]
], [ # elec 2
[1, 2, 3],
[1, 2, 3],
[1, 2, 3]
]
])
ut.add_unit(spike_times=[3, 4, 5], obs_intervals=[[2, 5], [6, 7]],
waveform_mean=[1., 2., 3.], waveform_sd=[4., 5., 6.])
waveform_mean=[1., 2., 3.], waveform_sd=[4., 5., 6.],
waveforms=np.array([
[ # elec 1
[1, 2, 3], # spike 1, [sample 1, sample 2, sample 3]
[1, 2, 3], # spike 2
[1, 2, 3], # spike 3
[1, 2, 3] # spike 4
], [ # elec 2
[1, 2, 3], # spike 1
[1, 2, 3], # spike 2
[1, 2, 3], # spike 3
[1, 2, 3] # spike 4
], [ # elec 3
[1, 2, 3], # spike 1
[1, 2, 3], # spike 2
[1, 2, 3], # spike 3
[1, 2, 3] # spike 4
]
]))
ut.waveform_rate = 40000.
ut.resolution = 1/40000
return ut
Expand Down
39 changes: 39 additions & 0 deletions tests/unit/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,45 @@ def test_add_spike_times(self):
self.assertEqual(ut['spike_times'][0], [0, 1, 2])
self.assertEqual(ut['spike_times'][1], [3, 4, 5])

def test_add_waveforms(self):
ut = Units()
wf1 = [
[ # elec 1
[1, 2, 3],
[1, 2, 3],
[1, 2, 3]
], [ # elec 2
[1, 2, 3],
[1, 2, 3],
[1, 2, 3]
]
]
wf2 = [
[ # elec 1
[1, 2, 3], # spike 1, [sample 1, sample 2, sample 3]
[1, 2, 3], # spike 2
[1, 2, 3], # spike 3
[1, 2, 3] # spike 4
], [ # elec 2
[1, 2, 3], # spike 1
[1, 2, 3], # spike 2
[1, 2, 3], # spike 3
[1, 2, 3] # spike 4
], [ # elec 3
[1, 2, 3], # spike 1
[1, 2, 3], # spike 2
[1, 2, 3], # spike 3
[1, 2, 3] # spike 4
]
]
ut.add_unit(waveforms=wf1)
ut.add_unit(waveforms=wf2)
self.assertEqual(ut.id.data, [0, 1])
self.assertEqual(ut['waveforms'].target.data, [3, 6, 10, 14, 18])
self.assertEqual(ut['waveforms'].data, [2, 5])
self.assertListEqual(ut['waveforms'][0], wf1)
self.assertListEqual(ut['waveforms'][1], wf2)

def test_get_spike_times(self):
ut = Units()
ut.add_unit(spike_times=[0, 1, 2])
Expand Down