Skip to content

Commit d300ba5

Browse files
authored
Improve Auto Interpretation Performance (#814)
* cythonize merge_plateaus * add a threshold for plateau count to prevent running forever
1 parent 217a814 commit d300ba5

File tree

3 files changed

+55
-41
lines changed

3 files changed

+55
-41
lines changed

src/urh/ainterpretation/AutoInterpretation.py

+2-23
Original file line numberDiff line numberDiff line change
@@ -280,28 +280,7 @@ def merge_plateau_lengths(plateau_lengths, tolerance=None) -> list:
280280
if tolerance == 0 or tolerance is None:
281281
return plateau_lengths
282282

283-
result = []
284-
if len(plateau_lengths) == 0:
285-
return result
286-
287-
if plateau_lengths[0] <= tolerance:
288-
result.append(0)
289-
290-
i = 0
291-
while i < len(plateau_lengths):
292-
if plateau_lengths[i] <= tolerance:
293-
# Look forward to see if we need to merge a larger window e.g. for 67, 1, 10, 1, 21
294-
n = 2
295-
while i + n < len(plateau_lengths) and plateau_lengths[i + n] <= tolerance:
296-
n += 2
297-
298-
result[-1] = sum(plateau_lengths[max(i - 1, 0):i + n])
299-
i += n
300-
else:
301-
result.append(plateau_lengths[i])
302-
i += 1
303-
304-
return result
283+
return c_auto_interpretation.merge_plateaus(plateau_lengths, tolerance, max_count=10000)
305284

306285

307286
def round_plateau_lengths(plateau_lengths: list):
@@ -343,8 +322,8 @@ def get_bit_length_from_plateau_lengths(merged_plateau_lengths) -> int:
343322
return int(merged_plateau_lengths[0])
344323

345324
round_plateau_lengths(merged_plateau_lengths)
325+
histogram = c_auto_interpretation.get_threshold_divisor_histogram(merged_plateau_lengths)
346326

347-
histogram = c_auto_interpretation.get_threshold_divisor_histogram(np.array(merged_plateau_lengths, dtype=np.uint64))
348327
if len(histogram) == 0:
349328
return 0
350329
else:

src/urh/cythonext/auto_interpretation.pyx

+41-8
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@ from cpython cimport array
55
import array
66
import cython
77

8+
from cython.parallel import prange
9+
from libc.stdlib cimport malloc, free
10+
from libcpp.algorithm cimport sort
11+
from libc.stdint cimport uint64_t
12+
813
cpdef tuple k_means(float[:] data, unsigned int k=2):
914
cdef float[:] centers = np.empty(k, dtype=np.float32)
1015
cdef list clusters = []
@@ -105,7 +110,7 @@ def segment_messages_from_magnitudes(cython.floating[:] magnitudes, float noise_
105110

106111
return result
107112

108-
cpdef unsigned long long[:] get_threshold_divisor_histogram(unsigned long long[:] plateau_lengths, float threshold=0.2):
113+
cpdef uint64_t[:] get_threshold_divisor_histogram(uint64_t[:] plateau_lengths, float threshold=0.2):
109114
"""
110115
Get a histogram (i.e. count) how many times a value is a threshold divisor for other values in given data
111116
@@ -114,12 +119,10 @@ cpdef unsigned long long[:] get_threshold_divisor_histogram(unsigned long long[:
114119
:param plateau_lengths:
115120
:return:
116121
"""
117-
cdef unsigned long long num_lengths = len(plateau_lengths)
122+
cdef uint64_t i, j, x, y, minimum, maximum, num_lengths = len(plateau_lengths)
118123

119124
cdef np.ndarray[np.uint64_t, ndim=1] histogram = np.zeros(int(np.max(plateau_lengths)) + 1, dtype=np.uint64)
120125

121-
cdef unsigned long long i, j, x, y, minimum, maximum
122-
123126
for i in range(0, num_lengths):
124127
for j in range(i+1, num_lengths):
125128
x = plateau_lengths[i]
@@ -139,6 +142,40 @@ cpdef unsigned long long[:] get_threshold_divisor_histogram(unsigned long long[:
139142

140143
return histogram
141144

145+
cpdef np.ndarray[np.uint64_t, ndim=1] merge_plateaus(np.ndarray[np.uint64_t, ndim=1] plateaus,
146+
uint64_t tolerance,
147+
uint64_t max_count):
148+
cdef uint64_t j, n, L = len(plateaus), current = 0, i = 1, tmp_sum
149+
if L == 0:
150+
return np.zeros(0, dtype=np.uint64)
151+
152+
cdef np.ndarray[np.uint64_t, ndim=1] result = np.empty(L, dtype=np.uint64)
153+
if plateaus[0] <= tolerance:
154+
result[0] = 0
155+
else:
156+
result[0] = plateaus[0]
157+
158+
while i < L and current < max_count:
159+
if plateaus[i] <= tolerance:
160+
# Look ahead to see whether we need to merge a larger window e.g. for 67, 1, 10, 1, 21
161+
n = 2
162+
while i + n < L and plateaus[i + n] <= tolerance:
163+
n += 2
164+
165+
tmp_sum = 0
166+
for j in range(i - 1, i + n):
167+
tmp_sum += plateaus[j]
168+
169+
result[current] = tmp_sum
170+
i += n
171+
else:
172+
current += 1
173+
result[current] = plateaus[i]
174+
i += 1
175+
176+
return result[:current+1]
177+
178+
142179
cpdef np.ndarray[np.uint64_t, ndim=1] get_plateau_lengths(float[:] rect_data, float center, int percentage=25):
143180
if len(rect_data) == 0 or center is None:
144181
return np.array([], dtype=np.uint64)
@@ -171,10 +208,6 @@ cpdef np.ndarray[np.uint64_t, ndim=1] get_plateau_lengths(float[:] rect_data, fl
171208
return np.array(result, dtype=np.uint64)
172209

173210

174-
from cython.parallel import prange
175-
from libc.stdlib cimport malloc, free
176-
from libcpp.algorithm cimport sort
177-
178211
cdef float median(double[:] data, unsigned long start, unsigned long data_len, unsigned int k=3) nogil:
179212
cdef unsigned long i, j
180213

tests/auto_interpretation/test_bit_length_detection.py

+12-10
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
11
import unittest
2+
import numpy as np
23

34
from urh.ainterpretation import AutoInterpretation
45

5-
66
class TestAutoInterpretation(unittest.TestCase):
7+
def __run_merge(self, data):
8+
return list(AutoInterpretation.merge_plateau_lengths(np.array(data, dtype=np.uint64)))
9+
710
def test_merge_plateau_lengths(self):
811
self.assertEqual(AutoInterpretation.merge_plateau_lengths([]), [])
912
self.assertEqual(AutoInterpretation.merge_plateau_lengths([42]), [42])
1013
self.assertEqual(AutoInterpretation.merge_plateau_lengths([100, 100, 100]), [100, 100, 100])
11-
self.assertEqual(AutoInterpretation.merge_plateau_lengths([100, 49, 1, 50, 100]), [100, 100, 100])
12-
self.assertEqual(AutoInterpretation.merge_plateau_lengths([100, 48, 2, 50, 100]), [100, 100, 100])
13-
self.assertEqual(AutoInterpretation.merge_plateau_lengths([100, 100, 67, 1, 10, 1, 21]), [100, 100, 100])
14-
self.assertEqual(AutoInterpretation.merge_plateau_lengths([100, 100, 67, 1, 10, 1, 21, 100, 50, 1, 49]),
15-
[100, 100, 100, 100, 100])
14+
self.assertEqual(self.__run_merge([100, 49, 1, 50, 100]), [100, 100, 100])
15+
self.assertEqual(self.__run_merge([100, 48, 2, 50, 100]), [100, 100, 100])
16+
self.assertEqual(self.__run_merge([100, 100, 67, 1, 10, 1, 21]), [100, 100, 100])
17+
self.assertEqual(self.__run_merge([100, 100, 67, 1, 10, 1, 21, 100, 50, 1, 49]), [100, 100, 100, 100, 100])
1618

1719
def test_estimate_tolerance_from_plateau_lengths(self):
1820
self.assertEqual(AutoInterpretation.estimate_tolerance_from_plateau_lengths([]), None)
@@ -34,19 +36,19 @@ def test_tolerant_greatest_common_divisor(self):
3436
def test_get_bit_length_from_plateau_length(self):
3537
self.assertEqual(AutoInterpretation.get_bit_length_from_plateau_lengths([]), 0)
3638
self.assertEqual(AutoInterpretation.get_bit_length_from_plateau_lengths([42]), 42)
37-
plateau_lengths = [2, 1, 2, 73, 1, 26, 100, 40, 1, 59, 100, 47, 1, 52, 67, 1, 10, 1, 21, 33, 1, 66, 100, 5, 1, 3, 1, 48, 1, 27, 1, 8]
39+
plateau_lengths = np.array([2, 1, 2, 73, 1, 26, 100, 40, 1, 59, 100, 47, 1, 52, 67, 1, 10, 1, 21, 33, 1, 66, 100, 5, 1, 3, 1, 48, 1, 27, 1, 8], dtype=np.uint64)
3840
merged_lengths = AutoInterpretation.merge_plateau_lengths(plateau_lengths)
3941
self.assertEqual(AutoInterpretation.get_bit_length_from_plateau_lengths(merged_lengths), 100)
4042

4143

42-
plateau_lengths = [1, 292, 331, 606, 647, 286, 645, 291, 334, 601, 339, 601, 338, 602, 337, 603, 338, 604, 336, 605, 337, 600, 338, 605, 646]
44+
plateau_lengths = np.array([1, 292, 331, 606, 647, 286, 645, 291, 334, 601, 339, 601, 338, 602, 337, 603, 338, 604, 336, 605, 337, 600, 338, 605, 646], dtype=np.uint64)
4345
merged_lengths = AutoInterpretation.merge_plateau_lengths(plateau_lengths)
4446
self.assertEqual(AutoInterpretation.get_bit_length_from_plateau_lengths(merged_lengths), 300)
4547

46-
plateau_lengths = [3, 8, 8, 8, 8, 8, 8, 8, 8, 16, 8, 8, 16, 32, 8, 8, 8, 8, 8, 24, 8, 24, 8, 24, 8, 24, 8, 24, 16, 16, 24, 8]
48+
plateau_lengths = np.array([3, 8, 8, 8, 8, 8, 8, 8, 8, 16, 8, 8, 16, 32, 8, 8, 8, 8, 8, 24, 8, 24, 8, 24, 8, 24, 8, 24, 16, 16, 24, 8], dtype=np.uint64)
4749
merged_lengths = AutoInterpretation.merge_plateau_lengths(plateau_lengths)
4850
self.assertEqual(AutoInterpretation.get_bit_length_from_plateau_lengths(merged_lengths), 8)
4951

5052
def test_get_bit_length_from_merged_plateau_lengths(self):
51-
merged_lengths = [40, 40, 40, 40, 40, 30, 50, 30, 90, 40, 40, 80, 160, 30, 50, 30]
53+
merged_lengths = np.array([40, 40, 40, 40, 40, 30, 50, 30, 90, 40, 40, 80, 160, 30, 50, 30], dtype=np.uint64)
5254
self.assertEqual(AutoInterpretation.get_bit_length_from_plateau_lengths(merged_lengths), 40)

0 commit comments

Comments
 (0)