Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,8 @@ sorttable = {

hasInputs = (typeof node.getElementsByTagName == 'function') &&
node.getElementsByTagName('input').length;
if (node.getAttribute("sorttable_customkey") != null) {

if (node.nodeType == 1 && node.getAttribute("sorttable_customkey") != null) {
return node.getAttribute("sorttable_customkey");
}
else if (typeof node.textContent != 'undefined' && !hasInputs) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,11 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app")

<div class="row-fluid"> <!-- Executors -->
<div class="span12">
<h4> Executor Summary </h4>
<h4> Executor Summary ({allExecutors.length}) </h4>
{executorsTable}
{
if (removedExecutors.nonEmpty) {
<h4> Removed Executors </h4> ++
<h4> Removed Executors ({removedExecutors.length}) </h4> ++
removedExecutorsTable
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,14 +128,14 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {

<div class="row-fluid">
<div class="span12">
<h4> Workers </h4>
<h4> Workers ({workers.length}) </h4>
{workerTable}
</div>
</div>

<div class="row-fluid">
<div class="span12">
<h4 id="running-app"> Running Applications </h4>
<h4 id="running-app"> Running Applications ({activeApps.length}) </h4>
{activeAppsTable}
</div>
</div>
Expand All @@ -144,7 +144,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
{if (hasDrivers) {
<div class="row-fluid">
<div class="span12">
<h4> Running Drivers </h4>
<h4> Running Drivers ({activeDrivers.length}) </h4>
{activeDriversTable}
</div>
</div>
Expand All @@ -154,7 +154,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {

<div class="row-fluid">
<div class="span12">
<h4 id="completed-app"> Completed Applications </h4>
<h4 id="completed-app"> Completed Applications ({completedApps.length}) </h4>
{completedAppsTable}
</div>
</div>
Expand All @@ -164,7 +164,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
if (hasDrivers) {
<div class="row-fluid">
<div class="span12">
<h4> Completed Drivers </h4>
<h4> Completed Drivers ({completedDrivers.length}) </h4>
{completedDriversTable}
</div>
</div>
Expand Down
13 changes: 8 additions & 5 deletions core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,11 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable {
* Returns the name of this accumulator, can only be called after registration.
*/
final def name: Option[String] = {
assertMetadataNotNull()

if (atDriverSide) {
AccumulatorContext.get(id).flatMap(_.metadata.name)
metadata.name.orElse(AccumulatorContext.get(id).flatMap(_.metadata.name))
} else {
assertMetadataNotNull()
metadata.name
}
}
Expand Down Expand Up @@ -165,13 +166,15 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable {
}
val copyAcc = copyAndReset()
assert(copyAcc.isZero, "copyAndReset must return a zero value copy")
val isInternalAcc =
(name.isDefined && name.get.startsWith(InternalAccumulator.METRICS_PREFIX)) ||
getClass.getSimpleName == "SQLMetric"
val isInternalAcc = name.isDefined && name.get.startsWith(InternalAccumulator.METRICS_PREFIX)
if (isInternalAcc) {
// Do not serialize the name of internal accumulator and send it to executor.
copyAcc.metadata = metadata.copy(name = None)
} else {
// For non-internal accumulators, we still need to send the name because users may need to
// access the accumulator name at executor side, or they may keep the accumulators sent from
// executors and access the name when the registered accumulator is already garbage
// collected(e.g. SQLMetrics).
copyAcc.metadata = metadata
}
copyAcc
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
final val family: Param[String] = new Param(this, "family",
"The name of family which is a description of the label distribution to be used in the " +
s"model. Supported options: ${supportedFamilyNames.mkString(", ")}.",
ParamValidators.inArray[String](supportedFamilyNames))
(value: String) => supportedFamilyNames.contains(value.toLowerCase(Locale.ROOT)))

/** @group getParam */
@Since("2.1.0")
Expand Down Expand Up @@ -526,7 +526,7 @@ class LogisticRegression @Since("1.2.0") (
case None => histogram.length
}

val isMultinomial = $(family) match {
val isMultinomial = getFamily.toLowerCase(Locale.ROOT) match {
case "binomial" =>
require(numClasses == 1 || numClasses == 2, s"Binomial family only supports 1 or 2 " +
s"outcome classes but found $numClasses.")
Expand Down
30 changes: 15 additions & 15 deletions mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
@Since("1.6.0")
final val optimizer = new Param[String](this, "optimizer", "Optimizer or inference" +
" algorithm used to estimate the LDA model. Supported: " + supportedOptimizers.mkString(", "),
(o: String) =>
ParamValidators.inArray(supportedOptimizers).apply(o.toLowerCase(Locale.ROOT)))
(value: String) => supportedOptimizers.contains(value.toLowerCase(Locale.ROOT)))

/** @group getParam */
@Since("1.6.0")
Expand Down Expand Up @@ -325,7 +324,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
s" ${getDocConcentration.length}, but k = $getK. docConcentration must be an array of" +
s" length either 1 (scalar) or k (num topics).")
}
getOptimizer match {
getOptimizer.toLowerCase(Locale.ROOT) match {
case "online" =>
require(getDocConcentration.forall(_ >= 0),
"For Online LDA optimizer, docConcentration values must be >= 0. Found values: " +
Expand All @@ -337,7 +336,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
}
}
if (isSet(topicConcentration)) {
getOptimizer match {
getOptimizer.toLowerCase(Locale.ROOT) match {
case "online" =>
require(getTopicConcentration >= 0, s"For Online LDA optimizer, topicConcentration" +
s" must be >= 0. Found value: $getTopicConcentration")
Expand All @@ -350,17 +349,18 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
SchemaUtils.appendColumn(schema, $(topicDistributionCol), new VectorUDT)
}

private[clustering] def getOldOptimizer: OldLDAOptimizer = getOptimizer match {
case "online" =>
new OldOnlineLDAOptimizer()
.setTau0($(learningOffset))
.setKappa($(learningDecay))
.setMiniBatchFraction($(subsamplingRate))
.setOptimizeDocConcentration($(optimizeDocConcentration))
case "em" =>
new OldEMLDAOptimizer()
.setKeepLastCheckpoint($(keepLastCheckpoint))
}
private[clustering] def getOldOptimizer: OldLDAOptimizer =
getOptimizer.toLowerCase(Locale.ROOT) match {
case "online" =>
new OldOnlineLDAOptimizer()
.setTau0($(learningOffset))
.setKappa($(learningDecay))
.setMiniBatchFraction($(subsamplingRate))
.setOptimizeDocConcentration($(optimizeDocConcentration))
case "em" =>
new OldEMLDAOptimizer()
.setKeepLastCheckpoint($(keepLastCheckpoint))
}
}

private object LDAParams {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2582,6 +2582,17 @@ class LogisticRegressionSuite
assert(expected.coefficients.toArray === actual.coefficients.toArray)
}
}

test("string params should be case-insensitive") {
val lr = new LogisticRegression()
Seq(("AuTo", smallBinaryDataset), ("biNoMial", smallBinaryDataset),
("mulTinomIAl", smallMultinomialDataset)).foreach { case (family, data) =>
lr.setFamily(family)
assert(lr.getFamily === family)
val model = lr.fit(data)
assert(model.getFamily === family)
}
}
}

object LogisticRegressionSuite {
Expand Down
10 changes: 10 additions & 0 deletions mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -313,4 +313,14 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead

assert(model.getCheckpointFiles.isEmpty)
}

test("string params should be case-insensitive") {
val lda = new LDA()
Seq("eM", "oNLinE").foreach { optimizer =>
lda.setOptimizer(optimizer)
assert(lda.getOptimizer === optimizer)
val model = lda.fit(dataset)
assert(model.getOptimizer === optimizer)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ trait TimeZoneAwareExpression extends Expression {
/** Returns a copy of this expression with the specified timeZoneId. */
def withTimeZone(timeZoneId: String): TimeZoneAwareExpression

@transient lazy val timeZone: TimeZone = TimeZone.getTimeZone(timeZoneId.get)
@transient lazy val timeZone: TimeZone = DateTimeUtils.getTimeZone(timeZoneId.get)
}

/**
Expand Down Expand Up @@ -416,7 +416,7 @@ case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCa
override def dataType: DataType = IntegerType

@transient private lazy val c = {
val c = Calendar.getInstance(TimeZone.getTimeZone("UTC"))
val c = Calendar.getInstance(DateTimeUtils.getTimeZone("UTC"))
c.setFirstDayOfWeek(Calendar.MONDAY)
c.setMinimalDaysInFirstWeek(4)
c
Expand All @@ -431,9 +431,10 @@ case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCa
nullSafeCodeGen(ctx, ev, time => {
val cal = classOf[Calendar].getName
val c = ctx.freshName("cal")
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
ctx.addMutableState(cal, c,
s"""
$c = $cal.getInstance(java.util.TimeZone.getTimeZone("UTC"));
$c = $cal.getInstance($dtu.getTimeZone("UTC"));
$c.setFirstDayOfWeek($cal.MONDAY);
$c.setMinimalDaysInFirstWeek(4);
""")
Expand Down Expand Up @@ -954,8 +955,9 @@ case class FromUTCTimestamp(left: Expression, right: Expression)
val tzTerm = ctx.freshName("tz")
val utcTerm = ctx.freshName("utc")
val tzClass = classOf[TimeZone].getName
ctx.addMutableState(tzClass, tzTerm, s"""$tzTerm = $tzClass.getTimeZone("$tz");""")
ctx.addMutableState(tzClass, utcTerm, s"""$utcTerm = $tzClass.getTimeZone("UTC");""")
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
ctx.addMutableState(tzClass, tzTerm, s"""$tzTerm = $dtu.getTimeZone("$tz");""")
ctx.addMutableState(tzClass, utcTerm, s"""$utcTerm = $dtu.getTimeZone("UTC");""")
val eval = left.genCode(ctx)
ev.copy(code = s"""
|${eval.code}
Expand Down Expand Up @@ -1125,8 +1127,9 @@ case class ToUTCTimestamp(left: Expression, right: Expression)
val tzTerm = ctx.freshName("tz")
val utcTerm = ctx.freshName("utc")
val tzClass = classOf[TimeZone].getName
ctx.addMutableState(tzClass, tzTerm, s"""$tzTerm = $tzClass.getTimeZone("$tz");""")
ctx.addMutableState(tzClass, utcTerm, s"""$utcTerm = $tzClass.getTimeZone("UTC");""")
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
ctx.addMutableState(tzClass, tzTerm, s"""$tzTerm = $dtu.getTimeZone("$tz");""")
ctx.addMutableState(tzClass, utcTerm, s"""$utcTerm = $dtu.getTimeZone("UTC");""")
val eval = left.genCode(ctx)
ev.copy(code = s"""
|${eval.code}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ private[sql] class JSONOptions(
val columnNameOfCorruptRecord =
parameters.getOrElse("columnNameOfCorruptRecord", defaultColumnNameOfCorruptRecord)

val timeZone: TimeZone = TimeZone.getTimeZone(
val timeZone: TimeZone = DateTimeUtils.getTimeZone(
parameters.getOrElse(DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId))

// Uses `FastDateFormat` which can be direct replacement for `SimpleDateFormat` and thread-safe.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf)
RemoveRedundantProject,
SimplifyCreateStructOps,
SimplifyCreateArrayOps,
SimplifyCreateMapOps) ++
SimplifyCreateMapOps,
CombineConcats) ++
extendedOperatorOptimizationRules: _*) ::
Batch("Check Cartesian Products", Once,
CheckCartesianProducts(conf)) ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.optimizer

import scala.collection.immutable.HashSet
import scala.collection.mutable.{ArrayBuffer, Stack}

import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -543,3 +544,28 @@ object SimplifyCaseConversionExpressions extends Rule[LogicalPlan] {
}
}
}

/**
* Combine nested [[Concat]] expressions.
*/
object CombineConcats extends Rule[LogicalPlan] {

private def flattenConcats(concat: Concat): Concat = {
val stack = Stack[Expression](concat)
val flattened = ArrayBuffer.empty[Expression]
while (stack.nonEmpty) {
stack.pop() match {
case Concat(children) =>
stack.pushAll(children.reverse)
case child =>
flattened += child
}
}
Concat(flattened)
}

def apply(plan: LogicalPlan): LogicalPlan = plan.transformExpressionsDown {
case concat: Concat if concat.children.exists(_.isInstanceOf[Concat]) =>
flattenConcats(concat)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package org.apache.spark.sql.catalyst.optimizer

import java.util.TimeZone

import scala.collection.mutable

import org.apache.spark.sql.catalyst.catalog.SessionCatalog
Expand Down Expand Up @@ -55,7 +53,7 @@ object ComputeCurrentTime extends Rule[LogicalPlan] {
case CurrentDate(Some(timeZoneId)) =>
currentDates.getOrElseUpdate(timeZoneId, {
Literal.create(
DateTimeUtils.millisToDays(timestamp / 1000L, TimeZone.getTimeZone(timeZoneId)),
DateTimeUtils.millisToDays(timestamp / 1000L, DateTimeUtils.getTimeZone(timeZoneId)),
DateType)
})
case CurrentTimestamp() => currentTime
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ package org.apache.spark.sql.catalyst.util
import java.sql.{Date, Timestamp}
import java.text.{DateFormat, SimpleDateFormat}
import java.util.{Calendar, Locale, TimeZone}
import java.util.concurrent.ConcurrentHashMap
import java.util.function.{Function => JFunction}
import javax.xml.bind.DatatypeConverter

import scala.annotation.tailrec
Expand Down Expand Up @@ -98,6 +100,15 @@ object DateTimeUtils {
sdf
}

private val computedTimeZones = new ConcurrentHashMap[String, TimeZone]
private val computeTimeZone = new JFunction[String, TimeZone] {
override def apply(timeZoneId: String): TimeZone = TimeZone.getTimeZone(timeZoneId)
}

def getTimeZone(timeZoneId: String): TimeZone = {
computedTimeZones.computeIfAbsent(timeZoneId, computeTimeZone)
}

def newDateFormat(formatString: String, timeZone: TimeZone): DateFormat = {
val sdf = new SimpleDateFormat(formatString, Locale.US)
sdf.setTimeZone(timeZone)
Expand Down Expand Up @@ -407,7 +418,7 @@ object DateTimeUtils {
Calendar.getInstance(timeZone)
} else {
Calendar.getInstance(
TimeZone.getTimeZone(f"GMT${tz.get.toChar}${segments(7)}%02d:${segments(8)}%02d"))
getTimeZone(f"GMT${tz.get.toChar}${segments(7)}%02d:${segments(8)}%02d"))
}
c.set(Calendar.MILLISECOND, 0)

Expand Down Expand Up @@ -1027,15 +1038,15 @@ object DateTimeUtils {
* representation in their timezone.
*/
def fromUTCTime(time: SQLTimestamp, timeZone: String): SQLTimestamp = {
convertTz(time, TimeZoneGMT, TimeZone.getTimeZone(timeZone))
convertTz(time, TimeZoneGMT, getTimeZone(timeZone))
}

/**
* Returns a utc timestamp from a given timestamp from a given timezone, with the same
* string representation in their timezone.
*/
def toUTCTime(time: SQLTimestamp, timeZone: String): SQLTimestamp = {
convertTz(time, TimeZone.getTimeZone(timeZone), TimeZoneGMT)
convertTz(time, getTimeZone(timeZone), TimeZoneGMT)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac
val catalog = newBasicCatalog()
val tbl1 = catalog.getTable("db2", "tbl1")
val newSchema = StructType(Seq(
StructField("new_field_1", IntegerType),
StructField("col1", IntegerType),
StructField("new_field_2", StringType),
StructField("a", IntegerType),
StructField("b", StringType)))
Expand Down
Loading