From 18b4396df4c66509da5ae341e99c4e7ed757bba2 Mon Sep 17 00:00:00 2001 From: Christopher Groskopf Date: Tue, 27 Oct 2015 16:42:00 -0500 Subject: [PATCH] Fix two major join bugs. #336. --- CHANGELOG | 2 +- agate/table.py | 10 ++++++---- tests/test_table.py | 35 +++++++++++++++++++++++++++++++++++ 3 files changed, 42 insertions(+), 5 deletions(-) diff --git a/CHANGELOG b/CHANGELOG index f530197b6..4e025a54b 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -1,7 +1,7 @@ 1.1.0 ----- - +* Fixed two major join issues. (#336) 1.0.0 ------ diff --git a/agate/table.py b/agate/table.py index c0fbfb5c8..ebae5498f 100644 --- a/agate/table.py +++ b/agate/table.py @@ -526,10 +526,10 @@ def join(self, right_table, left_key, right_key=None, inner=False): right_hash = {} for i, value in enumerate(right_data): - if value not in []: + if value not in right_hash: right_hash[value] = [] - right_hash[value].append(self._rows[i]) + right_hash[value].append(right_table._rows[i]) # Collect new rows rows = [] @@ -541,13 +541,13 @@ def join(self, right_table, left_key, right_key=None, inner=False): # Iterate over left column for left_index, left_value in enumerate(left_data): - new_row = list(self._rows[left_index]) - matching_rows = right_hash.get(left_value, None) # Rows with matches if matching_rows: for right_row in matching_rows: + new_row = list(self._rows[left_index]) + for k, v in enumerate(right_row): if k == right_key_index: continue @@ -560,6 +560,8 @@ def join(self, right_table, left_key, right_key=None, inner=False): row_names.append(self._row_names[left_index]) # Rows without matches elif not inner: + new_row = list(self._rows[left_index]) + for k, v in enumerate(right_table.column_names): if k == right_key_index: continue diff --git a/tests/test_table.py b/tests/test_table.py index da6a6e1aa..fa1f3024f 100644 --- a/tests/test_table.py +++ b/tests/test_table.py @@ -1005,6 +1005,41 @@ def test_join(self): self.assertSequenceEqual(new_table.rows[1], (2, 3, 'b', 3, 'b')) self.assertSequenceEqual(new_table.rows[2], (None, 2, 'c', 2, 'c')) + def test_join_match_multiple(self): + left_rows = ( + (1, 4, 'a'), + (2, 3, 'b') + ) + + right_rows = ( + (1, 1, 'a'), + (1, 2, 'a'), + (2, 2, 'b') + ) + + left = Table(left_rows, self.left_columns) + right = Table(right_rows, self.right_columns) + new_table = left.join(right, 'one', 'five') + + self.assertEqual(len(new_table.rows), 3) + self.assertEqual(len(new_table.columns), 5) + + self.assertEqual(new_table.columns[0].name, 'one') + self.assertEqual(new_table.columns[1].name, 'two') + self.assertEqual(new_table.columns[2].name, 'three') + self.assertEqual(new_table.columns[3].name, 'four') + self.assertEqual(new_table.columns[4].name, 'six') + + self.assertIsInstance(new_table.columns[0].data_type, Number) + self.assertIsInstance(new_table.columns[1].data_type, Number) + self.assertIsInstance(new_table.columns[2].data_type, Text) + self.assertIsInstance(new_table.columns[3].data_type, Number) + self.assertIsInstance(new_table.columns[4].data_type, Text) + + self.assertSequenceEqual(new_table.rows[0], (1, 4, 'a', 1, 'a')) + self.assertSequenceEqual(new_table.rows[1], (2, 3, 'b', 1, 'a')) + self.assertSequenceEqual(new_table.rows[2], (2, 3, 'b', 2, 'b')) + def test_join2(self): new_table = self.left.join(self.right, 'one', 'five')