diff --git a/velox/exec/Expand.cpp b/velox/exec/Expand.cpp index 5d866b888c1..44b6215819e 100644 --- a/velox/exec/Expand.cpp +++ b/velox/exec/Expand.cpp @@ -31,6 +31,7 @@ Expand::Expand( const auto numRows = expandNode->projections().size(); fieldProjections_.reserve(numRows); constantProjections_.reserve(numRows); + constantOutputs_.reserve(numRows); const auto numColumns = expandNode->names().size(); for (const auto& rowProjections : expandNode->projections()) { std::vector rowProjection; @@ -58,6 +59,25 @@ Expand::Expand( } } +void Expand::initialize() { + if (constantProjections_.empty()) { + return; + } + const auto numColumns = constantProjections_[0].size(); + for (const auto& projections : constantProjections_) { + std::vector constantOutput; + constantOutput.reserve(numColumns); + for (const auto& constant : projections) { + if (constant) { + constantOutput.push_back(constant->toConstantVector(pool())); + } else { + constantOutput.push_back(nullptr); + } + } + constantOutputs_.emplace_back(std::move(constantOutput)); + } +} + bool Expand::needsInput() const { return !noMoreInput_ && input_ == nullptr; } @@ -81,21 +101,13 @@ RowVectorPtr Expand::getOutput() { std::vector outputColumns(outputType_->size()); const auto& rowProjection = fieldProjections_[rowIndex_]; - const auto& constantProjection = constantProjections_[rowIndex_]; + const auto& constantProjection = constantOutputs_[rowIndex_]; const auto numColumns = rowProjection.size(); for (auto i = 0; i < numColumns; ++i) { if (rowProjection[i] == kConstantChannel) { - const auto& constantExpr = constantProjection[i]; - if (constantExpr->value().isNull()) { - // Add null column. - outputColumns[i] = BaseVector::createNullConstant( - outputType_->childAt(i), numInput, pool()); - } else { - // Add constant column. - outputColumns[i] = BaseVector::createConstant( - constantExpr->type(), constantExpr->value(), numInput, pool()); - } + outputColumns[i] = + BaseVector::wrapInConstant(numInput, 0, constantProjection[i]); } else { outputColumns[i] = input_->childAt(rowProjection[i]); } diff --git a/velox/exec/Expand.h b/velox/exec/Expand.h index 97c737c1d1f..adf87a71526 100644 --- a/velox/exec/Expand.h +++ b/velox/exec/Expand.h @@ -42,11 +42,15 @@ class Expand : public Operator { } private: + void initialize() override; + std::vector> fieldProjections_; std::vector>> constantProjections_; + std::vector> constantOutputs_; + // Used to indicate the index of fieldProjections_. int32_t rowIndex_{0}; }; diff --git a/velox/exec/tests/ExpandTest.cpp b/velox/exec/tests/ExpandTest.cpp index f541d20c04a..d1161b374fc 100644 --- a/velox/exec/tests/ExpandTest.cpp +++ b/velox/exec/tests/ExpandTest.cpp @@ -21,7 +21,6 @@ using namespace facebook::velox; using namespace facebook::velox::exec::test; namespace facebook::velox::exec { - namespace { class ExpandTest : public OperatorTestBase { public: @@ -37,7 +36,31 @@ class ExpandTest : public OperatorTestBase { }); } }; -} // anonymous namespace + +TEST_F(ExpandTest, complexConstant) { + auto data = makeRowVectorData(3); + auto children = data->children(); + auto arrayVector = + makeArrayVector({{1, 2, 3}, {1, 2, 3}, {1, 2, 3}}); + children.push_back(arrayVector); + children.push_back(makeAllNullArrayVector(3, INTEGER())); + children.push_back(makeNullConstant(TypeKind::INTEGER, 3)); + auto expected = makeRowVector(children); + + auto plan = PlanBuilder(pool()) + .values({data}) + .expand( + {{"k1", + "k2", + "a", + "b", + "ARRAY[1, 2, 3] as c", + "null::integer[] as d", + "null::integer as e"}}) + .planNode(); + + assertQuery(plan, expected); +} TEST_F(ExpandTest, groupingSets) { auto data = makeRowVectorData(1'000); @@ -151,4 +174,5 @@ TEST_F(ExpandTest, invalidUseCases) { "projections must not be empty."); } +} // namespace } // namespace facebook::velox::exec