Skip to content

Commit

Permalink
Fix indexing issues when duplicate indices are present. (#235)
Browse files Browse the repository at this point in the history
* Support duplicate indices in fetch method

Enhanced the fetch method to handle and preserve the order of duplicate indices. Added logging to warn users about duplicate requests, ensuring transparency and aiding in debugging.

* Add test for random access with duplicate indices

* remove unnecessary sort

---------

Co-authored-by: Altay Sansal <[email protected]>
  • Loading branch information
tasansal and Altay Sansal authored Nov 20, 2024
1 parent 6d00bfc commit e2f1a25
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 5 deletions.
1 change: 0 additions & 1 deletion docs/tutorials/quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,6 @@
"\n",
"rng = default_rng()\n",
"indices = rng.choice(sgy.num_traces, size=5_000, replace=False)\n",
"indices.sort()\n",
"\n",
"traces = sgy.trace[indices]"
]
Expand Down
22 changes: 18 additions & 4 deletions src/segy/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,10 @@ def __getitem__(self, item: int | list[int] | NDArray[IntDType] | slice) -> Any:
def fetch(self, indices: NDArray[IntDType]) -> NDArray[Any]:
"""Fetches and decodes binary data from the given indices.
It supports duplicates in the indices, and it will also preserve
the order of the request. If you want a sorted order, please sort
the trace indices first.
Args:
indices: A list of integers representing the indices.
Expand All @@ -182,12 +186,22 @@ def fetch(self, indices: NDArray[IntDType]) -> NDArray[Any]:
file specified by the 'url' parameter. However, this is fastest
if minimize the amount of reads. Here we combine starts and
stops that are adjacent to each other. This requires a sort.
- The indices users request may be out of order, so we ensure we
save the index order and then use it to sort the read buffer back
to user's requested shape.
- The fetched data is then decoded and squeezed before being returned.
"""
index_order = np.argsort(indices)
unique_indices, index_order, counts = np.unique(
indices,
return_inverse=True,
return_counts=True,
)

# Warn user about duplicates in the request
if len(indices) != len(unique_indices):
duplicate_mask = counts > 1
values = unique_indices[duplicate_mask]
counts = counts[duplicate_mask]
duplicates = {int(v): int(c) for v, c in zip(values, counts)}
logger.warning("Duplicate indices requested with counts %s:", duplicates)

starts, ends = self.indices_to_byte_ranges(indices)
buffer = merge_cat_file(self.fs, self.url, starts.tolist(), ends.tolist())
array = self.decode(buffer)
Expand Down
6 changes: 6 additions & 0 deletions tests/test_segy_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,12 @@ def test_trace_accessor(
assert_array_equal(traces.header, test_config.expected_headers[index])
assert_array_almost_equal(traces.sample, test_config.expected_samples[index])

# Test random access with duplicates
index = [5, 3, 0, 5, 5, 3]
traces = segy_file.trace[index]
assert_array_equal(traces.header, test_config.expected_headers[index])
assert_array_almost_equal(traces.sample, test_config.expected_samples[index])


class TestSegyFileExceptions:
"""Test exceptions for SegyFile."""
Expand Down

0 comments on commit e2f1a25

Please sign in to comment.