1818#include " velox/expression/VectorFunction.h"
1919#include " velox/functions/lib/LambdaFunctionUtil.h"
2020#include " velox/functions/lib/RowsTranslationUtil.h"
21+ #include " velox/vector/FlatMapVector.h"
2122#include " velox/vector/FunctionVector.h"
2223
2324namespace facebook ::velox::functions {
@@ -34,7 +35,7 @@ class FilterFunctionBase : public exec::VectorFunction {
3435 static vector_size_t doApply (
3536 const SelectivityVector& rows,
3637 const std::shared_ptr<T>& input,
37- const VectorPtr& lambdas ,
38+ const VectorPtr& lambda ,
3839 const std::vector<VectorPtr>& lambdaArgs,
3940 exec::EvalCtx& context,
4041 BufferPtr& resultOffsets,
@@ -60,7 +61,7 @@ class FilterFunctionBase : public exec::VectorFunction {
6061 getElementToTopLevelRows (numElements, rows, input.get (), pool);
6162
6263 exec::LocalDecodedVector bitsDecoder (context);
63- auto iter = lambdas ->asUnchecked <FunctionVector>()->iterator (&rows);
64+ auto iter = lambda ->asUnchecked <FunctionVector>()->iterator (&rows);
6465 while (auto entry = iter.next ()) {
6566 auto elementRows =
6667 toElementRows<T>(numElements, *entry.rows , input.get ());
@@ -173,27 +174,177 @@ class ArrayFilterFunction : public FilterFunctionBase {
173174// - https://prestodb.io/docs/current/functions/lambda.html
174175// - https://prestodb.io/blog/2020/03/02/presto-lambda
175176class MapFilterFunction : public FilterFunctionBase {
176- public:
177- void apply (
177+ private:
178+ // Builds a SelectivityVector based upon FlatMapVector's inMap buffer. This
179+ // buffer indicates whether or not a particular key is present in that map's
180+ // row. When applying lambda functions we should avoid executing on key-value
181+ // pairs that are not present in that particular row. SelectivityVector
182+ // rowsToFilterOn is used to filter out those rows.
183+ void buildInMapSelectivityVector (
184+ SelectivityVector& rowsToFilterOn,
185+ BufferPtr flattenedInMap,
186+ BufferPtr inMap,
187+ vector_size_t inMapSize,
188+ const SelectivityVector& rows,
189+ const vector_size_t * decodedIndices) const {
190+ // Flatten inMap buffer.
191+ auto * mutableFlattedInMap = flattenedInMap->asMutable <uint64_t >();
192+ bits::fillBits (mutableFlattedInMap, 0 , inMapSize, false );
193+ auto * mutableInMap = inMap->asMutable <uint64_t >();
194+ rows.applyToSelected ([&](vector_size_t row) {
195+ if (bits::isBitSet (mutableInMap, decodedIndices[row])) {
196+ bits::setBit (mutableFlattedInMap, decodedIndices[row]);
197+ }
198+ });
199+
200+ // Extract flattened inMap buffer values for next lambda call.
201+ rowsToFilterOn.clearAll ();
202+ auto bits = rowsToFilterOn.asMutableRange ().bits ();
203+ bits::orBits (bits, mutableFlattedInMap, 0 , rowsToFilterOn.size ());
204+ rowsToFilterOn.updateBounds ();
205+ }
206+
207+ // Apply filter function to vector of encoding FlatMapVector. Because the
208+ // entirety of the map values are stored in in a list of vectors (one vector
209+ // per key), we will need to apply the filter function on each vector and
210+ // associated inMap buffer. Additionally, we will have to reduce the number of
211+ // distinct keys stored in the FlatMapVector if they key list changes.
212+ void applyFlatMapVector (
213+ DecodedVector& decodedMap,
178214 const SelectivityVector& rows,
179215 std::vector<VectorPtr>& args,
180216 const TypePtr& outputType,
181217 exec::EvalCtx& context,
182- VectorPtr& result) const override {
183- VELOX_CHECK_EQ (args.size (), 2 );
184- exec::LocalDecodedVector mapDecoder (context, *args[0 ], rows);
185- auto & decodedMap = *mapDecoder.get ();
218+ VectorPtr& result) const {
219+ // Current map and fields
220+ const FlatMapVector& flatMap =
221+ *(decodedMap.base ())->template as <FlatMapVector>();
222+ auto distinctKeys = flatMap.distinctKeys ();
223+ auto mapValues = flatMap.mapValues ();
224+ auto numRows = rows.size ();
225+ BufferPtr decodedIndices =
226+ AlignedBuffer::allocate<vector_size_t >(numRows, flatMap.pool ());
227+ auto mutableIndices = decodedIndices->asMutable <vector_size_t >();
228+ for (int i = 0 ; i < decodedMap.size (); i++) {
229+ mutableIndices[i] = decodedMap.indices ()[i];
230+ }
231+
232+ // Result map fields
233+ auto filteredKeysIndices = AlignedBuffer::allocate<vector_size_t >(
234+ distinctKeys->size (), context.pool ());
235+ std::vector<VectorPtr> filteredMapValues;
236+ std::vector<BufferPtr> filteredInMaps;
237+ uint64_t * filteredInMap;
238+ auto numDistinct = 0 ;
239+ auto rawIndices = filteredKeysIndices->asMutable <vector_size_t >();
240+
241+ // Lambda function
242+ auto iter = args[1 ]->asUnchecked <FunctionVector>()->iterator (&rows);
243+ exec::LocalDecodedVector bitsDecoder (context);
244+ SelectivityVector rowsToFilterOn (flatMap.size ());
245+ // Selectivity vector to help ignore filtering for key-value pairs
246+ // identified by inMap buffer. Let's allocate here to avoid during each
247+ // iteration.
248+ auto flattenedInMap =
249+ AlignedBuffer::allocate<bool >(flatMap.size (), context.pool (), 0 );
250+
251+ // Apply lambda function to each map value vector and its associated key
252+ // from our flat map vector. If the key is not filtered out, we will copy it
253+ // to our result vector.
254+ while (auto entry = iter.next ()) {
255+ for (int channel = 0 ; channel < mapValues.size (); ++channel) {
256+ // Only apply lambda function to values that are in the map.
257+ buildInMapSelectivityVector (
258+ rowsToFilterOn,
259+ flattenedInMap,
260+ flatMap.inMaps ()[channel],
261+ flatMap.size (),
262+ *entry.rows ,
263+ decodedMap.indices ());
264+
265+ // Call lambda function and decode its output bit vector. We will
266+ // use it to determine what will persist to the final result vector.
267+ VectorPtr lambdaResultBits;
268+ entry.callable ->apply (
269+ rowsToFilterOn,
270+ nullptr ,
271+ nullptr ,
272+ &context,
273+ {
274+ BaseVector::wrapInConstant (
275+ flatMap.size (), channel, distinctKeys),
276+ mapValues[channel],
277+ },
278+ decodedIndices,
279+ &lambdaResultBits);
280+ bitsDecoder.get ()->decode (*lambdaResultBits);
281+
282+ bool isFilteredIn = false ;
283+ entry.rows ->applyToSelected ([&](vector_size_t row) {
284+ row = decodedMap.indices ()[row];
285+ if (rowsToFilterOn.isValid (row) &&
286+ !bitsDecoder.get ()->isNullAt (row) &&
287+ bitsDecoder.get ()->valueAt <bool >(row)) {
288+ // First time seeing this key; let's copy over its associated values
289+ // vector and define a new filtered inMap buffer. Let's also note
290+ // the index of this key for key filtering.
291+ if (!isFilteredIn) {
292+ filteredMapValues.push_back (
293+ BaseVector::copy (*mapValues[channel]));
294+ filteredInMaps.push_back (
295+ AlignedBuffer::allocate<bool >(numRows, context.pool (), 0 ));
296+ filteredInMap = filteredInMaps.back ()->asMutable <uint64_t >();
297+ rawIndices[numDistinct++] = channel;
298+ isFilteredIn = true ;
299+ }
300+ bits::setBit (filteredInMap, row);
301+ }
302+ });
303+ }
304+ }
305+
306+ // Resize filtered distinct keys indices in order to wrap in dictionary and
307+ // create our result filtered flat map vector
308+ filteredKeysIndices->setSize (numDistinct * sizeof (vector_size_t ));
309+ auto localResult = std::make_shared<FlatMapVector>(
310+ context.pool (),
311+ outputType,
312+ nullptr ,
313+ flatMap.size (),
314+ BaseVector::wrapInDictionary (
315+ BufferPtr (nullptr ), filteredKeysIndices, numDistinct, distinctKeys),
316+ std::move (filteredMapValues),
317+ std::move (filteredInMaps));
186318
187- auto flatMap = flattenMap (rows, args[0 ], decodedMap);
319+ // Handle wrapped encoding if necessary
320+ if (decodedMap.isIdentityMapping ()) {
321+ context.moveOrCopyResult (localResult, rows, result);
322+ } else {
323+ context.moveOrCopyResult (
324+ BaseVector::wrapInDictionary (
325+ nullptr , decodedIndices, decodedMap.size (), localResult),
326+ rows,
327+ result);
328+ }
329+ }
188330
189- VectorPtr keys = flatMap->mapKeys ();
190- VectorPtr values = flatMap->mapValues ();
331+ // Applies filter function on traditional map vector.
332+ void applyMapVector (
333+ DecodedVector& decodedMap,
334+ const SelectivityVector& rows,
335+ std::vector<VectorPtr>& args,
336+ const TypePtr& outputType,
337+ exec::EvalCtx& context,
338+ VectorPtr& result) const {
339+ auto mapVector = flattenMap (rows, args[0 ], decodedMap);
340+ VectorPtr keys = mapVector->mapKeys ();
341+ VectorPtr values = mapVector->mapValues ();
191342 BufferPtr resultSizes;
192343 BufferPtr resultOffsets;
193344 BufferPtr selectedIndices;
194345 auto numSelected = doApply (
195346 rows,
196- flatMap ,
347+ mapVector ,
197348 args[1 ],
198349 {keys, values},
199350 context,
@@ -220,9 +371,9 @@ class MapFilterFunction : public FilterFunctionBase {
220371 true /* flattenIfRedundant*/ )
221372 : nullptr ;
222373 // Set nulls for rows not present in 'rows'.
223- BufferPtr newNulls = addNullsForUnselectedRows (flatMap , rows);
374+ BufferPtr newNulls = addNullsForUnselectedRows (mapVector , rows);
224375 auto localResult = std::make_shared<MapVector>(
225- flatMap ->pool (),
376+ mapVector ->pool (),
226377 outputType,
227378 std::move (newNulls),
228379 rows.end (),
@@ -233,6 +384,35 @@ class MapFilterFunction : public FilterFunctionBase {
233384 context.moveOrCopyResult (localResult, rows, result);
234385 }
235386
387+ public:
388+ void apply (
389+ const SelectivityVector& rows,
390+ std::vector<VectorPtr>& args,
391+ const TypePtr& outputType,
392+ exec::EvalCtx& context,
393+ VectorPtr& result) const override {
394+ VELOX_CHECK_EQ (args.size (), 2 );
395+ exec::LocalDecodedVector mapDecoder (context, *args[0 ], rows);
396+ auto & decodedMap = *mapDecoder.get ();
397+
398+ // Flattening input maps will peel if possible, but may simply cast if the
399+ // vector is an identify mapping.
400+ switch (decodedMap.base ()->encoding ()) {
401+ case VectorEncoding::Simple::FLAT_MAP: {
402+ applyFlatMapVector (decodedMap, rows, args, outputType, context, result);
403+ break ;
404+ }
405+ case VectorEncoding::Simple::MAP: {
406+ applyMapVector (decodedMap, rows, args, outputType, context, result);
407+ break ;
408+ }
409+ default :
410+ VELOX_UNSUPPORTED (
411+ " map_filter not supported for encoding: {}" ,
412+ decodedMap.base ()->encoding ());
413+ }
414+ }
415+
236416 static std::vector<std::shared_ptr<exec::FunctionSignature>> signatures () {
237417 // map(K,V), function(K,V,boolean) -> map(K,V)
238418 return {exec::FunctionSignatureBuilder ()
0 commit comments