Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 31 additions & 51 deletions cpp/src/arrow/dataset/expression.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ namespace {

std::string PrintDatum(const Datum& datum) {
if (datum.is_scalar()) {
if (!datum.scalar()->is_valid) return "null";

switch (datum.type()->id()) {
case Type::STRING:
case Type::LARGE_STRING:
Expand All @@ -114,6 +116,7 @@ std::string PrintDatum(const Datum& datum) {
default:
break;
}

return datum.scalar()->ToString();
}
return datum.ToString();
Expand Down Expand Up @@ -688,27 +691,27 @@ std::vector<Expression> GuaranteeConjunctionMembers(
// conjunction_members
Status ExtractKnownFieldValuesImpl(
std::vector<Expression>* conjunction_members,
std::unordered_map<FieldRef, KnownFieldValue, FieldRef::Hash>* known_values) {
auto unconsumed_end = std::partition(
conjunction_members->begin(), conjunction_members->end(),
[](const Expression& expr) {
// search for an equality conditions between a field and a literal
auto call = expr.call();
if (!call) return true;

if (call->function_name == "equal") {
auto ref = call->arguments[0].field_ref();
auto lit = call->arguments[1].literal();
return !(ref && lit);
}

if (call->function_name == "is_null" || call->function_name == "is_valid") {
auto ref = call->arguments[0].field_ref();
return !ref;
}

return true;
});
std::unordered_map<FieldRef, Datum, FieldRef::Hash>* known_values) {
auto unconsumed_end =
std::partition(conjunction_members->begin(), conjunction_members->end(),
[](const Expression& expr) {
// search for an equality conditions between a field and a literal
auto call = expr.call();
if (!call) return true;

if (call->function_name == "equal") {
auto ref = call->arguments[0].field_ref();
auto lit = call->arguments[1].literal();
return !(ref && lit);
}

if (call->function_name == "is_null") {
auto ref = call->arguments[0].field_ref();
return !ref;
}

return true;
});

for (auto it = unconsumed_end; it != conjunction_members->end(); ++it) {
auto call = CallNotNull(*it);
Expand All @@ -719,10 +722,7 @@ Status ExtractKnownFieldValuesImpl(
known_values->emplace(*ref, *lit);
} else if (call->function_name == "is_null") {
auto ref = call->arguments[0].field_ref();
known_values->emplace(*ref, false);
} else if (call->function_name == "is_valid") {
auto ref = call->arguments[0].field_ref();
known_values->emplace(*ref, true);
known_values->emplace(*ref, std::make_shared<NullScalar>());
}
}

Expand All @@ -733,16 +733,16 @@ Status ExtractKnownFieldValuesImpl(

} // namespace

Result<std::unordered_map<FieldRef, KnownFieldValue, FieldRef::Hash>>
ExtractKnownFieldValues(const Expression& guaranteed_true_predicate) {
Result<std::unordered_map<FieldRef, Datum, FieldRef::Hash>> ExtractKnownFieldValues(
const Expression& guaranteed_true_predicate) {
auto conjunction_members = GuaranteeConjunctionMembers(guaranteed_true_predicate);
std::unordered_map<FieldRef, KnownFieldValue, FieldRef::Hash> known_values;
std::unordered_map<FieldRef, Datum, FieldRef::Hash> known_values;
RETURN_NOT_OK(ExtractKnownFieldValuesImpl(&conjunction_members, &known_values));
return known_values;
}

Result<Expression> ReplaceFieldsWithKnownValues(
const std::unordered_map<FieldRef, KnownFieldValue, FieldRef::Hash>& known_values,
const std::unordered_map<FieldRef, Datum, FieldRef::Hash>& known_values,
Expression expr) {
if (!expr.IsBound()) {
return Status::Invalid(
Expand All @@ -755,11 +755,7 @@ Result<Expression> ReplaceFieldsWithKnownValues(
if (auto ref = expr.field_ref()) {
auto it = known_values.find(*ref);
if (it != known_values.end()) {
const auto& known_value = it->second;
if (!known_value.concrete()) {
return expr;
}
auto lit = known_value.datum;
Datum lit = it->second;
if (expr.type()->id() == Type::DICTIONARY) {
if (lit.is_scalar()) {
// FIXME the "right" way to support this is adding support for scalars to
Expand All @@ -779,22 +775,6 @@ Result<Expression> ReplaceFieldsWithKnownValues(
ARROW_ASSIGN_OR_RAISE(lit, compute::Cast(lit, expr.type()));
return literal(std::move(lit));
}
} else if (auto call = expr.call()) {
if (call->function_name == "is_null") {
if (auto ref = call->arguments[0].field_ref()) {
auto it = known_values.find(*ref);
if (it != known_values.end()) {
return literal(!it->second.valid);
}
}
} else if (call->function_name == "is_valid") {
if (auto ref = call->arguments[0].field_ref()) {
auto it = known_values.find(*ref);
if (it != known_values.end()) {
return literal(it->second.valid);
}
}
}
}
return expr;
},
Expand Down Expand Up @@ -971,7 +951,7 @@ Result<Expression> SimplifyWithGuarantee(Expression expr,
const Expression& guaranteed_true_predicate) {
auto conjunction_members = GuaranteeConjunctionMembers(guaranteed_true_predicate);

std::unordered_map<FieldRef, KnownFieldValue, FieldRef::Hash> known_values;
std::unordered_map<FieldRef, Datum, FieldRef::Hash> known_values;
RETURN_NOT_OK(ExtractKnownFieldValuesImpl(&conjunction_members, &known_values));

ARROW_ASSIGN_OR_RAISE(expr,
Expand Down
24 changes: 3 additions & 21 deletions cpp/src/arrow/dataset/expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,27 +162,10 @@ Expression call(std::string function, std::vector<Expression> arguments,
ARROW_DS_EXPORT
std::vector<FieldRef> FieldsInExpression(const Expression&);

/// Represents either a concrete value or a hint that a field is valid/invalid
struct KnownFieldValue {
Datum datum;
bool valid;

KnownFieldValue() : datum(), valid(false) {}
KnownFieldValue(const Datum& datum) // NOLINT implicit conversion
: datum(datum), valid(datum.length() != datum.null_count()) {}
KnownFieldValue(bool is_valid) // NOLINT implicit conversion
: datum(), valid(is_valid) {}

inline bool concrete() const { return datum.kind() != Datum::Kind::NONE; }
bool operator==(const KnownFieldValue& other) const {
return datum == other.datum && valid == other.valid;
}
};

/// Assemble a mapping from field references to known values.
ARROW_DS_EXPORT
Result<std::unordered_map<FieldRef, KnownFieldValue, FieldRef::Hash>>
ExtractKnownFieldValues(const Expression& guaranteed_true_predicate);
Result<std::unordered_map<FieldRef, Datum, FieldRef::Hash>> ExtractKnownFieldValues(
const Expression& guaranteed_true_predicate);

/// \defgroup expression-passes Functions for modification of Expressions
///
Expand Down Expand Up @@ -211,8 +194,7 @@ Result<Expression> FoldConstants(Expression);
/// Simplify Expressions by replacing with known values of the fields which it references.
ARROW_DS_EXPORT
Result<Expression> ReplaceFieldsWithKnownValues(
const std::unordered_map<FieldRef, KnownFieldValue, FieldRef::Hash>& known_values,
Expression);
const std::unordered_map<FieldRef, Datum, FieldRef::Hash>& known_values, Expression);

/// Simplify an expression by replacing subexpressions based on a guarantee:
/// a boolean expression which is guaranteed to evaluate to `true`. For example, this is
Expand Down
48 changes: 26 additions & 22 deletions cpp/src/arrow/dataset/expression_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -676,9 +676,8 @@ TEST(Expression, FoldConstantsBoolean) {

TEST(Expression, ExtractKnownFieldValues) {
struct {
void operator()(
Expression guarantee,
std::unordered_map<FieldRef, KnownFieldValue, FieldRef::Hash> expected) {
void operator()(Expression guarantee,
std::unordered_map<FieldRef, Datum, FieldRef::Hash> expected) {
ASSERT_OK_AND_ASSIGN(auto actual, ExtractKnownFieldValues(guarantee));
EXPECT_THAT(actual, UnorderedElementsAreArray(expected))
<< " guarantee: " << guarantee.ToString();
Expand Down Expand Up @@ -726,20 +725,20 @@ TEST(Expression, ExtractKnownFieldValues) {
}

TEST(Expression, ReplaceFieldsWithKnownValues) {
auto ExpectReplacesTo = [](Expression expr,
const std::unordered_map<FieldRef, KnownFieldValue,
FieldRef::Hash>& known_values,
Expression unbound_expected) {
ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*kBoringSchema));
ASSERT_OK_AND_ASSIGN(auto expected, unbound_expected.Bind(*kBoringSchema));
ASSERT_OK_AND_ASSIGN(auto replaced, ReplaceFieldsWithKnownValues(known_values, expr));
auto ExpectReplacesTo =
[](Expression expr,
const std::unordered_map<FieldRef, Datum, FieldRef::Hash>& known_values,
Expression unbound_expected) {
ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*kBoringSchema));
ASSERT_OK_AND_ASSIGN(auto expected, unbound_expected.Bind(*kBoringSchema));
ASSERT_OK_AND_ASSIGN(auto replaced,
ReplaceFieldsWithKnownValues(known_values, expr));

EXPECT_EQ(replaced, expected);
ExpectIdenticalIfUnchanged(replaced, expr);
};
EXPECT_EQ(replaced, expected);
ExpectIdenticalIfUnchanged(replaced, expr);
};

std::unordered_map<FieldRef, KnownFieldValue, FieldRef::Hash> i32_is_3{
{"i32", Datum(3)}};
std::unordered_map<FieldRef, Datum, FieldRef::Hash> i32_is_3{{"i32", Datum(3)}};

ExpectReplacesTo(literal(1), i32_is_3, literal(1));

Expand Down Expand Up @@ -772,13 +771,18 @@ TEST(Expression, ReplaceFieldsWithKnownValues) {
literal(2),
}));

std::unordered_map<FieldRef, KnownFieldValue, FieldRef::Hash> a_valid_b_invalid{
{"a", true}, {"b", false}};
std::unordered_map<FieldRef, Datum, FieldRef::Hash> i32_valid_str_null{
{"i32", Datum(3)}, {"str", MakeNullScalar(utf8())}};

ExpectReplacesTo(is_null(field_ref("i32")), i32_valid_str_null, is_null(literal(3)));

ExpectReplacesTo(is_null(field_ref("a")), a_valid_b_invalid, literal(false));
ExpectReplacesTo(is_valid(field_ref("a")), a_valid_b_invalid, literal(true));
ExpectReplacesTo(is_null(field_ref("b")), a_valid_b_invalid, literal(true));
ExpectReplacesTo(is_valid(field_ref("b")), a_valid_b_invalid, literal(false));
ExpectReplacesTo(is_valid(field_ref("i32")), i32_valid_str_null, is_valid(literal(3)));

ExpectReplacesTo(is_null(field_ref("str")), i32_valid_str_null,
is_null(null_literal(utf8())));

ExpectReplacesTo(is_valid(field_ref("str")), i32_valid_str_null,
is_valid(null_literal(utf8())));
}

struct {
Expand Down Expand Up @@ -1042,7 +1046,7 @@ TEST(Expression, SimplifyWithGuarantee) {

Simplify{is_valid(field_ref("i32"))}
.WithGuarantee(is_valid(field_ref("i32")))
.Expect(literal(true));
.Expect(is_valid(field_ref("i32")));
}

TEST(Expression, SimplifyThenExecute) {
Expand Down
60 changes: 26 additions & 34 deletions cpp/src/arrow/dataset/partition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,31 +74,20 @@ Status KeyValuePartitioning::SetDefaultValuesFromKeys(const Expression& expr,
RecordBatchProjector* projector) {
ARROW_ASSIGN_OR_RAISE(auto known_values, ExtractKnownFieldValues(expr));
for (const auto& ref_value : known_values) {
const auto& known_value = ref_value.second;
if (known_value.concrete() && !known_value.datum.is_scalar()) {
return Status::Invalid("non-scalar partition key ", known_value.datum.ToString());
if (!ref_value.second.is_scalar()) {
return Status::Invalid("non-scalar partition key ", ref_value.second.ToString());
}

ARROW_ASSIGN_OR_RAISE(auto match,
ref_value.first.FindOneOrNone(*projector->schema()));

if (match.empty()) continue;

const auto& field = projector->schema()->field(match[0]);
if (known_value.concrete()) {
RETURN_NOT_OK(projector->SetDefaultValue(match, known_value.datum.scalar()));
} else if (known_value.valid) {
// We know some information about the value but nothing concrete enough to set. Can
// happen if expression is something like is_valid(field_ref("a"))
continue;
} else {
RETURN_NOT_OK(projector->SetDefaultValue(match, MakeNullScalar(field->type())));
}
RETURN_NOT_OK(projector->SetDefaultValue(match, ref_value.second.scalar()));
}
return Status::OK();
}

Expression ConjunctionFromGroupingRow(Scalar* row) {
inline Expression ConjunctionFromGroupingRow(Scalar* row) {
ScalarVector* values = &checked_cast<StructScalar*>(row)->value;
std::vector<Expression> equality_expressions(values->size());
for (size_t i = 0; i < values->size(); ++i) {
Expand Down Expand Up @@ -213,34 +202,37 @@ Result<std::string> KeyValuePartitioning::Format(const Expression& expr) const {

ARROW_ASSIGN_OR_RAISE(auto known_values, ExtractKnownFieldValues(expr));
for (const auto& ref_value : known_values) {
const auto& known_value = ref_value.second;
if (known_value.concrete() && !known_value.datum.is_scalar()) {
return Status::Invalid("non-scalar partition key ", known_value.datum.ToString());
if (!ref_value.second.is_scalar()) {
return Status::Invalid("non-scalar partition key ", ref_value.second.ToString());
}

ARROW_ASSIGN_OR_RAISE(auto match, ref_value.first.FindOneOrNone(*schema_));
if (match.empty()) continue;

const auto& field = schema_->field(match[0]);

if (known_value.concrete()) {
auto value = known_value.datum.scalar();
if (!value->type->Equals(field->type())) {
return Status::TypeError("scalar ", value->ToString(), " (of type ", *value->type,
") is invalid for ", field->ToString());
}
auto value = ref_value.second.scalar();

if (value->type->id() == Type::DICTIONARY) {
ARROW_ASSIGN_OR_RAISE(
value, checked_cast<const DictionaryScalar&>(*value).GetEncodedValue());
const auto& field = schema_->field(match[0]);
if (!value->type->Equals(field->type())) {
if (value->is_valid) {
auto maybe_converted = compute::Cast(value, field->type());
if (!maybe_converted.ok()) {
return Status::TypeError("Error converting scalar ", value->ToString(),
" (of type ", *value->type,
") to a partition key for ", field->ToString(), ": ",
maybe_converted.status().message());
}
value = maybe_converted->scalar();
} else {
value = MakeNullScalar(field->type());
}
}

values[match[0]] = std::move(value);
} else {
if (!known_value.valid) {
values[match[0]] = MakeNullScalar(field->type());
}
if (value->type->id() == Type::DICTIONARY) {
ARROW_ASSIGN_OR_RAISE(
value, checked_cast<const DictionaryScalar&>(*value).GetEncodedValue());
}

values[match[0]] = std::move(value);
}

return FormatValues(values);
Expand Down
16 changes: 13 additions & 3 deletions cpp/src/arrow/dataset/projector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <vector>

#include "arrow/array.h"
#include "arrow/compute/cast.h"
#include "arrow/dataset/type_fwd.h"
#include "arrow/record_batch.h"
#include "arrow/result.h"
Expand Down Expand Up @@ -88,9 +89,18 @@ Status RecordBatchProjector::SetDefaultValue(FieldRef ref,

auto field_type = to_->field(index)->type();
if (!field_type->Equals(scalar->type)) {
return Status::TypeError("field ", to_->field(index)->ToString(),
" cannot be materialized from scalar of type ",
*scalar->type);
if (scalar->is_valid) {
auto maybe_converted = compute::Cast(scalar, field_type);
if (!maybe_converted.ok()) {
return Status::TypeError("Field ", to_->field(index)->ToString(),
" cannot be materialized from scalar of type ",
*scalar->type,
". Cast error: ", maybe_converted.status().message());
}
scalar = maybe_converted->scalar();
} else {
scalar = MakeNullScalar(field_type);
}
}

scalars_[index] = std::move(scalar);
Expand Down
Loading