Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions core/src/main/scala/org/apache/spark/util/collection/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,20 @@ private[spark] object Utils {
builder.result()
}

/**
* Same function as `keys.zipWithIndex.toMap`, but has perf gain.
*/
def toMapWithIndex[K](keys: Iterable[K]): Map[K, Int] = {
val builder = immutable.Map.newBuilder[K, Int]
val keyIter = keys.iterator
var idx = 0
while (keyIter.hasNext) {
builder += (keyIter.next(), idx).asInstanceOf[(K, Int)]
idx = idx + 1
}
builder.result()
}

/**
* Same function as `keys.zip(values).toMap.asJava`, but has perf gain.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.ml.attribute
import scala.annotation.varargs

import org.apache.spark.sql.types.{DoubleType, Metadata, MetadataBuilder, NumericType, StructField}
import org.apache.spark.util.collection.Utils

/**
* Abstract class for ML attributes.
Expand Down Expand Up @@ -338,7 +339,7 @@ class NominalAttribute private[ml] (
override def isNominal: Boolean = true

private lazy val valueToIndex: Map[String, Int] = {
values.map(_.zipWithIndex.toMap).getOrElse(Map.empty)
values.map(Utils.toMapWithIndex(_)).getOrElse(Map.empty)
}

/** Index of a specific value. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.collection.OpenHashMap
import org.apache.spark.util.collection.{OpenHashMap, Utils}

/**
* Params for [[CountVectorizer]] and [[CountVectorizerModel]].
Expand Down Expand Up @@ -305,7 +305,7 @@ class CountVectorizerModel(
override def transform(dataset: Dataset[_]): DataFrame = {
val outputSchema = transformSchema(dataset.schema, logging = true)
if (broadcastDict.isEmpty) {
val dict = vocabulary.zipWithIndex.toMap
val dict = Utils.toMapWithIndex(vocabulary)
broadcastDict = Some(dataset.sparkSession.sparkContext.broadcast(dict))
}
val dictBr = broadcastDict.get
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.util.collection.OpenHashSet
import org.apache.spark.util.collection.{OpenHashSet, Utils}

/** Private trait for params for VectorIndexer and VectorIndexerModel */
private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOutputCol
Expand Down Expand Up @@ -235,7 +235,7 @@ object VectorIndexer extends DefaultParamsReadable[VectorIndexer] {
if (zeroExists) {
sortedFeatureValues = 0.0 +: sortedFeatureValues
}
val categoryMap: Map[Double, Int] = sortedFeatureValues.zipWithIndex.toMap
val categoryMap: Map[Double, Int] = Utils.toMapWithIndex(sortedFeatureValues)
(featureIndex, categoryMap)
}.toMap
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd._
import org.apache.spark.sql.SparkSession
import org.apache.spark.util.Utils
import org.apache.spark.util.collection.{Utils => CUtils}
import org.apache.spark.util.random.XORShiftRandom

/**
Expand Down Expand Up @@ -470,7 +471,7 @@ class Word2Vec extends Serializable with Logging {
newSentences.unpersist()

val wordArray = vocab.map(_.word)
new Word2VecModel(wordArray.zipWithIndex.toMap, syn0Global)
new Word2VecModel(CUtils.toMapWithIndex(wordArray), syn0Global)
}

/**
Expand Down Expand Up @@ -639,7 +640,7 @@ class Word2VecModel private[spark] (
object Word2VecModel extends Loader[Word2VecModel] {

private def buildWordIndex(model: Map[String, Array[Float]]): Map[String, Int] = {
model.keys.zipWithIndex.toMap
CUtils.toMapWithIndex(model.keys)
}

private def buildWordVectors(model: Map[String, Array[Float]]): Array[Float] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.collection.Utils

/**
* Model trained by [[FPGrowth]], which holds frequent itemsets.
Expand Down Expand Up @@ -269,7 +270,7 @@ class FPGrowth private[spark] (
minCount: Long,
freqItems: Array[Item],
partitioner: Partitioner): RDD[FreqItemset[Item]] = {
val itemToRank = freqItems.zipWithIndex.toMap
val itemToRank = Utils.toMapWithIndex(freqItems)
data.flatMap { transaction =>
genCondTransactions(transaction, itemToRank, partitioner)
}.aggregateByKey(new FPTree[Int], partitioner.numPartitions)(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.collection.Utils

/**
* A parallel PrefixSpan algorithm to mine frequent sequential patterns.
Expand Down Expand Up @@ -147,7 +148,7 @@ class PrefixSpan private (
logInfo(s"number of frequent items: ${freqItems.length}")

// Keep only frequent items from input sequences and convert them to internal storage.
val itemToInt = freqItems.zipWithIndex.toMap
val itemToInt = Utils.toMapWithIndex(freqItems)
val dataInternalRepr = toDatabaseInternalRepr(data, itemToInt)
.persist(StorageLevel.MEMORY_AND_DISK)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
import org.apache.spark.util.collection.OpenHashMap
import org.apache.spark.util.collection.{OpenHashMap, Utils}

/**
* Conduct the chi-squared test for the input RDDs using the specified method.
Expand Down Expand Up @@ -181,14 +181,14 @@ private[spark] object ChiSqTest extends Logging {
counts: Map[(Double, Double), Long],
methodName: String,
col: Int): ChiSqTestResult = {
val label2Index = counts.iterator.map(_._1._1).toArray.distinct.sorted.zipWithIndex.toMap
val label2Index = Utils.toMapWithIndex(counts.iterator.map(_._1._1).toArray.distinct.sorted)
val numLabels = label2Index.size
if (numLabels > maxCategories) {
throw new SparkException(s"Chi-square test expect factors (categorical values) but "
+ s"found more than $maxCategories distinct label values.")
}

val value2Index = counts.iterator.map(_._1._2).toArray.distinct.sorted.zipWithIndex.toMap
val value2Index = Utils.toMapWithIndex(counts.iterator.map(_._1._2).toArray.distinct.sorted)
val numValues = value2Index.size
if (numValues > maxCategories) {
throw new SparkException(s"Chi-square test expect factors (categorical values) but "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.Utils
import org.apache.spark.util.random.RandomSampler

/**
Expand Down Expand Up @@ -1235,7 +1236,7 @@ object Expand {
groupByAttrs: Seq[Attribute],
gid: Attribute,
child: LogicalPlan): Expand = {
val attrMap = groupByAttrs.zipWithIndex.toMap
val attrMap = Utils.toMapWithIndex(groupByAttrs)

val hasDuplicateGroupingSets = groupingSetsAttrs.size !=
groupingSetsAttrs.map(_.map(_.exprId).toSet).distinct.size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._
import org.apache.spark.sql.catalyst.util.StringUtils.StringConcat
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.collection.Utils

/**
* A [[StructType]] object can be constructed by
Expand Down Expand Up @@ -117,7 +118,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru

private lazy val fieldNamesSet: Set[String] = fieldNames.toSet
private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap
private lazy val nameToIndex: Map[String, Int] = fieldNames.zipWithIndex.toMap
private lazy val nameToIndex: Map[String, Int] = Utils.toMapWithIndex(fieldNames)

override def equals(that: Any): Boolean = {
that match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import org.apache.spark.sql.execution.datasources.DataSourceUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.collection.Utils

/**
* A [[ParentContainerUpdater]] is used by a Parquet converter to set converted values to some
Expand Down Expand Up @@ -207,7 +208,7 @@ private[parquet] class ParquetRowConverter(
private[this] val fieldConverters: Array[Converter with HasParentContainerUpdater] = {
// (SPARK-31116) Use case insensitive map if spark.sql.caseSensitive is false
// to prevent throwing IllegalArgumentException when searching catalyst type's field index
def nameToIndex: Map[String, Int] = catalystType.fieldNames.zipWithIndex.toMap
def nameToIndex: Map[String, Int] = Utils.toMapWithIndex(catalystType.fieldNames)

val catalystFieldIdxByName = if (SQLConf.get.caseSensitiveAnalysis) {
nameToIndex
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.collection.Utils

object StatFunctions extends Logging {

Expand Down Expand Up @@ -198,7 +199,7 @@ object StatFunctions extends Logging {
}
// get the distinct sorted values of column 2, so that we can make them the column names
val distinctCol2: Map[Any, Int] =
counts.map(e => cleanElement(e.get(1))).distinct.sorted.zipWithIndex.toMap
Utils.toMapWithIndex(counts.map(e => cleanElement(e.get(1))).distinct.sorted)
val columnSize = distinctCol2.size
require(columnSize < 1e4, s"The number of distinct values for $col2, can't " +
s"exceed 1e4. Currently $columnSize")
Expand Down