Skip to content
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
SimplifyCasts,
SimplifyCaseConversionExpressions,
RewriteCorrelatedScalarSubquery,
EliminateSerialization) ::
EliminateSerialization,
RemoveAliasOnlyProject) ::
Batch("Decimal Optimizations", fixedPoint,
DecimalAggregates) ::
Batch("Typed Filter Optimization", fixedPoint,
Expand Down Expand Up @@ -155,6 +156,59 @@ object SamplePushDown extends Rule[LogicalPlan] {
}
}

/**
* Removes the Project only conducting Alias of its child node.
* It is created mainly for removing extra Project added in EliminateSerialization rule,
* but can also benefit other operators.
*/
object RemoveAliasOnlyProject extends Rule[LogicalPlan] {
// Check if projectList in the Project node has the same attribute names and ordering
// as its child node.
private def checkAliasOnly(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: isAliasOnly

projectList: Seq[NamedExpression],
childOutput: Seq[Attribute]): Boolean = {
if (!projectList.forall(_.isInstanceOf[Alias]) || projectList.length != childOutput.length) {
return false
} else {
projectList.map(_.asInstanceOf[Alias]).zip(childOutput).forall { case (a, o) =>
a.child match {
case attr: Attribute
Copy link
Contributor

@cloud-fan cloud-fan May 11, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't it just a semanticEquals attr?

if a.name == attr.name && attr.name == o.name && attr.dataType == o.dataType
&& attr.exprId == o.exprId =>
true
case _ => false
}
}
}
}

def apply(plan: LogicalPlan): LogicalPlan = {
val processedPlan = plan.find { p =>
p match {
case Project(pList, child) if checkAliasOnly(pList, child.output) => true
case _ => false
}
}.map { case p: Project =>
val attrMap = p.projectList.map { a =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can use AttributeMap here

val alias = a.asInstanceOf[Alias]
val replaceFrom = alias.toAttribute.exprId
val replaceTo = alias.child.asInstanceOf[Attribute]
(replaceFrom, replaceTo)
}.toMap
plan.transformAllExpressions {
case a: Attribute if attrMap.contains(a.exprId) => attrMap(a.exprId)
}.transform {
case op: Project if op == p => op.child
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use eq to compare reference, which is safer.

}
}
if (processedPlan.isDefined) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

code style suggestion:

val aliasOnlyProject = ...
if (aliasOnlyProject.isDefined) {
    val p = aliasOnlyProject.get.asInstanceOf[Project]
    ...
} else {
  plan
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Why not processedPlan.getOrElse(plan)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

processedPlan.get
} else {
plan
}
}
}

/**
* Removes cases where we are unnecessarily going between the object and serialized (InternalRow)
* representation of data item. For example back to back map operations.
Expand All @@ -163,15 +217,10 @@ object EliminateSerialization extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case d @ DeserializeToObject(_, _, s: SerializeFromObject)
if d.outputObjectType == s.inputObjectType =>
// A workaround for SPARK-14803. Remove this after it is fixed.
if (d.outputObjectType.isInstanceOf[ObjectType] &&
d.outputObjectType.asInstanceOf[ObjectType].cls == classOf[org.apache.spark.sql.Row]) {
s.child
} else {
// Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`.
val objAttr = Alias(s.child.output.head, "obj")(exprId = d.output.head.exprId)
Project(objAttr :: Nil, s.child)
}
// Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`.
// We will remove it later in RemoveAliasOnlyProject rule.
val objAttr = Alias(s.child.output.head, "obj")(exprId = d.output.head.exprId)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use Alias(s.child.output.head, s.child.output.head.name)(exprId = d.output.head.exprId) to make sure the alias name is same with the attribute name

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok. update later.

Project(objAttr :: Nil, s.child)
case a @ AppendColumns(_, _, _, s: SerializeFromObject)
if a.deserializer.dataType == s.inputObjectType =>
AppendColumnsWithObject(a.func, s.serializer, a.serializer, s.child)
Expand Down
12 changes: 12 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
checkDataset(DatasetTransform.addOne(dataset), 2, 3, 4)
}

test("dataset.rdd with generic case class") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this test for?

Copy link
Member Author

@viirya viirya May 12, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The failed codes in jira SPARK-15094.

val ds = Seq(Generic(1, 1.0), Generic(2, 2.0)).toDS
val ds2 = ds.map(g => Generic(g.id, g.value))
ds.rdd.map(r => r.id).count
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's better to use checkDataset to check the answer

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because it failed on operations of rdd. I will check counting results instead.

ds2.rdd.map(r => r.id).count

val ds3 = ds.map(g => new java.lang.Long(g.id))
ds3.rdd.map(r => r).count
}

test("runtime null check for RowEncoder") {
val schema = new StructType().add("i", IntegerType, nullable = false)
val df = sqlContext.range(10).map(l => {
Expand All @@ -676,6 +686,8 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
}
}

case class Generic[T](id: T, value: Double)

case class OtherTuple(_1: String, _2: Int)

case class TupleClass(data: (Int, String))
Expand Down