Skip to content

Commit

Permalink
[Data] fix RandomAccessDataset.multiget returning unexpected values f…
Browse files Browse the repository at this point in the history
…or missing keys

Signed-off-by: Wu Yufei <[email protected]>
  • Loading branch information
tespent committed Apr 17, 2024
1 parent d8c7234 commit ea62f91
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 10 deletions.
6 changes: 4 additions & 2 deletions python/ray/data/random_access_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,10 @@ def multiget(self, block_indices, keys):
col = block[self.key_field]
indices = np.searchsorted(col, keys)
acc = BlockAccessor.for_block(block)
result = [acc._get_row(i) for i in indices]
# assert result == [self._get(i, k) for i, k in zip(block_indices, keys)]
result = [
acc._get_row(i) if k1.as_py() == k2 else None
for i, k1, k2 in zip(indices, col.take(indices), keys)
]
else:
result = [self._get(i, k) for i, k in zip(block_indices, keys)]
self.total_time += time.perf_counter() - start
Expand Down
18 changes: 10 additions & 8 deletions python/ray/data/tests/test_random_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,28 @@
@pytest.mark.parametrize("pandas", [False, True])
def test_basic(ray_start_regular_shared, pandas):
ds = ray.data.range(100, override_num_blocks=10)
ds = ds.add_column("key", lambda b: b["id"] * 2)
ds = ds.add_column("embedding", lambda b: b["id"] ** 2)
if not pandas:
ds = ds.map_batches(
lambda df: pyarrow.Table.from_pandas(df), batch_format="pandas"
)

rad = ds.to_random_access_dataset("id", num_workers=1)
rad = ds.to_random_access_dataset("key", num_workers=1)

def expected(i):
return {"id": i, "key": i * 2, "embedding": i**2}

# Test get.
assert ray.get(rad.get_async(-1)) is None
assert ray.get(rad.get_async(100)) is None
assert ray.get(rad.get_async(200)) is None
for i in range(100):
assert ray.get(rad.get_async(i)) == {"id": i, "embedding": i**2}

def expected(i):
return {"id": i, "embedding": i**2}
assert ray.get(rad.get_async(i * 2 + 1)) is None
assert ray.get(rad.get_async(i * 2)) == expected(i)

# Test multiget.
results = rad.multiget([-1] + list(range(10)) + [100])
assert results == [None] + [expected(i) for i in range(10)] + [None]
results = rad.multiget([-1] + list(range(0, 20, 2)) + list(range(1, 21, 2)) + [200])
assert results == [None] + [expected(i) for i in range(10)] + [None] * 10 + [None]


def test_empty_blocks(ray_start_regular_shared):
Expand Down

0 comments on commit ea62f91

Please sign in to comment.