Skip to content

Commit

Permalink
Implement case-insensitive key comparison for csvjoin
Browse files Browse the repository at this point in the history
  • Loading branch information
anorth committed May 24, 2016
1 parent ffe5f44 commit a4f93c7
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 26 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ docs/_build
.coverage
.tox
cover
env
84 changes: 65 additions & 19 deletions csvkit/join.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
#!/usr/bin/env python


def _get_ordered_keys(rows, column_index):
def _get_keys(rows, column_index, lowercase=False):
"""
Get ordered keys from rows, given the key column index.
Get keys from rows as keys in a dictionary (i.e. unordered), given the key column index.
"""
return [r[column_index] for r in rows]
pairs = ((r[column_index], True) for r in rows)
return CaseInsensitiveDict(pairs) if lowercase else dict(pairs)


def _get_mapped_keys(rows, column_index):
mapped_keys = {}
def _get_mapped_keys(rows, column_index, case_insensitive=False):
mapped_keys = CaseInsensitiveDict() if case_insensitive else {}

for r in rows:
key = r[column_index]
Expand All @@ -21,6 +22,11 @@ def _get_mapped_keys(rows, column_index):

return mapped_keys

def _lower(key):
"""Transforms a string to lowercase, leaves other types alone."""
keyfn = getattr(key, 'lower', None)
return keyfn() if keyfn else key


def sequential_join(left_rows, right_rows, header=True):
"""
Expand Down Expand Up @@ -49,7 +55,7 @@ def sequential_join(left_rows, right_rows, header=True):
return output


def inner_join(left_rows, left_column_id, right_rows, right_column_id, header=True):
def inner_join(left_rows, left_column_id, right_rows, right_column_id, header=True, ignore_case=False):
"""
Execute an inner join on two tables and return the combined table.
"""
Expand All @@ -63,7 +69,7 @@ def inner_join(left_rows, left_column_id, right_rows, right_column_id, header=Tr
output = []

# Map right rows to keys
right_mapped_keys = _get_mapped_keys(right_rows, right_column_id)
right_mapped_keys = _get_mapped_keys(right_rows, right_column_id, ignore_case)

for left_row in left_rows:
len_left_row = len(left_row)
Expand All @@ -80,7 +86,7 @@ def inner_join(left_rows, left_column_id, right_rows, right_column_id, header=Tr
return output


def full_outer_join(left_rows, left_column_id, right_rows, right_column_id, header=True):
def full_outer_join(left_rows, left_column_id, right_rows, right_column_id, header=True, ignore_case=False):
"""
Execute full outer join on two tables and return the combined table.
"""
Expand All @@ -94,11 +100,11 @@ def full_outer_join(left_rows, left_column_id, right_rows, right_column_id, head
else:
output = []

# Get ordered keys
left_ordered_keys = _get_ordered_keys(left_rows, left_column_id)
# Get left keys
left_keys = _get_keys(left_rows, left_column_id, ignore_case)

# Get mapped keys
right_mapped_keys = _get_mapped_keys(right_rows, right_column_id)
right_mapped_keys = _get_mapped_keys(right_rows, right_column_id, ignore_case)

for left_row in left_rows:
len_left_row = len(left_row)
Expand All @@ -116,13 +122,13 @@ def full_outer_join(left_rows, left_column_id, right_rows, right_column_id, head
for right_row in right_rows:
right_key = right_row[right_column_id]

if right_key not in left_ordered_keys:
if right_key not in left_keys:
output.append(([u''] * len_left_headers) + right_row)

return output


def left_outer_join(left_rows, left_column_id, right_rows, right_column_id, header=True):
def left_outer_join(left_rows, left_column_id, right_rows, right_column_id, header=True, ignore_case=False):
"""
Execute left outer join on two tables and return the combined table.
"""
Expand All @@ -137,7 +143,7 @@ def left_outer_join(left_rows, left_column_id, right_rows, right_column_id, head
output = []

# Get mapped keys
right_mapped_keys = _get_mapped_keys(right_rows, right_column_id)
right_mapped_keys = _get_mapped_keys(right_rows, right_column_id, ignore_case)

for left_row in left_rows:
len_left_row = len(left_row)
Expand All @@ -155,7 +161,7 @@ def left_outer_join(left_rows, left_column_id, right_rows, right_column_id, head
return output


def right_outer_join(left_rows, left_column_id, right_rows, right_column_id, header=True):
def right_outer_join(left_rows, left_column_id, right_rows, right_column_id, header=True, ignore_case=False):
"""
Execute right outer join on two tables and return the combined table.
"""
Expand All @@ -168,11 +174,11 @@ def right_outer_join(left_rows, left_column_id, right_rows, right_column_id, hea
else:
output = []

# Get ordered keys
left_ordered_keys = _get_ordered_keys(left_rows, left_column_id)
# Get left keys
left_keys = _get_keys(left_rows, left_column_id, ignore_case)

# Get mapped keys
right_mapped_keys = _get_mapped_keys(right_rows, right_column_id)
right_mapped_keys = _get_mapped_keys(right_rows, right_column_id, ignore_case)

for left_row in left_rows:
len_left_row = len(left_row)
Expand All @@ -188,7 +194,47 @@ def right_outer_join(left_rows, left_column_id, right_rows, right_column_id, hea
for right_row in right_rows:
right_key = right_row[right_column_id]

if right_key not in left_ordered_keys:
if right_key not in left_keys:
output.append(([u''] * len_left_headers) + right_row)

return output



class CaseInsensitiveDict(dict):
"""
Adapted from http://stackoverflow.com/a/32888599/1583437
"""
def __init__(self, *args, **kwargs):
super(CaseInsensitiveDict, self).__init__(*args, **kwargs)
self._convert_keys()

def __getitem__(self, key):
return super(CaseInsensitiveDict, self).__getitem__(_lower(key))

def __setitem__(self, key, value):
super(CaseInsensitiveDict, self).__setitem__(_lower(key), value)

def __delitem__(self, key):
return super(CaseInsensitiveDict, self).__delitem__(_lower(key))

def __contains__(self, key):
return super(CaseInsensitiveDict, self).__contains__(_lower(key))

def pop(self, key, *args, **kwargs):
return super(CaseInsensitiveDict, self).pop(_lower(key), *args, **kwargs)

def get(self, key, *args, **kwargs):
return super(CaseInsensitiveDict, self).get(_lower(key), *args, **kwargs)

def setdefault(self, key, *args, **kwargs):
return super(CaseInsensitiveDict, self).setdefault(_lower(key), *args, **kwargs)

def update(self, single_arg=None, **kwargs):
super(CaseInsensitiveDict, self).update(self.__class__(single_arg))
super(CaseInsensitiveDict, self).update(self.__class__(**kwargs))

def _convert_keys(self):
for k in list(self.keys()):
v = super(CaseInsensitiveDict, self).pop(k)
self.__setitem__(k, v)
11 changes: 7 additions & 4 deletions csvkit/utilities/csvjoin.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def add_arguments(self):
help='Perform a left outer join, rather than the default inner join. If more than two files are provided this will be executed as a sequence of left outer joins, starting at the left.')
self.argparser.add_argument('--right', dest='right_join', action='store_true',
help='Perform a right outer join, rather than the default inner join. If more than two files are provided this will be executed as a sequence of right outer joins, starting at the right.')
self.argparser.add_argument('--ignorecase', dest='ignore_case', action='store_true',
help='Whether to ignore string case when comparing keys.')

def main(self):
self.input_files = []
Expand Down Expand Up @@ -62,10 +64,11 @@ def main(self):

jointab = tables[0]

ignore_case = self.args.ignore_case
if self.args.left_join:
# Left outer join
for i, t in enumerate(tables[1:]):
jointab = join.left_outer_join(jointab, join_column_ids[0], t, join_column_ids[i + 1], header=header)
jointab = join.left_outer_join(jointab, join_column_ids[0], t, join_column_ids[i + 1], header=header, ignore_case=ignore_case)
elif self.args.right_join:
# Right outer join
jointab = tables[-1]
Expand All @@ -74,15 +77,15 @@ def main(self):
remaining_tables.reverse()

for i, t in enumerate(remaining_tables):
jointab = join.right_outer_join(t, join_column_ids[-(i + 2)], jointab, join_column_ids[-1], header=header)
jointab = join.right_outer_join(t, join_column_ids[-(i + 2)], jointab, join_column_ids[-1], header=header, ignore_case=ignore_case)
elif self.args.outer_join:
# Full outer join
for i, t in enumerate(tables[1:]):
jointab = join.full_outer_join(jointab, join_column_ids[0], t, join_column_ids[i + 1], header=header)
jointab = join.full_outer_join(jointab, join_column_ids[0], t, join_column_ids[i + 1], header=header, ignore_case=ignore_case)
elif self.args.columns:
# Inner join
for i, t in enumerate(tables[1:]):
jointab = join.inner_join(jointab, join_column_ids[0], t, join_column_ids[i + 1], header=header)
jointab = join.inner_join(jointab, join_column_ids[0], t, join_column_ids[i + 1], header=header, ignore_case=ignore_case)
else:
# Sequential join
for t in tables[1:]:
Expand Down
38 changes: 35 additions & 3 deletions tests/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,23 @@ def setUp(self):
[u'1', u'second', u'0'],
[u'2', u'only', u'0', u'0']] # Note extra value in this column

def test_get_ordered_keys(self):
self.assertEqual(join._get_ordered_keys(self.tab1[1:], 0), [u'1', u'2', u'3', u'1'])
self.assertEqual(join._get_ordered_keys(self.tab2[1:], 0), [u'1', u'4', u'1', u'2'])
def test_get_keys(self):
self.assertEqual(join._get_keys(self.tab1[1:], 0).keys(), set([u'1', u'2', u'3', u'1']))
self.assertEqual(join._get_keys(self.tab2[1:], 0).keys(), set([u'1', u'4', u'1', u'2']))

def test_get_mapped_keys(self):
self.assertEqual(join._get_mapped_keys(self.tab1[1:], 0), {
u'1': [[u'1', u'Chicago Reader', u'first'], [u'1', u'Chicago Reader', u'second']],
u'2': [[u'2', u'Chicago Sun-Times', u'only']],
u'3': [[u'3', u'Chicago Tribune', u'only']]})

def test_get_mapped_keys_ignore_case(self):
mapped_keys = join._get_mapped_keys(self.tab1[1:], 1, case_insensitive=True)
assert u'Chicago Reader' in mapped_keys
assert u'chicago reader' in mapped_keys
assert u'CHICAGO SUN-TIMES' in mapped_keys
assert u'1' not in mapped_keys

def test_sequential_join(self):
self.assertEqual(join.sequential_join(self.tab1, self.tab2), [
['id', 'name', 'i_work_here', 'id', 'age', 'i_work_here'],
Expand Down Expand Up @@ -82,3 +89,28 @@ def test_right_outer_join(self):
[u'1', u'Chicago Reader', u'second', u'1', u'first', u'0'],
[u'1', u'Chicago Reader', u'second', u'1', u'second', u'0'],
[u'', u'', u'', u'4', u'only', u'0']])

def test_right_outer_join_ignore_case(self):
# Right outer join exercises all the case dependencies
tab1 = [
['id', 'name', 'i_work_here'],
[u'a', u'Chicago Reader', u'first'],
[u'b', u'Chicago Sun-Times', u'only'],
[u'c', u'Chicago Tribune', u'only'],
[u'a', u'Chicago Reader', u'second']]

tab2 = [
['id', 'age', 'i_work_here'],
[u'A', u'first', u'0'],
[u'D', u'only', u'0'],
[u'A', u'second', u'0'],
[u'B', u'only', u'0', u'0']] # Note extra value in this column

self.assertEqual(join.right_outer_join(tab1, 0, tab2, 0, ignore_case=True), [
['id', 'name', 'i_work_here', 'id', 'age', 'i_work_here'],
[u'a', u'Chicago Reader', u'first', u'A', u'first', u'0'],
[u'a', u'Chicago Reader', u'first', u'A', u'second', u'0'],
[u'b', u'Chicago Sun-Times', u'only', u'B', u'only', u'0', u'0'],
[u'a', u'Chicago Reader', u'second', u'A', u'first', u'0'],
[u'a', u'Chicago Reader', u'second', u'A', u'second', u'0'],
[u'', u'', u'', u'D', u'only', u'0']])

0 comments on commit a4f93c7

Please sign in to comment.