Skip to content

Commit

Permalink
Fix two major join bugs. wireservice#336.
Browse files Browse the repository at this point in the history
  • Loading branch information
onyxfish committed Oct 27, 2015
1 parent 51fabce commit 18b4396
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 5 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
1.1.0
-----


* Fixed two major join issues. (#336)

1.0.0
------
Expand Down
10 changes: 6 additions & 4 deletions agate/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,10 +526,10 @@ def join(self, right_table, left_key, right_key=None, inner=False):
right_hash = {}

for i, value in enumerate(right_data):
if value not in []:
if value not in right_hash:
right_hash[value] = []

right_hash[value].append(self._rows[i])
right_hash[value].append(right_table._rows[i])

# Collect new rows
rows = []
Expand All @@ -541,13 +541,13 @@ def join(self, right_table, left_key, right_key=None, inner=False):

# Iterate over left column
for left_index, left_value in enumerate(left_data):
new_row = list(self._rows[left_index])

matching_rows = right_hash.get(left_value, None)

# Rows with matches
if matching_rows:
for right_row in matching_rows:
new_row = list(self._rows[left_index])

for k, v in enumerate(right_row):
if k == right_key_index:
continue
Expand All @@ -560,6 +560,8 @@ def join(self, right_table, left_key, right_key=None, inner=False):
row_names.append(self._row_names[left_index])
# Rows without matches
elif not inner:
new_row = list(self._rows[left_index])

for k, v in enumerate(right_table.column_names):
if k == right_key_index:
continue
Expand Down
35 changes: 35 additions & 0 deletions tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,6 +1005,41 @@ def test_join(self):
self.assertSequenceEqual(new_table.rows[1], (2, 3, 'b', 3, 'b'))
self.assertSequenceEqual(new_table.rows[2], (None, 2, 'c', 2, 'c'))

def test_join_match_multiple(self):
left_rows = (
(1, 4, 'a'),
(2, 3, 'b')
)

right_rows = (
(1, 1, 'a'),
(1, 2, 'a'),
(2, 2, 'b')
)

left = Table(left_rows, self.left_columns)
right = Table(right_rows, self.right_columns)
new_table = left.join(right, 'one', 'five')

self.assertEqual(len(new_table.rows), 3)
self.assertEqual(len(new_table.columns), 5)

self.assertEqual(new_table.columns[0].name, 'one')
self.assertEqual(new_table.columns[1].name, 'two')
self.assertEqual(new_table.columns[2].name, 'three')
self.assertEqual(new_table.columns[3].name, 'four')
self.assertEqual(new_table.columns[4].name, 'six')

self.assertIsInstance(new_table.columns[0].data_type, Number)
self.assertIsInstance(new_table.columns[1].data_type, Number)
self.assertIsInstance(new_table.columns[2].data_type, Text)
self.assertIsInstance(new_table.columns[3].data_type, Number)
self.assertIsInstance(new_table.columns[4].data_type, Text)

self.assertSequenceEqual(new_table.rows[0], (1, 4, 'a', 1, 'a'))
self.assertSequenceEqual(new_table.rows[1], (2, 3, 'b', 1, 'a'))
self.assertSequenceEqual(new_table.rows[2], (2, 3, 'b', 2, 'b'))

def test_join2(self):
new_table = self.left.join(self.right, 'one', 'five')

Expand Down

0 comments on commit 18b4396

Please sign in to comment.