Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion cpp/src/arrow/dataset/scanner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ const FieldVector kAugmentedFields{
field("__fragment_index", int32()),
field("__batch_index", int32()),
field("__last_in_fragment", boolean()),
field("__filename", utf8()),
};

// Scan options has a number of options that we can infer from the dataset
Expand Down Expand Up @@ -708,8 +709,12 @@ Result<ProjectionDescr> ProjectionDescr::FromNames(std::vector<std::string> name
for (size_t i = 0; i < exprs.size(); ++i) {
exprs[i] = compute::field_ref(names[i]);
}
auto fields = dataset_schema.fields();
for (const auto& aug_field : kAugmentedFields) {
fields.push_back(aug_field);
}
return ProjectionDescr::FromExpressions(std::move(exprs), std::move(names),
dataset_schema);
Schema(fields, dataset_schema.metadata()));
}

Result<ProjectionDescr> ProjectionDescr::Default(const Schema& dataset_schema) {
Expand Down Expand Up @@ -877,6 +882,7 @@ Result<compute::ExecNode*> MakeScanNode(compute::ExecPlan* plan,
batch->values.emplace_back(partial.fragment.index);
batch->values.emplace_back(partial.record_batch.index);
batch->values.emplace_back(partial.record_batch.last);
batch->values.emplace_back(partial.fragment.value->ToString());
return batch;
});

Expand Down
15 changes: 14 additions & 1 deletion cpp/src/arrow/dataset/scanner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,15 @@ class TestScanner : public DatasetFixtureMixinWithParam<TestScannerParams> {
AssertScanBatchesEquals(expected.get(), scanner.get());
}

void AssertNoAugmentedFields(std::shared_ptr<Scanner> scanner) {
ASSERT_OK_AND_ASSIGN(auto table, scanner.get()->ToTable());
auto columns = table.get()->ColumnNames();
EXPECT_TRUE(std::none_of(columns.begin(), columns.end(), [](std::string& x) {
return x == "__fragment_index" || x == "__batch_index" ||
x == "__last_in_fragment" || x == "__filename";
}));
}

void AssertScanBatchesUnorderedEqualRepetitionsOf(
std::shared_ptr<Scanner> scanner, std::shared_ptr<RecordBatch> batch,
const int64_t total_batches = GetParam().num_child_datasets *
Expand Down Expand Up @@ -257,6 +266,7 @@ TEST_P(TestScanner, ProjectionDefaults) {
options_->projection = literal(true);
options_->projected_schema = nullptr;
AssertScanBatchesEqualRepetitionsOf(MakeScanner(batch_in), batch_in);
AssertNoAugmentedFields(MakeScanner(batch_in));
}
// If we only specify a projection expression then infer the projected schema
// from the projection expression
Expand Down Expand Up @@ -1386,6 +1396,7 @@ DatasetAndBatches DatasetAndBatchesFromJSON(
// ... and with the last-in-fragment flag
batches.back().values.emplace_back(batch_index ==
fragment_batch_strs[fragment_index].size() - 1);
batches.back().values.emplace_back(fragments[fragment_index]->ToString());

// each batch carries a guarantee inherited from its Fragment's partition expression
batches.back().guarantee = fragments[fragment_index]->partition_expression();
Expand Down Expand Up @@ -1472,7 +1483,8 @@ DatasetAndBatches MakeNestedDataset() {
compute::Expression Materialize(std::vector<std::string> names,
bool include_aug_fields = false) {
if (include_aug_fields) {
for (auto aug_name : {"__fragment_index", "__batch_index", "__last_in_fragment"}) {
for (auto aug_name :
{"__fragment_index", "__batch_index", "__last_in_fragment", "__filename"}) {
names.emplace_back(aug_name);
}
}
Expand Down Expand Up @@ -1502,6 +1514,7 @@ TEST(ScanNode, Schema) {
fields.push_back(field("__fragment_index", int32()));
fields.push_back(field("__batch_index", int32()));
fields.push_back(field("__last_in_fragment", boolean()));
fields.push_back(field("__filename", utf8()));
// output_schema is *always* the full augmented dataset schema, regardless of
// projection (but some columns *may* be placeholder null Scalars if not projected)
AssertSchemaEqual(Schema(fields), *scan->output_schema());
Expand Down
37 changes: 31 additions & 6 deletions python/pyarrow/_dataset.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2042,15 +2042,24 @@ cdef class Scanner(_Weakrefable):
dataset : Dataset
Dataset to scan.
columns : list of str or dict, default None
The columns to project. This can be a list of column names to include
(order and duplicates will be preserved), or a dictionary with
{new_column_name: expression} values for more advanced projections.
The columns to project. This can be a list of column names to
include (order and duplicates will be preserved), or a dictionary
with {{new_column_name: expression}} values for more advanced
projections.

The list of columns or expressions may use the special fields
`__batch_index` (the index of the batch within the fragment),
`__fragment_index` (the index of the fragment within the dataset),
`__last_in_fragment` (whether the batch is last in fragment), and
`__filename` (the name of the source file or a description of the
source fragment).

The columns will be passed down to Datasets and corresponding data
fragments to avoid loading, copying, and deserializing columns
that will not be required further down the compute chain.
By default all of the available columns are projected. Raises
an exception if any of the referenced column names does not exist
in the dataset's Schema.
By default all of the available columns are projected.
Raises an exception if any of the referenced column names does
not exist in the dataset's Schema.
filter : Expression, default None
Scan will return only the rows matching the filter.
If possible the predicate will be pushed down to exploit the
Expand Down Expand Up @@ -2111,6 +2120,14 @@ cdef class Scanner(_Weakrefable):
include (order and duplicates will be preserved), or a dictionary
with {new_column_name: expression} values for more advanced
projections.

The list of columns or expressions may use the special fields
`__batch_index` (the index of the batch within the fragment),
`__fragment_index` (the index of the fragment within the dataset),
`__last_in_fragment` (whether the batch is last in fragment), and
`__filename` (the name of the source file or a description of the
source fragment).

The columns will be passed down to Datasets and corresponding data
fragments to avoid loading, copying, and deserializing columns
that will not be required further down the compute chain.
Expand Down Expand Up @@ -2181,6 +2198,14 @@ cdef class Scanner(_Weakrefable):
include (order and duplicates will be preserved), or a dictionary
with {new_column_name: expression} values for more advanced
projections.

The list of columns or expressions may use the special fields
`__batch_index` (the index of the batch within the fragment),
`__fragment_index` (the index of the fragment within the dataset),
`__last_in_fragment` (whether the batch is last in fragment), and
`__filename` (the name of the source file or a description of the
source fragment).

The columns will be passed down to Datasets and corresponding data
fragments to avoid loading, copying, and deserializing columns
that will not be required further down the compute chain.
Expand Down
18 changes: 18 additions & 0 deletions python/pyarrow/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,24 @@ def test_scanner(dataset, dataset_reader):

assert table.num_rows == scanner.count_rows()

scanner = dataset_reader.scanner(dataset, columns=['__filename',
'__fragment_index',
'__batch_index',
'__last_in_fragment'],
memory_pool=pa.default_memory_pool())
table = scanner.to_table()
expected_names = ['__filename', '__fragment_index',
'__batch_index', '__last_in_fragment']
assert table.column_names == expected_names

sorted_table = table.sort_by('__fragment_index')
assert sorted_table['__filename'].to_pylist() == (
['subdir/1/xxx/file0.parquet'] * 5 +
['subdir/2/yyy/file1.parquet'] * 5)
assert sorted_table['__fragment_index'].to_pylist() == ([0] * 5 + [1] * 5)
assert sorted_table['__batch_index'].to_pylist() == [0] * 10
assert sorted_table['__last_in_fragment'].to_pylist() == [True] * 10


@pytest.mark.parquet
def test_scanner_async_deprecated(dataset):
Expand Down