@@ -22,10 +22,12 @@ import org.apache.spark.sql.catalyst.InternalRow
2222import org .apache .spark .sql .catalyst .analysis .TypeCheckResult
2323import org .apache .spark .sql .catalyst .expressions .ArraySortLike .NullOrder
2424import org .apache .spark .sql .catalyst .expressions .codegen ._
25- import org .apache .spark .sql .catalyst .util .{ ArrayData , GenericArrayData , MapData , TypeUtils }
25+ import org .apache .spark .sql .catalyst .util ._
2626import org .apache .spark .sql .types ._
27+ import org .apache .spark .unsafe .Platform
2728import org .apache .spark .unsafe .array .ByteArrayMethods
2829import org .apache .spark .unsafe .types .{ByteArray , UTF8String }
30+ import org .apache .spark .util .collection .OpenHashSet
2931
3032/**
3133 * Given an array or map, returns its size. Returns -1 if null.
@@ -118,6 +120,229 @@ case class MapValues(child: Expression)
118120 override def prettyName : String = " map_values"
119121}
120122
123+ /**
124+ * Returns a map created from the given array of entries.
125+ */
126+ @ ExpressionDescription (
127+ usage = " _FUNC_(arrayOfEntries) - Returns a map created from the given array of entries." ,
128+ examples = """
129+ Examples:
130+ > SELECT _FUNC_(array(struct(1, 'a'), struct(2, 'b')));
131+ {1:"a",2:"b"}
132+ """ ,
133+ since = " 2.4.0" )
134+ case class MapFromEntries (child : Expression ) extends UnaryExpression
135+ {
136+ private lazy val resolvedDataType : Option [MapType ] = child.dataType match {
137+ case ArrayType (
138+ StructType (Array (
139+ StructField (_, keyType, false , _),
140+ StructField (_, valueType, valueNullable, _))),
141+ false ) => Some (MapType (keyType, valueType, valueNullable))
142+ case _ => None
143+ }
144+
145+ override def dataType : MapType = resolvedDataType.get
146+
147+ override def checkInputDataTypes (): TypeCheckResult = resolvedDataType match {
148+ case Some (_) => TypeCheckResult .TypeCheckSuccess
149+ case None => TypeCheckResult .TypeCheckFailure (s " ' ${child.sql}' is of " +
150+ s " ${child.dataType.simpleString} type. $prettyName accepts only null-free arrays " +
151+ " of pair structs. Values of the first struct field can't contain nulls and produce " +
152+ " duplicates." )
153+ }
154+
155+ override protected def nullSafeEval (input : Any ): Any = {
156+ val arrayData = input.asInstanceOf [ArrayData ]
157+ val length = arrayData.numElements()
158+ val keyArray = new Array [AnyRef ](length)
159+ val keySet = new OpenHashSet [AnyRef ]()
160+ val valueArray = new Array [AnyRef ](length)
161+ var i = 0 ;
162+ while (i < length) {
163+ val entry = arrayData.getStruct(i, 2 )
164+ val key = entry.get(0 , dataType.keyType)
165+ if (key == null ) {
166+ throw new RuntimeException (" The first field from a struct (key) can't be null." )
167+ }
168+ if (keySet.contains(key)) {
169+ throw new RuntimeException (" The first field from a struct (key) can't produce duplicates." )
170+ }
171+ keySet.add(key)
172+ keyArray.update(i, key)
173+ val value = entry.get(1 , dataType.valueType)
174+ valueArray.update(i, value)
175+ i += 1
176+ }
177+ ArrayBasedMapData (keyArray, valueArray)
178+ }
179+
180+ private def getHashSetDetails (): (String , String ) = dataType.keyType match {
181+ case ByteType | ShortType | IntegerType => (" $mcI$sp" , " Int" )
182+ case LongType => (" $mcJ$sp" , " Long" )
183+ case _ => (" " , " Object" )
184+ }
185+
186+ override protected def doGenCode (ctx : CodegenContext , ev : ExprCode ): ExprCode = {
187+ nullSafeCodeGen(ctx, ev, c => {
188+ val numElements = ctx.freshName(" numElements" )
189+ val keySet = ctx.freshName(" keySet" )
190+ val hsClass = classOf [OpenHashSet [_]].getName
191+ val tagPrefix = " scala.reflect.ClassTag$.MODULE$."
192+ val (hsSuffix, tagSuffix) = getHashSetDetails()
193+ val isKeyPrimitive = CodeGenerator .isPrimitiveType(dataType.keyType)
194+ val isValuePrimitive = CodeGenerator .isPrimitiveType(dataType.valueType)
195+ val code = if (isKeyPrimitive && isValuePrimitive) {
196+ genCodeForPrimitiveElements(ctx, c, ev.value, keySet, numElements)
197+ } else {
198+ genCodeForAnyElements(ctx, c, ev.value, keySet, numElements)
199+ }
200+ s """
201+ |final int $numElements = $c.numElements();
202+ |final $hsClass$hsSuffix $keySet = new $hsClass$hsSuffix( $tagPrefix$tagSuffix());
203+ | $code
204+ """ .stripMargin
205+ })
206+ }
207+
208+ private def genCodeForAssignmentLoop (
209+ ctx : CodegenContext ,
210+ childVariable : String ,
211+ numElements : String ,
212+ keySet : String ,
213+ keyAssignment : (String , String ) => String ,
214+ valueAssignment : (String , String ) => String ): String = {
215+ val entry = ctx.freshName(" entry" )
216+ val key = ctx.freshName(" key" )
217+ val idx = ctx.freshName(" idx" )
218+ val keyType = CodeGenerator .javaType(dataType.keyType)
219+
220+ s """
221+ |for (int $idx = 0; $idx < $numElements; $idx++) {
222+ | InternalRow $entry = $childVariable.getStruct( $idx, 2);
223+ | if ( $entry.isNullAt(0)) {
224+ | throw new RuntimeException("The first field from a struct (key) can't be null.");
225+ | }
226+ | $keyType $key = ${CodeGenerator .getValue(entry, dataType.keyType, " 0" )};
227+ | if ( $keySet.contains( $key)) {
228+ | throw new RuntimeException(
229+ | "The first field from a struct (key) can't produce duplicates.");
230+ | }
231+ | $keySet.add( $key);
232+ | ${keyAssignment(key, idx)}
233+ | ${valueAssignment(entry, idx)}
234+ |}
235+ """ .stripMargin
236+ }
237+
238+ private def genCodeForPrimitiveElements (
239+ ctx : CodegenContext ,
240+ childVariable : String ,
241+ mapData : String ,
242+ keySet : String ,
243+ numElements : String ): String = {
244+ val byteArraySize = ctx.freshName(" byteArraySize" )
245+ val keySectionSize = ctx.freshName(" keySectionSize" )
246+ val valueSectionSize = ctx.freshName(" valueSectionSize" )
247+ val data = ctx.freshName(" byteArray" )
248+ val unsafeMapData = ctx.freshName(" unsafeMapData" )
249+ val keyArrayData = ctx.freshName(" keyArrayData" )
250+ val valueArrayData = ctx.freshName(" valueArrayData" )
251+
252+ val baseOffset = Platform .BYTE_ARRAY_OFFSET
253+ val keySize = dataType.keyType.defaultSize
254+ val valueSize = dataType.valueType.defaultSize
255+ val kByteSize = s " UnsafeArrayData.calculateSizeOfUnderlyingByteArray( $numElements, $keySize) "
256+ val vByteSize = s " UnsafeArrayData.calculateSizeOfUnderlyingByteArray( $numElements, $valueSize) "
257+ val keyTypeName = CodeGenerator .primitiveTypeName(dataType.keyType)
258+ val valueTypeName = CodeGenerator .primitiveTypeName(dataType.valueType)
259+
260+ val keyAssignment = (key : String , idx : String ) => s " $keyArrayData.set $keyTypeName( $idx, $key); "
261+ val valueAssignment = (entry : String , idx : String ) => {
262+ val value = CodeGenerator .getValue(entry, dataType.valueType, " 1" )
263+ val valueNullUnsafeAssignment = s " $valueArrayData.set $valueTypeName( $idx, $value); "
264+ if (dataType.valueContainsNull) {
265+ s """
266+ |if ( $entry.isNullAt(1)) {
267+ | $valueArrayData.setNullAt( $idx);
268+ |} else {
269+ | $valueNullUnsafeAssignment
270+ |}
271+ """ .stripMargin
272+ } else {
273+ valueNullUnsafeAssignment
274+ }
275+ }
276+ val assignmentLoop = genCodeForAssignmentLoop(
277+ ctx,
278+ childVariable,
279+ numElements,
280+ keySet,
281+ keyAssignment,
282+ valueAssignment
283+ )
284+
285+ s """
286+ |final long $keySectionSize = $kByteSize;
287+ |final long $valueSectionSize = $vByteSize;
288+ |final long $byteArraySize = 8 + $keySectionSize + $valueSectionSize;
289+ |if ( $byteArraySize > ${ByteArrayMethods .MAX_ROUNDED_ARRAY_LENGTH }) {
290+ | ${genCodeForAnyElements(ctx, childVariable, mapData, keySet, numElements)}
291+ |} else {
292+ | final byte[] $data = new byte[(int) $byteArraySize];
293+ | UnsafeMapData $unsafeMapData = new UnsafeMapData();
294+ | Platform.putLong( $data, $baseOffset, $keySectionSize);
295+ | Platform.putLong( $data, ${baseOffset + 8 }, $numElements);
296+ | Platform.putLong( $data, ${baseOffset + 8 } + $keySectionSize, $numElements);
297+ | $unsafeMapData.pointTo( $data, $baseOffset, (int) $byteArraySize);
298+ | ArrayData $keyArrayData = $unsafeMapData.keyArray();
299+ | ArrayData $valueArrayData = $unsafeMapData.valueArray();
300+ | $assignmentLoop
301+ | $mapData = $unsafeMapData;
302+ |}
303+ """ .stripMargin
304+ }
305+
306+ private def genCodeForAnyElements (
307+ ctx : CodegenContext ,
308+ childVariable : String ,
309+ mapData : String ,
310+ keySet : String ,
311+ numElements : String ): String = {
312+ val keys = ctx.freshName(" keys" )
313+ val values = ctx.freshName(" values" )
314+ val mapDataClass = classOf [ArrayBasedMapData ].getName()
315+
316+ val isValuePrimitive = CodeGenerator .isPrimitiveType(dataType.valueType)
317+ val valueAssignment = (entry : String , idx : String ) => {
318+ val value = CodeGenerator .getValue(entry, dataType.valueType, " 1" )
319+ if (dataType.valueContainsNull && isValuePrimitive) {
320+ s " $values[ $idx] = $entry.isNullAt(1) ? null : (Object) $value; "
321+ } else {
322+ s " $values[ $idx] = $value; "
323+ }
324+ }
325+ val keyAssignment = (key : String , idx : String ) => s " $keys[ $idx] = $key; "
326+ val assignmentLoop = genCodeForAssignmentLoop(
327+ ctx,
328+ childVariable,
329+ numElements,
330+ keySet,
331+ keyAssignment,
332+ valueAssignment)
333+
334+ s """
335+ |final Object[] $keys = new Object[ $numElements];
336+ |final Object[] $values = new Object[ $numElements];
337+ | $assignmentLoop
338+ | $mapData = $mapDataClass.apply( $keys, $values);
339+ """ .stripMargin
340+ }
341+
342+ override def prettyName : String = " map_from_entries"
343+ }
344+
345+
121346/**
122347 * Common base class for [[SortArray ]] and [[ArraySort ]].
123348 */
0 commit comments