Skip to content

Commit 9cb66ee

Browse files
author
James McKinney
authored
Merge pull request #785 from wireservice/performance
Avoid exponential cleanup algorithm, closes #575
2 parents a2df82c + cf636ba commit 9cb66ee

File tree

2 files changed

+29
-73
lines changed

2 files changed

+29
-73
lines changed

csvkit/cleanup.py

+28-33
Original file line numberDiff line numberDiff line change
@@ -20,23 +20,6 @@ def join_rows(rows, joiner=' '):
2020
return fixed_row
2121

2222

23-
def extract_joinable_row_errors(errs):
24-
joinable = []
25-
26-
for err in reversed(errs):
27-
if type(err) is not LengthMismatchError:
28-
break
29-
30-
if joinable and err.line_number != joinable[-1].line_number - 1:
31-
break
32-
33-
joinable.append(err)
34-
35-
joinable.reverse()
36-
37-
return joinable
38-
39-
4023
class RowChecker(object):
4124
"""
4225
Iterate over rows of a CSV producing cleaned rows and storing error rows.
@@ -45,7 +28,6 @@ class RowChecker(object):
4528
def __init__(self, reader):
4629
self.reader = reader
4730
self.column_names = next(reader)
48-
4931
self.errors = []
5032
self.rows_joined = 0
5133
self.joins = 0
@@ -54,39 +36,52 @@ def checked_rows(self):
5436
"""
5537
A generator which yields rows which are ready to write to output.
5638
"""
39+
length = len(self.column_names)
5740
line_number = self.reader.line_num
41+
joinable_row_errors = []
5842

5943
for row in self.reader:
6044
try:
61-
if len(row) != len(self.column_names):
62-
raise LengthMismatchError(line_number, row, len(self.column_names))
45+
if len(row) != length:
46+
raise LengthMismatchError(line_number, row, length)
6347

6448
yield row
49+
50+
# Don't join rows across valid rows.
51+
joinable_row_errors = []
6552
except LengthMismatchError as e:
6653
self.errors.append(e)
6754

68-
joinable_row_errors = extract_joinable_row_errors(self.errors)
55+
# Don't join with long rows.
56+
if len(row) > length:
57+
joinable_row_errors = []
58+
else:
59+
joinable_row_errors.append(e)
6960

70-
while joinable_row_errors:
71-
fixed_row = join_rows([err.row for err in joinable_row_errors], joiner=' ')
61+
while joinable_row_errors:
62+
fixed_row = join_rows([error.row for error in joinable_row_errors], joiner=' ')
7263

73-
if len(fixed_row) < len(self.column_names):
74-
break
64+
if len(fixed_row) < length:
65+
break
7566

76-
if len(fixed_row) == len(self.column_names):
77-
self.rows_joined += len(joinable_row_errors)
78-
self.joins += 1
67+
if len(fixed_row) == length:
68+
self.rows_joined += len(joinable_row_errors)
69+
self.joins += 1
7970

80-
yield fixed_row
71+
yield fixed_row
8172

82-
for fixed in joinable_row_errors:
83-
self.errors.remove(fixed)
73+
for fixed in joinable_row_errors:
74+
joinable_row_errors.remove(fixed)
75+
self.errors.remove(fixed)
8476

85-
break
77+
break
8678

87-
joinable_row_errors = joinable_row_errors[1:] # keep trying in case we're too long because of a straggler
79+
joinable_row_errors = joinable_row_errors[1:] # keep trying in case we're too long because of a straggler
8880

8981
except CSVTestException as e:
9082
self.errors.append(e)
9183

84+
# Don't join rows across other errors.
85+
joinable_row_errors = []
86+
9287
line_number = self.reader.line_num

tests/test_cleanup.py

+1-40
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
except ImportError:
66
import unittest
77

8-
from csvkit.cleanup import extract_joinable_row_errors, join_rows
8+
from csvkit.cleanup import join_rows
99
from csvkit.exceptions import CSVTestException, LengthMismatchError
1010

1111

@@ -25,45 +25,6 @@ def test_fix_rows(self):
2525
self.assertEqual(" ".join([start[0][-1], start[1][0], start[2][0], start[3][0]]), fixed[2])
2626
self.assertEqual(start[3][1], fixed[3])
2727

28-
def test_extract_joinable_row_errors(self):
29-
e1 = LengthMismatchError(1, ['foo', 'bar', 'baz'], 10)
30-
e2 = LengthMismatchError(2, ['foo', 'bar', 'baz'], 10)
31-
e3 = LengthMismatchError(3, ['foo', 'bar', 'baz'], 10)
32-
errs = [e1, e2, e3]
33-
joinable = extract_joinable_row_errors(errs)
34-
self.assertEqual(3, len(joinable))
35-
for e, j in zip(errs, joinable):
36-
self.assertTrue(e is j)
37-
38-
def test_extract_joinable_row_errors_2(self):
39-
e1 = LengthMismatchError(1, ['foo', 'bar', 'baz'], 10)
40-
e2 = CSVTestException(2, ['foo', 'bar', 'baz'], "A throwaway message.")
41-
e3 = LengthMismatchError(3, ['foo', 'bar', 'baz'], 10)
42-
errs = [e1, e2, e3]
43-
joinable = extract_joinable_row_errors(errs)
44-
self.assertEqual(1, len(joinable))
45-
self.assertTrue(next(iter(joinable)) is e3)
46-
47-
def test_extract_joinable_row_errors_3(self):
48-
e1 = CSVTestException(1, ['foo', 'bar', 'baz'], "A throwaway message.")
49-
e2 = LengthMismatchError(2, ['foo', 'bar', 'baz'], 10)
50-
e3 = LengthMismatchError(3, ['foo', 'bar', 'baz'], 10)
51-
errs = [e1, e2, e3]
52-
joinable = extract_joinable_row_errors(errs)
53-
self.assertEqual(2, len(joinable))
54-
joinable = list(joinable)
55-
self.assertTrue(joinable[0] is e2)
56-
self.assertTrue(joinable[1] is e3)
57-
58-
def test_extract_joinable_row_errors_4(self):
59-
e1 = CSVTestException(1, ['foo', 'bar', 'baz'], "A throwaway message.")
60-
e2 = LengthMismatchError(2, ['foo', 'bar', 'baz'], 10)
61-
e3 = LengthMismatchError(4, ['foo', 'bar', 'baz'], 10)
62-
errs = [e1, e2, e3]
63-
joinable = extract_joinable_row_errors(errs)
64-
self.assertEqual(1, len(joinable))
65-
self.assertTrue(next(iter(joinable)) is e3)
66-
6728
def test_real_world_join_fail(self):
6829
start = [['168772', '1102', '$0.23 TO $0.72', 'HOUR', '1.5%'],
6930
['GROSS', '1.5% '],

0 commit comments

Comments
 (0)