Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

csvjoin: Implement case-insensitive key comparison #610

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ docs/_build
.coverage
.tox
cover
env
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for a virtualenv, but let me know if you have another convention you strongly prefer.

84 changes: 65 additions & 19 deletions csvkit/join.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
#!/usr/bin/env python


def _get_ordered_keys(rows, column_index):
def _get_keys(rows, column_index, lowercase=False):
"""
Get ordered keys from rows, given the key column index.
Get keys from rows as keys in a dictionary (i.e. unordered), given the key column index.
"""
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that the order of these keys was never used, just __contains__, so this actually speeds things up too.

return [r[column_index] for r in rows]
pairs = ((r[column_index], True) for r in rows)
return CaseInsensitiveDict(pairs) if lowercase else dict(pairs)


def _get_mapped_keys(rows, column_index):
mapped_keys = {}
def _get_mapped_keys(rows, column_index, case_insensitive=False):
mapped_keys = CaseInsensitiveDict() if case_insensitive else {}

for r in rows:
key = r[column_index]
Expand All @@ -21,6 +22,11 @@ def _get_mapped_keys(rows, column_index):

return mapped_keys

def _lower(key):
"""Transforms a string to lowercase, leaves other types alone."""
keyfn = getattr(key, 'lower', None)
return keyfn() if keyfn else key


def sequential_join(left_rows, right_rows, header=True):
"""
Expand Down Expand Up @@ -49,7 +55,7 @@ def sequential_join(left_rows, right_rows, header=True):
return output


def inner_join(left_rows, left_column_id, right_rows, right_column_id, header=True):
def inner_join(left_rows, left_column_id, right_rows, right_column_id, header=True, ignore_case=False):
"""
Execute an inner join on two tables and return the combined table.
"""
Expand All @@ -63,7 +69,7 @@ def inner_join(left_rows, left_column_id, right_rows, right_column_id, header=Tr
output = []

# Map right rows to keys
right_mapped_keys = _get_mapped_keys(right_rows, right_column_id)
right_mapped_keys = _get_mapped_keys(right_rows, right_column_id, ignore_case)

for left_row in left_rows:
len_left_row = len(left_row)
Expand All @@ -80,7 +86,7 @@ def inner_join(left_rows, left_column_id, right_rows, right_column_id, header=Tr
return output


def full_outer_join(left_rows, left_column_id, right_rows, right_column_id, header=True):
def full_outer_join(left_rows, left_column_id, right_rows, right_column_id, header=True, ignore_case=False):
"""
Execute full outer join on two tables and return the combined table.
"""
Expand All @@ -94,11 +100,11 @@ def full_outer_join(left_rows, left_column_id, right_rows, right_column_id, head
else:
output = []

# Get ordered keys
left_ordered_keys = _get_ordered_keys(left_rows, left_column_id)
# Get left keys
left_keys = _get_keys(left_rows, left_column_id, ignore_case)

# Get mapped keys
right_mapped_keys = _get_mapped_keys(right_rows, right_column_id)
right_mapped_keys = _get_mapped_keys(right_rows, right_column_id, ignore_case)

for left_row in left_rows:
len_left_row = len(left_row)
Expand All @@ -116,13 +122,13 @@ def full_outer_join(left_rows, left_column_id, right_rows, right_column_id, head
for right_row in right_rows:
right_key = right_row[right_column_id]

if right_key not in left_ordered_keys:
if right_key not in left_keys:
output.append(([u''] * len_left_headers) + right_row)

return output


def left_outer_join(left_rows, left_column_id, right_rows, right_column_id, header=True):
def left_outer_join(left_rows, left_column_id, right_rows, right_column_id, header=True, ignore_case=False):
"""
Execute left outer join on two tables and return the combined table.
"""
Expand All @@ -137,7 +143,7 @@ def left_outer_join(left_rows, left_column_id, right_rows, right_column_id, head
output = []

# Get mapped keys
right_mapped_keys = _get_mapped_keys(right_rows, right_column_id)
right_mapped_keys = _get_mapped_keys(right_rows, right_column_id, ignore_case)

for left_row in left_rows:
len_left_row = len(left_row)
Expand All @@ -155,7 +161,7 @@ def left_outer_join(left_rows, left_column_id, right_rows, right_column_id, head
return output


def right_outer_join(left_rows, left_column_id, right_rows, right_column_id, header=True):
def right_outer_join(left_rows, left_column_id, right_rows, right_column_id, header=True, ignore_case=False):
"""
Execute right outer join on two tables and return the combined table.
"""
Expand All @@ -168,11 +174,11 @@ def right_outer_join(left_rows, left_column_id, right_rows, right_column_id, hea
else:
output = []

# Get ordered keys
left_ordered_keys = _get_ordered_keys(left_rows, left_column_id)
# Get left keys
left_keys = _get_keys(left_rows, left_column_id, ignore_case)

# Get mapped keys
right_mapped_keys = _get_mapped_keys(right_rows, right_column_id)
right_mapped_keys = _get_mapped_keys(right_rows, right_column_id, ignore_case)

for left_row in left_rows:
len_left_row = len(left_row)
Expand All @@ -188,7 +194,47 @@ def right_outer_join(left_rows, left_column_id, right_rows, right_column_id, hea
for right_row in right_rows:
right_key = right_row[right_column_id]

if right_key not in left_ordered_keys:
if right_key not in left_keys:
output.append(([u''] * len_left_headers) + right_row)

return output



class CaseInsensitiveDict(dict):
"""
Adapted from http://stackoverflow.com/a/32888599/1583437
"""
def __init__(self, *args, **kwargs):
super(CaseInsensitiveDict, self).__init__(*args, **kwargs)
self._convert_keys()

def __getitem__(self, key):
return super(CaseInsensitiveDict, self).__getitem__(_lower(key))

def __setitem__(self, key, value):
super(CaseInsensitiveDict, self).__setitem__(_lower(key), value)

def __delitem__(self, key):
return super(CaseInsensitiveDict, self).__delitem__(_lower(key))

def __contains__(self, key):
return super(CaseInsensitiveDict, self).__contains__(_lower(key))

def pop(self, key, *args, **kwargs):
return super(CaseInsensitiveDict, self).pop(_lower(key), *args, **kwargs)

def get(self, key, *args, **kwargs):
return super(CaseInsensitiveDict, self).get(_lower(key), *args, **kwargs)

def setdefault(self, key, *args, **kwargs):
return super(CaseInsensitiveDict, self).setdefault(_lower(key), *args, **kwargs)

def update(self, single_arg=None, **kwargs):
super(CaseInsensitiveDict, self).update(self.__class__(single_arg))
super(CaseInsensitiveDict, self).update(self.__class__(**kwargs))

def _convert_keys(self):
for k in list(self.keys()):
v = super(CaseInsensitiveDict, self).pop(k)
self.__setitem__(k, v)
11 changes: 7 additions & 4 deletions csvkit/utilities/csvjoin.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def add_arguments(self):
help='Perform a left outer join, rather than the default inner join. If more than two files are provided this will be executed as a sequence of left outer joins, starting at the left.')
self.argparser.add_argument('--right', dest='right_join', action='store_true',
help='Perform a right outer join, rather than the default inner join. If more than two files are provided this will be executed as a sequence of right outer joins, starting at the right.')
self.argparser.add_argument('--ignorecase', dest='ignore_case', action='store_true',
help='Whether to ignore string case when comparing keys.')

def main(self):
self.input_files = []
Expand Down Expand Up @@ -62,10 +64,11 @@ def main(self):

jointab = tables[0]

ignore_case = self.args.ignore_case
if self.args.left_join:
# Left outer join
for i, t in enumerate(tables[1:]):
jointab = join.left_outer_join(jointab, join_column_ids[0], t, join_column_ids[i + 1], header=header)
jointab = join.left_outer_join(jointab, join_column_ids[0], t, join_column_ids[i + 1], header=header, ignore_case=ignore_case)
elif self.args.right_join:
# Right outer join
jointab = tables[-1]
Expand All @@ -74,15 +77,15 @@ def main(self):
remaining_tables.reverse()

for i, t in enumerate(remaining_tables):
jointab = join.right_outer_join(t, join_column_ids[-(i + 2)], jointab, join_column_ids[-1], header=header)
jointab = join.right_outer_join(t, join_column_ids[-(i + 2)], jointab, join_column_ids[-1], header=header, ignore_case=ignore_case)
elif self.args.outer_join:
# Full outer join
for i, t in enumerate(tables[1:]):
jointab = join.full_outer_join(jointab, join_column_ids[0], t, join_column_ids[i + 1], header=header)
jointab = join.full_outer_join(jointab, join_column_ids[0], t, join_column_ids[i + 1], header=header, ignore_case=ignore_case)
elif self.args.columns:
# Inner join
for i, t in enumerate(tables[1:]):
jointab = join.inner_join(jointab, join_column_ids[0], t, join_column_ids[i + 1], header=header)
jointab = join.inner_join(jointab, join_column_ids[0], t, join_column_ids[i + 1], header=header, ignore_case=ignore_case)
else:
# Sequential join
for t in tables[1:]:
Expand Down
38 changes: 35 additions & 3 deletions tests/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,23 @@ def setUp(self):
[u'1', u'second', u'0'],
[u'2', u'only', u'0', u'0']] # Note extra value in this column

def test_get_ordered_keys(self):
self.assertEqual(join._get_ordered_keys(self.tab1[1:], 0), [u'1', u'2', u'3', u'1'])
self.assertEqual(join._get_ordered_keys(self.tab2[1:], 0), [u'1', u'4', u'1', u'2'])
def test_get_keys(self):
self.assertEqual(join._get_keys(self.tab1[1:], 0).keys(), set([u'1', u'2', u'3', u'1']))
self.assertEqual(join._get_keys(self.tab2[1:], 0).keys(), set([u'1', u'4', u'1', u'2']))

def test_get_mapped_keys(self):
self.assertEqual(join._get_mapped_keys(self.tab1[1:], 0), {
u'1': [[u'1', u'Chicago Reader', u'first'], [u'1', u'Chicago Reader', u'second']],
u'2': [[u'2', u'Chicago Sun-Times', u'only']],
u'3': [[u'3', u'Chicago Tribune', u'only']]})

def test_get_mapped_keys_ignore_case(self):
mapped_keys = join._get_mapped_keys(self.tab1[1:], 1, case_insensitive=True)
assert u'Chicago Reader' in mapped_keys
assert u'chicago reader' in mapped_keys
assert u'CHICAGO SUN-TIMES' in mapped_keys
assert u'1' not in mapped_keys

def test_sequential_join(self):
self.assertEqual(join.sequential_join(self.tab1, self.tab2), [
['id', 'name', 'i_work_here', 'id', 'age', 'i_work_here'],
Expand Down Expand Up @@ -82,3 +89,28 @@ def test_right_outer_join(self):
[u'1', u'Chicago Reader', u'second', u'1', u'first', u'0'],
[u'1', u'Chicago Reader', u'second', u'1', u'second', u'0'],
[u'', u'', u'', u'4', u'only', u'0']])

def test_right_outer_join_ignore_case(self):
# Right outer join exercises all the case dependencies
tab1 = [
['id', 'name', 'i_work_here'],
[u'a', u'Chicago Reader', u'first'],
[u'b', u'Chicago Sun-Times', u'only'],
[u'c', u'Chicago Tribune', u'only'],
[u'a', u'Chicago Reader', u'second']]

tab2 = [
['id', 'age', 'i_work_here'],
[u'A', u'first', u'0'],
[u'D', u'only', u'0'],
[u'A', u'second', u'0'],
[u'B', u'only', u'0', u'0']] # Note extra value in this column

self.assertEqual(join.right_outer_join(tab1, 0, tab2, 0, ignore_case=True), [
['id', 'name', 'i_work_here', 'id', 'age', 'i_work_here'],
[u'a', u'Chicago Reader', u'first', u'A', u'first', u'0'],
[u'a', u'Chicago Reader', u'first', u'A', u'second', u'0'],
[u'b', u'Chicago Sun-Times', u'only', u'B', u'only', u'0', u'0'],
[u'a', u'Chicago Reader', u'second', u'A', u'first', u'0'],
[u'a', u'Chicago Reader', u'second', u'A', u'second', u'0'],
[u'', u'', u'', u'D', u'only', u'0']])