Skip to content

Commit 7869350

Browse files
author
Fabian-Robert Stöter
committed
add splitting to contrib
1 parent 516fde7 commit 7869350

File tree

4 files changed

+299
-3
lines changed

4 files changed

+299
-3
lines changed

norbert/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@ class Processor(object):
1818
1919
Parameters
2020
----------
21-
pipeline : list of norbert objects
21+
pipeline : list of norbert modules
2222
2323
"""
2424
def __init__(self, pipeline):
2525
super(Processor, self).__init__()
26-
# set up modules
2726

27+
# set up modules
2828
self.pipeline = pipeline
2929

3030
def forward(self, input):

norbert/contrib.py

+250
Original file line numberDiff line numberDiff line change
@@ -128,3 +128,253 @@ def compress_filter(W, eps, thresh=0.6, slope=15, multichannel=True):
128128
else:
129129
W = _logit(W, thresh, slope)
130130
return W
131+
132+
133+
134+
import numpy as np
135+
import itertools
136+
137+
138+
def splitinfo(sigShape, frameShape, hop):
139+
140+
# making sure input shapes are tuples, not simple integers
141+
if np.isscalar(frameShape):
142+
frameShape = (frameShape,)
143+
if np.isscalar(hop):
144+
hop = (hop,)
145+
146+
# converting frameShape to array, and building an aligned frameshape,
147+
# which is 1 whenever the frame dimension is not given. For instance, if
148+
# frameShape=(1024,) and sigShape=(10000,2), frameShapeAligned is set
149+
# to (1024,1)
150+
frameShape = np.array(frameShape)
151+
fdim = len(frameShape)
152+
frameShapeAligned = np.append(
153+
frameShape, np.ones(
154+
(len(sigShape) - len(frameShape)))).astype(int)
155+
156+
# same thing for hop
157+
hop = np.array(hop)
158+
hop = np.append(hop, np.ones((len(sigShape) - len(hop)))).astype(int)
159+
160+
# building the positions of the frames. For each dimension, gridding from
161+
# 0 to sigShape[dim] every hop[dim]
162+
framesPos = np.ogrid[[slice(0, size, step)
163+
for (size, step) in zip(sigShape, hop)]]
164+
165+
# number of dimensions
166+
nDim = len(framesPos)
167+
168+
# now making sure we have at most one frame going out of the signal. This
169+
# is possible, for instance if the overlap is very large between the frames
170+
for dim in range(nDim):
171+
# for each dimension, we remove all frames that go beyond the signal
172+
framesPos[dim] = framesPos[dim][
173+
np.nonzero(
174+
np.add(
175+
framesPos[dim],
176+
frameShapeAligned[dim]) < sigShape[dim])]
177+
# are there frames positions left in this dimension ?
178+
if len(framesPos[dim]):
179+
# yes. we then add a last frame (the one going beyond the signal),
180+
# if it is possible. (it may NOT be possible in some exotic cases
181+
# such as hopSize[dim]>1 and frameShapeAligned[dim]==1
182+
if framesPos[dim][-1] + hop[dim] < sigShape[dim]:
183+
framesPos[dim] = np.append(
184+
framesPos[dim], framesPos[dim][-1] + hop[dim])
185+
else:
186+
# if there is no more frames in this dimension (short signal in
187+
# this dimension), then at least consider 0
188+
framesPos[dim] = [0]
189+
190+
# constructing the shape of the framed signal
191+
framedShape = np.append(frameShape, [len(x) for x in framesPos])
192+
return (framesPos, framedShape, frameShape, hop,
193+
fdim, nDim, frameShapeAligned)
194+
195+
196+
def split(sig, frames_shape, hop, weight_frames=False, verbose=False):
197+
"""splits a ndarray into overlapping frames
198+
sig : ndarray
199+
frameShape : tuple giving the size of each frame. If its shape is
200+
smaller than that of sig, assume the frame is of size 1
201+
for all missing dimensions
202+
hop : tuple giving the hopsize in each dimension. If its shape is
203+
smaller than that of sig, assume the hopsize is 1 for all
204+
missing dimensions
205+
weightFrames : return frames weighted by a ND hamming window
206+
verbose : whether to output progress during computation"""
207+
208+
# signal shape
209+
sigShape = np.array(sig.shape)
210+
211+
(framesPos, framedShape, frameShape,
212+
hop, fdim, nDim, frameShapeAligned) = splitinfo(
213+
sigShape, frames_shape, hop)
214+
215+
if weight_frames:
216+
# constructing the weighting window. Choosing hamming for convenience
217+
# (never 0)
218+
win = 1
219+
for dim in range(len(frameShape) - 1, -1, -1):
220+
win = np.outer(np.hamming(frameShapeAligned[dim]), win)
221+
win = np.squeeze(win)
222+
223+
# alocating memory for framed signal
224+
framed = np.zeros(framedShape, dtype=sig.dtype)
225+
226+
# total number of frames (for displaying)
227+
nFrames = np.prod([len(x) for x in framesPos])
228+
229+
# for each frame
230+
for iframe, index in enumerate(
231+
itertools.product(*[range(len(x)) for x in framesPos])):
232+
# display from time to time if asked for
233+
if verbose and (not iframe % 100):
234+
print('Splitting : frame ' + str(iframe) + '/' + str(nFrames))
235+
236+
# build the slice to use for extracting the signal of this frame.
237+
frameRange = [Ellipsis]
238+
for dim in range(nDim):
239+
frameRange += [slice(framesPos[dim][index[dim]],
240+
min(sigShape[dim],
241+
framesPos[dim][index[dim]]
242+
+ frameShapeAligned[dim]),
243+
1)]
244+
245+
# extract the signal
246+
sigFrame = sig[tuple(frameRange)]
247+
sigFrame.shape = sigFrame.shape[:fdim]
248+
249+
# the signal may be shorter than the normal size of a frame (at the
250+
# end of the signal). We build a slice that corresponds to the actual
251+
# size we got here
252+
sigFrameRange = [slice(0, x, 1) for x in sigFrame.shape[:fdim]]
253+
254+
# puts the signal in the output variable
255+
framed[tuple(sigFrameRange + list(index))] = sigFrame
256+
257+
if weight_frames:
258+
# multiply by the weighting window
259+
framed[(Ellipsis,) + tuple(index)] *= win
260+
261+
frameShape = [int(x) for x in frameShape]
262+
return framed
263+
264+
265+
def overlapadd(S, fdim, hop, shape=None, weighted_frames=True, verbose=False):
266+
"""n-dimensional overlap-add
267+
S : ndarray containing the stft to be inverted
268+
fdim : the number of dimensions in S corresponding to
269+
frame indices.
270+
hop : tuple containing hopsizes along dimensions.
271+
Missing hopsizes are assumed to be 1
272+
shape: Indicating the original shape of the
273+
signal for truncating. If None: no truncating is done
274+
weightedFrames: True if we need to compensate for the analysis weighting
275+
(weightFrames of the split function)
276+
verbose: whether or not to display progress
277+
"""
278+
279+
# number of dimensions
280+
nDim = len(S.shape)
281+
282+
frameShape = S.shape[:fdim]
283+
trueFrameShape = np.append(
284+
frameShape,
285+
np.ones(
286+
(nDim - len(frameShape)))).astype(int)
287+
288+
# same thing for hop
289+
if np.isscalar(hop):
290+
hop = (hop,)
291+
hop = np.array(hop)
292+
hop = np.append(hop, np.ones((nDim - len(hop)))).astype(int)
293+
294+
sigShape = [
295+
(nframedim - 1) * hopdim + frameshapedim for (
296+
nframedim,
297+
hopdim,
298+
frameshapedim) in zip(S.shape[fdim:], hop, trueFrameShape)]
299+
300+
# building the positions of the frames. For each dimension, gridding from
301+
# 0 to sigShape[dim] every hop[dim]
302+
framesPos = [np.arange(size) * step for (size, step)
303+
in zip(S.shape[fdim:], hop)]
304+
305+
# constructing the weighting window. Choosing hamming for convenience
306+
# (never 0)
307+
win = np.array(1)
308+
for dim in range(fdim):
309+
if trueFrameShape[dim] == 1:
310+
win = win[..., None]
311+
else:
312+
key = ((None,) * len(win.shape) + (Ellipsis,))
313+
win = (win[..., None]
314+
* np.hamming(trueFrameShape[dim]).__getitem__(key))
315+
316+
# if we need to compensate for analysis weighting, simply square window
317+
if weighted_frames:
318+
win2 = win ** 2
319+
else:
320+
win2 = win
321+
322+
sig = np.zeros(sigShape, dtype=S.dtype)
323+
324+
# will also store the sum of all weighting windows applied during
325+
# overlap and add. Traditionally, window function and overlap are chosen
326+
# so that these weights end up being 1 everywhere. However, we here are
327+
# not restricted here to any particular hopsize. Hence, the price to pay
328+
# is this further memory burden
329+
weights = np.zeros(sigShape)
330+
331+
# total number of frames (for displaying)
332+
nFrames = np.prod(S.shape[fdim:])
333+
334+
# could use memmap or stuff
335+
S *= win[tuple([Ellipsis] + [None] * (len(S.shape) - len(win.shape)))]
336+
337+
# for each frame
338+
for iframe, index in enumerate(
339+
itertools.product(*[range(len(x)) for x in framesPos])):
340+
# display from time to time if asked for
341+
if verbose and (not iframe % 100):
342+
print('overlap-add : frame ' + str(iframe) + '/' + str(nFrames))
343+
344+
# build the slice to use for overlap-adding the signal of this frame.
345+
frameRange = [Ellipsis]
346+
for dim in range(nDim-fdim):
347+
frameRange += [slice(framesPos[dim][index[dim]],
348+
min(sigShape[dim],
349+
framesPos[dim][index[dim]]
350+
+ trueFrameShape[dim]),
351+
1)]
352+
353+
# put back the reconstructed weighted frame into place
354+
frameSig = S[tuple([Ellipsis] + list(index))]
355+
sig[tuple(frameRange)] += frameSig[
356+
tuple([Ellipsis] +
357+
[None] *
358+
(len(sig[tuple(frameRange)].shape) -
359+
len(frameSig.shape)))]
360+
361+
# also store the corresponding window contribution
362+
weights[tuple(frameRange)] += win2[
363+
tuple([Ellipsis] +
364+
[None] *
365+
(len(weights[tuple(frameRange)].shape) -
366+
len(win2.shape)))]
367+
368+
# account for different weighting at different places
369+
sig /= weights
370+
371+
# truncate the signal if asked for
372+
if shape is not None:
373+
sig_res = np.zeros(shape, S.dtype)
374+
truncateRange = [slice(0, min(x, sig.shape[i]), 1)
375+
for (i, x) in enumerate(shape)]
376+
sig_res[tuple(truncateRange)] = sig[tuple(truncateRange)]
377+
sig = sig_res
378+
379+
# finished
380+
return sig

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
setup(
1717
name='norbert',
18-
version='0.1.0b',
18+
version='0.1.0c',
1919
description='Spectrogram Models',
2020
long_description=long_description,
2121
long_description_content_type='text/markdown',

tests/test_contrib.py

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import numpy as np
2+
import pytest
3+
from norbert.contrib import split, overlapadd
4+
5+
6+
@pytest.fixture(params=[100, 256, 1001])
7+
def nb_frames(request):
8+
return int(request.param)
9+
10+
11+
@pytest.fixture(params=[1024, 777])
12+
def nb_bins(request):
13+
return request.param
14+
15+
16+
@pytest.fixture(params=[1, 2])
17+
def nb_channels(request):
18+
return request.param
19+
20+
21+
@pytest.fixture(params=[np.float])
22+
def dtype(request):
23+
return request.param
24+
25+
26+
@pytest.fixture(params=[1, 2, 3])
27+
def test_len(request, nb_frames):
28+
return int(nb_frames/request.param)
29+
30+
31+
@pytest.fixture(params=[1, 2, 3])
32+
def test_hop(request, test_len):
33+
return int(test_len/request.param)
34+
35+
36+
@pytest.fixture
37+
def X(request, nb_frames, nb_bins, nb_channels, dtype):
38+
np.random.seed(0)
39+
X = np.random.random((nb_frames, nb_bins, nb_channels)).astype(dtype)
40+
return X
41+
42+
43+
def test_split(X, test_len, test_hop):
44+
patches = split(X, test_len, test_hop)
45+
X_out = overlapadd(patches, 1, test_len, X.shape)
46+
assert np.allclose(X, X_out)

0 commit comments

Comments
 (0)