Skip to content

Commit

Permalink
Add waveforms to units table (#1330)
Browse files Browse the repository at this point in the history
Co-authored-by: Ryan Ly <[email protected]>
Co-authored-by: Ben Dichter <[email protected]>
  • Loading branch information
3 people authored Jan 14, 2021
1 parent 9d32f5e commit 4ae60bc
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 6 deletions.
15 changes: 11 additions & 4 deletions src/pynwb/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,11 @@ class Units(DynamicTable):
'resolution'
)

waveforms_desc = ('Individual waveforms for each spike. If the dataset is three-dimensional, the third dimension '
'shows the response from different electrodes that all observe this unit simultaneously. In this'
' case, the `electrodes` column of this Units table should be used to indicate which electrodes '
'are associated with this unit, and the electrodes dimension here should be in the same order as'
' the electrodes referenced in the `electrodes` column of this table.')
__columns__ = (
{'name': 'spike_times', 'description': 'the spike times for each unit', 'index': True},
{'name': 'obs_intervals', 'description': 'the observation intervals for each unit',
Expand All @@ -148,7 +153,8 @@ class Units(DynamicTable):
'index': True, 'table': True},
{'name': 'electrode_group', 'description': 'the electrode group that each spike unit came from'},
{'name': 'waveform_mean', 'description': 'the spike waveform mean for each spike unit'},
{'name': 'waveform_sd', 'description': 'the spike waveform standard deviation for each spike unit'}
{'name': 'waveform_sd', 'description': 'the spike waveform standard deviation for each spike unit'},
{'name': 'waveforms', 'description': waveforms_desc, 'index': 2}
)

@docval({'name': 'name', 'type': str, 'doc': 'Name of this Units interface', 'default': 'Units'},
Expand All @@ -161,7 +167,7 @@ class Units(DynamicTable):
{'name': 'waveform_unit', 'type': str,
'doc': 'Unit of measurement of the waveform means', 'default': 'volts'},
{'name': 'resolution', 'type': 'float',
'doc': 'The smallest possible difference between two spike times', 'default': None},
'doc': 'The smallest possible difference between two spike times', 'default': None}
)
def __init__(self, **kwargs):
if kwargs.get('description', None) is None:
Expand Down Expand Up @@ -189,8 +195,9 @@ def __init__(self, **kwargs):
'default': None},
{'name': 'waveform_sd', 'type': 'array_data', 'default': None,
'doc': 'the spike waveform standard deviation for each unit. Shape is (time,) or (time, electrodes)'},
{'name': 'id', 'type': int, 'default': None,
'doc': 'the id for each unit'},
{'name': 'waveforms', 'type': 'array_data', 'default': None, 'doc': waveforms_desc,
'shape': ((None, None), (None, None, None))},
{'name': 'id', 'type': int, 'default': None, 'doc': 'the id for each unit'},
allow_extra=True)
def add_unit(self, **kwargs):
"""
Expand Down
33 changes: 31 additions & 2 deletions tests/integration/hdf5/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,38 @@ def setUpContainer(self):
""" Return the test Units to read/write """
ut = Units(name='UnitsTest', description='a simple table for testing Units')
ut.add_unit(spike_times=[0, 1, 2], obs_intervals=[[0, 1], [2, 3]],
waveform_mean=[1., 2., 3.], waveform_sd=[4., 5., 6.])
waveform_mean=[1., 2., 3.], waveform_sd=[4., 5., 6.],
waveforms=[
[ # elec 1
[1, 2, 3],
[1, 2, 3],
[1, 2, 3]
], [ # elec 2
[1, 2, 3],
[1, 2, 3],
[1, 2, 3]
]
])
ut.add_unit(spike_times=[3, 4, 5], obs_intervals=[[2, 5], [6, 7]],
waveform_mean=[1., 2., 3.], waveform_sd=[4., 5., 6.])
waveform_mean=[1., 2., 3.], waveform_sd=[4., 5., 6.],
waveforms=np.array([
[ # elec 1
[1, 2, 3], # spike 1, [sample 1, sample 2, sample 3]
[1, 2, 3], # spike 2
[1, 2, 3], # spike 3
[1, 2, 3] # spike 4
], [ # elec 2
[1, 2, 3], # spike 1
[1, 2, 3], # spike 2
[1, 2, 3], # spike 3
[1, 2, 3] # spike 4
], [ # elec 3
[1, 2, 3], # spike 1
[1, 2, 3], # spike 2
[1, 2, 3], # spike 3
[1, 2, 3] # spike 4
]
]))
ut.waveform_rate = 40000.
ut.resolution = 1/40000
return ut
Expand Down
39 changes: 39 additions & 0 deletions tests/unit/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,45 @@ def test_add_spike_times(self):
self.assertEqual(ut['spike_times'][0], [0, 1, 2])
self.assertEqual(ut['spike_times'][1], [3, 4, 5])

def test_add_waveforms(self):
ut = Units()
wf1 = [
[ # elec 1
[1, 2, 3],
[1, 2, 3],
[1, 2, 3]
], [ # elec 2
[1, 2, 3],
[1, 2, 3],
[1, 2, 3]
]
]
wf2 = [
[ # elec 1
[1, 2, 3], # spike 1, [sample 1, sample 2, sample 3]
[1, 2, 3], # spike 2
[1, 2, 3], # spike 3
[1, 2, 3] # spike 4
], [ # elec 2
[1, 2, 3], # spike 1
[1, 2, 3], # spike 2
[1, 2, 3], # spike 3
[1, 2, 3] # spike 4
], [ # elec 3
[1, 2, 3], # spike 1
[1, 2, 3], # spike 2
[1, 2, 3], # spike 3
[1, 2, 3] # spike 4
]
]
ut.add_unit(waveforms=wf1)
ut.add_unit(waveforms=wf2)
self.assertEqual(ut.id.data, [0, 1])
self.assertEqual(ut['waveforms'].target.data, [3, 6, 10, 14, 18])
self.assertEqual(ut['waveforms'].data, [2, 5])
self.assertListEqual(ut['waveforms'][0], wf1)
self.assertListEqual(ut['waveforms'][1], wf2)

def test_get_spike_times(self):
ut = Units()
ut.add_unit(spike_times=[0, 1, 2])
Expand Down

0 comments on commit 4ae60bc

Please sign in to comment.