diff --git a/velox/expression/tests/SimpleFunctionTest.cpp b/velox/expression/tests/SimpleFunctionTest.cpp index 694737a7f5ff..4fde2ee1f447 100644 --- a/velox/expression/tests/SimpleFunctionTest.cpp +++ b/velox/expression/tests/SimpleFunctionTest.cpp @@ -174,6 +174,47 @@ TEST_F(SimpleFunctionTest, arrayReader) { assertEqualVectors(expected, result); } +// Function that takes an array of arrays as input. +template +struct ArrayArrayReaderFunction { + VELOX_DEFINE_FUNCTION_TYPES(T); + + FOLLY_ALWAYS_INLINE bool call( + int64_t& out, + const arg_type>>& input) { + out = 0; + for (const auto& inner : input) { + if (inner.has_value()) { + for (const auto& v : inner.value()) { + if (v.has_value()) { + out += v.value(); + } + } + } + } + return true; + } +}; + +TEST_F(SimpleFunctionTest, arrayArrayReader) { + registerFunction>>( + {"array_array_reader_func"}); + + const size_t rows = arrayData.size(); + auto arrayVector = makeArrayVector(arrayData); + auto result = evaluate>( + "array_array_reader_func(array_constructor(c0, c0))", + makeRowVector({arrayVector})); + + auto arrayDataLocal = arrayData; + auto expected = makeFlatVector(rows, [&arrayDataLocal](auto row) { + return 2 * + std::accumulate( + arrayDataLocal[row].begin(), arrayDataLocal[row].end(), 0); + }); + assertEqualVectors(expected, result); +} + // Some input data for the rowVector. static std::vector rowVectorCol1 = {0, 22, 44, 55, 99, 101, 9, 0}; static std::vector rowVectorCol2 =