Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions core/src/main/scala/org/apache/spark/api/r/SerDe.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ import java.io.{DataInputStream, DataOutputStream}
import java.nio.charset.StandardCharsets
import java.sql.{Date, Time, Timestamp}

import scala.collection.JavaConverters._
import scala.collection.mutable

import org.apache.spark.util.collection.Utils

/**
* Utility functions to serialize, deserialize objects to / from R
*/
Expand Down Expand Up @@ -236,7 +237,7 @@ private[spark] object SerDe {
val keys = readArray(in, jvmObjectTracker).asInstanceOf[Array[Object]]
val values = readList(in, jvmObjectTracker)

keys.zip(values).toMap.asJava
Utils.toJavaMap(keys, values)
} else {
new java.util.HashMap[Object, Object]()
}
Expand Down
4 changes: 3 additions & 1 deletion core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ import org.apache.spark.launcher.SparkLauncher
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.serializer.{DeserializationStream, SerializationStream, Serializer, SerializerInstance}
import org.apache.spark.status.api.v1.{StackTrace, ThreadStackTrace}
import org.apache.spark.util.collection.{Utils => CUtils}
import org.apache.spark.util.io.ChunkedByteBufferOutputStream

/** CallSite represents a place in user code. It can have a short and a long form. */
Expand Down Expand Up @@ -1718,7 +1719,8 @@ private[spark] object Utils extends Logging {
assert(files.length == fileLengths.length)
val startIndex = math.max(start, 0)
val endIndex = math.min(end, fileLengths.sum)
val fileToLength = files.zip(fileLengths).toMap
val fileToLength = CUtils.toMap(files, fileLengths)

logDebug("Log files: \n" + fileToLength.mkString("\n"))

val stringBuffer = new StringBuffer((endIndex - startIndex).toInt)
Expand Down
29 changes: 29 additions & 0 deletions core/src/main/scala/org/apache/spark/util/collection/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@

package org.apache.spark.util.collection

import java.util.Collections

import scala.collection.JavaConverters._
import scala.collection.immutable

import com.google.common.collect.{Iterators => GuavaIterators, Ordering => GuavaOrdering}

Expand Down Expand Up @@ -62,4 +65,30 @@ private[spark] object Utils {
*/
def sequenceToOption[T](input: Seq[Option[T]]): Option[Seq[T]] =
if (input.forall(_.isDefined)) Some(input.flatten) else None

/**
* Same function as `keys.zip(values).toMap`, but has perf gain.
*/
def toMap[K, V](keys: Iterable[K], values: Iterable[V]): Map[K, V] = {
val builder = immutable.Map.newBuilder[K, V]
val keyIter = keys.iterator
val valueIter = values.iterator
while (keyIter.hasNext && valueIter.hasNext) {
builder += (keyIter.next(), valueIter.next()).asInstanceOf[(K, V)]
}
builder.result()
}

/**
* Same function as `keys.zip(values).toMap.asJava`, but has perf gain.
*/
def toJavaMap[K, V](keys: Iterable[K], values: Iterable[V]): java.util.Map[K, V] = {
val map = new java.util.HashMap[K, V]()
val keyIter = keys.iterator
val valueIter = values.iterator
while (keyIter.hasNext && valueIter.hasNext) {
map.put(keyIter.next(), valueIter.next())
}
Collections.unmodifiableMap(map)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.annotation.Since
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.util.collection.Utils

/**
* Evaluator for ranking algorithms.
Expand Down Expand Up @@ -155,7 +156,7 @@ class RankingMetrics[T: ClassTag] @Since("1.2.0") (predictionAndLabels: RDD[_ <:
rdd.map { case (pred, lab, rel) =>
val useBinary = rel.isEmpty
val labSet = lab.toSet
val relMap = lab.zip(rel).toMap
val relMap = Utils.toMap(lab, rel)
if (useBinary && lab.size != rel.size) {
logWarning(
"# of ground truth set and # of relevance value set should be equal, " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import org.apache.spark.sql.types._
import org.apache.spark.sql.types.DayTimeIntervalType._
import org.apache.spark.sql.types.YearMonthIntervalType._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.collection.Utils

/**
* Functions to convert Scala types to Catalyst types and vice versa.
Expand Down Expand Up @@ -229,7 +230,7 @@ object CatalystTypeConverters {
val convertedValues =
if (isPrimitive(valueType)) values else values.map(valueConverter.toScala)

convertedKeys.zip(convertedValues).toMap
Utils.toMap(convertedKeys, convertedValues)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ import org.apache.spark.sql.types.DayTimeIntervalType.DAY
import org.apache.spark.sql.util.{CaseInsensitiveStringMap, SchemaUtils}
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
import org.apache.spark.util.Utils
import org.apache.spark.util.collection.{Utils => CUtils}

/**
* A trivial [[Analyzer]] with a dummy [[SessionCatalog]], [[EmptyFunctionRegistry]] and
Expand Down Expand Up @@ -3457,7 +3458,7 @@ class Analyzer(override val catalogManager: CatalogManager)
throw QueryCompilationErrors.writeTableWithMismatchedColumnsError(
cols.size, query.output.size, query)
}
val nameToQueryExpr = cols.zip(query.output).toMap
val nameToQueryExpr = CUtils.toMap(cols, query.output)
// Static partition columns in the table output should not appear in the column list
// they will be handled in another rule ResolveInsertInto
val reordered = tableOutput.flatMap { nameToQueryExpr.get(_).orElse(None) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.trees.TreePattern.OUTER_REFERENCE
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.util.collection.Utils

/**
* Decorrelate the inner query by eliminating outer references and create domain joins.
Expand Down Expand Up @@ -346,7 +347,7 @@ object DecorrelateInnerQuery extends PredicateHelper {
val domains = attributes.map(_.newInstance())
// A placeholder to be rewritten into domain join.
val domainJoin = DomainJoin(domains, plan)
val outerReferenceMap = attributes.zip(domains).toMap
val outerReferenceMap = Utils.toMap(attributes, domains)
// Build join conditions between domain attributes and outer references.
// EqualNullSafe is used to make sure null key can be joined together. Note
// outer referenced attributes can be changed during the outer query optimization.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.CTE
import org.apache.spark.util.collection.Utils

/**
* Infer predicates and column pruning for [[CTERelationDef]] from its reference points, and push
Expand Down Expand Up @@ -71,7 +72,7 @@ object PushdownPredicatesAndPruneColumnsForCTEDef extends Rule[LogicalPlan] {

case PhysicalOperation(projects, predicates, ref: CTERelationRef) =>
val (cteDef, precedence, preds, attrs) = cteMap(ref.cteId)
val attrMapping = ref.output.zip(cteDef.output).map{ case (r, d) => r -> d }.toMap
val attrMapping = Utils.toMap(ref.output, cteDef.output)
val newPredicates = if (isTruePredicate(preds)) {
preds
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LogicalPl
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.AGGREGATE
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.util.collection.Utils

/**
* This rule rewrites an aggregate query with distinct aggregations into an expanded double
Expand Down Expand Up @@ -265,7 +266,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {

// Setup expand & aggregate operators for distinct aggregate expressions.
val distinctAggChildAttrLookup = distinctAggChildAttrMap.toMap
val distinctAggFilterAttrLookup = distinctAggFilters.zip(maxConds.map(_.toAttribute)).toMap
val distinctAggFilterAttrLookup = Utils.toMap(distinctAggFilters, maxConds.map(_.toAttribute))
val distinctAggOperatorMap = distinctAggGroups.toSeq.zipWithIndex.map {
case ((group, expressions), i) =>
val id = Literal(i + 1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.util

import java.util.{Map => JavaMap}

import org.apache.spark.util.collection.Utils

/**
* A simple `MapData` implementation which is backed by 2 arrays.
*
Expand Down Expand Up @@ -129,20 +131,19 @@ object ArrayBasedMapData {
def toScalaMap(map: ArrayBasedMapData): Map[Any, Any] = {
val keys = map.keyArray.asInstanceOf[GenericArrayData].array
val values = map.valueArray.asInstanceOf[GenericArrayData].array
keys.zip(values).toMap
Utils.toMap(keys, values)
}

def toScalaMap(keys: Array[Any], values: Array[Any]): Map[Any, Any] = {
keys.zip(values).toMap
Utils.toMap(keys, values)
}

def toScalaMap(keys: scala.collection.Seq[Any],
values: scala.collection.Seq[Any]): Map[Any, Any] = {
keys.zip(values).toMap
Utils.toMap(keys, values)
}

def toJavaMap(keys: Array[Any], values: Array[Any]): java.util.Map[Any, Any] = {
import scala.collection.JavaConverters._
keys.zip(values).toMap.asJava
Utils.toJavaMap(keys, values)
}
Copy link
Contributor

@mridulm mridulm Sep 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this method anymore ? Why not replace with Utils.toJavaMap entirely (in JavaTypeInference) ? Any issues with that ?

Copy link
Contributor Author

@LuciferYang LuciferYang Sep 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ArrayBasedMapData#toJavaMap is already a never used method, I think we can delete it, but need to confirm whether MiMa check can pass first

EDIT: ArrayBasedMapData#toJavaMap not unused method, it used by JavaTypeInference, sorry for missing what @mridulm said

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let me check this later

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ArrayBasedMapData is not a public API and shouldn't be tracked by mima.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

Copy link
Contributor Author

@LuciferYang LuciferYang Sep 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mridulm @cloud-fan

If

StaticInvoke(
ArrayBasedMapData.getClass,
ObjectType(classOf[JMap[_, _]]),
"toJavaMap",
keyData :: valueData :: Nil,
returnNullable = false)

change to

        StaticInvoke(
          Utils.getClass,
          ObjectType(classOf[JMap[_, _]]),
          "toJavaMap",
          keyData :: valueData :: Nil,
          returnNullable = false)

The signature to toJavaMap method in collection.Utils need change from

def toJavaMap[K, V](keys: Iterable[K], values: Iterable[V]): java.util.Map[K, V]

to

def toJavaMap[K, V](keys: Array[K], values: Array[V]): java.util.Map[K, V]

Otherwise, relevant tests will fail as due to

16:20:35.587 ERROR org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator: failed to compile: org.codehaus.commons.compiler.CompileException: File 'generated.java', Line 375, Column 50: No applicable constructor/method found for actual parameters "java.lang.Object[], java.lang.Object[]"; candidates are: "public static java.util.Map org.apache.spark.util.collection.Utils.toJavaMap(scala.collection.Iterable, scala.collection.Iterable)"
org.codehaus.commons.compiler.CompileException: File 'generated.java', Line 375, Column 50: No applicable constructor/method found for actual parameters "java.lang.Object[], java.lang.Object[]"; candidates are: "public static java.util.Map org.apache.spark.util.collection.Utils.toJavaMap(scala.collection.Iterable, scala.collection.Iterable)"

If the method signature is def toJavaMap[K, V](keys: Array[K], values: Array[V]): java.util.Map[K, V], it will limit the use scope of this method, so I prefer to retain the ArrayBasedMapData#toJavaMap method

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it acceptable to retain this method?

}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.apache.spark.sql.types._
import org.apache.spark.sql.types.DayTimeIntervalType._
import org.apache.spark.sql.types.YearMonthIntervalType.YEAR
import org.apache.spark.unsafe.types.CalendarInterval
import org.apache.spark.util.collection.Utils
/**
* Random data generators for Spark SQL DataTypes. These generators do not generate uniformly random
* values; instead, they're biased to return "interesting" values (such as maximum / minimum values)
Expand Down Expand Up @@ -340,7 +341,7 @@ object RandomDataGenerator {
count += 1
}
val values = Seq.fill(keys.size)(valueGenerator())
keys.zip(values).toMap
Utils.toMap(keys, values)
}
}
case StructType(fields) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection, UnknownPartitioning}
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.util.collection.Utils

object ExternalRDD {

Expand Down Expand Up @@ -106,7 +107,7 @@ case class LogicalRDD(
session :: originStats :: originConstraints :: Nil

override def newInstance(): LogicalRDD.this.type = {
val rewrite = output.zip(output.map(_.newInstance())).toMap
val rewrite = Utils.toMap(output, output.map(_.newInstance()))

val rewrittenPartitioning = outputPartitioning match {
case p: Expression =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.Utils
import org.apache.spark.util.collection.{Utils => CUtils}

/**
* Utility functions used by the query planner to convert our plan to new aggregation code path.
Expand Down Expand Up @@ -218,7 +219,7 @@ object AggUtils {
}

// 3. Create an Aggregate operator for partial aggregation (for distinct)
val distinctColumnAttributeLookup = distinctExpressions.zip(distinctAttributes).toMap
val distinctColumnAttributeLookup = CUtils.toMap(distinctExpressions, distinctAttributes)
val rewrittenDistinctFunctions = functionsWithDistinct.map {
// Children of an AggregateFunction with DISTINCT keyword has already
// been evaluated. At here, we need to replace original children
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{And, EqualTo, Literal}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.util.PartitioningUtils
import org.apache.spark.util.collection.Utils

/**
* Analyzes a given set of partitions to generate per-partition statistics, which will be used in
Expand Down Expand Up @@ -147,7 +148,7 @@ case class AnalyzePartitionCommand(
r.get(i).toString
}
}
val spec = tableMeta.partitionColumnNames.zip(partitionColumnValues).toMap
val spec = Utils.toMap(tableMeta.partitionColumnNames, partitionColumnValues)
val count = BigInt(r.getLong(partitionColumns.size))
(spec, count)
}.toMap
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.spark.sql.execution.datasources.DataSourceStrategy
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.collection.Utils

object PushDownUtils {
/**
Expand Down Expand Up @@ -203,7 +204,7 @@ object PushDownUtils {
def toOutputAttrs(
schema: StructType,
relation: DataSourceV2Relation): Seq[AttributeReference] = {
val nameToAttr = relation.output.map(_.name).zip(relation.output).toMap
val nameToAttr = Utils.toMap(relation.output.map(_.name), relation.output)
val cleaned = CharVarcharUtils.replaceCharVarcharWithStringInSchema(schema)
cleaned.toAttributes.map {
// we have to keep the attribute id during transformation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning}
import org.apache.spark.sql.execution.UnaryExecNode
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.Utils

/**
* Holds common logic for window operators
Expand Down Expand Up @@ -69,7 +70,7 @@ trait WindowExecBase extends UnaryExecNode {
// Results of window expressions will be on the right side of child's output
BoundReference(child.output.size + i, e.dataType, e.nullable)
}
val unboundToRefMap = expressions.zip(references).toMap
val unboundToRefMap = Utils.toMap(expressions, references)
val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap))
UnsafeProjection.create(
child.output ++ patchedWindowExpression,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import scala.collection.mutable

import org.apache.spark.scheduler.{ExecutorCacheTaskLocation, TaskLocation}
import org.apache.spark.streaming.receiver.Receiver
import org.apache.spark.util.collection.Utils

/**
* A class that tries to schedule receivers with evenly distributed. There are two phases for
Expand Down Expand Up @@ -135,7 +136,7 @@ private[streaming] class ReceiverSchedulingPolicy {
leastScheduledExecutors += executor
}

receivers.map(_.streamId).zip(scheduledLocations.map(_.toSeq)).toMap
Utils.toMap(receivers.map(_.streamId), scheduledLocations.map(_.toSeq))
}

/**
Expand Down