Skip to content

Commit 2e3453d

Browse files
committed
Hive module.
1 parent c589fda commit 2e3453d

File tree

13 files changed

+81
-43
lines changed

13 files changed

+81
-43
lines changed

sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,11 @@ import scala.reflect.ClassTag
2222

2323
import com.fasterxml.jackson.core.JsonFactory
2424

25+
import org.apache.spark.annotation.Experimental
2526
import org.apache.spark.rdd.RDD
2627
import org.apache.spark.storage.StorageLevel
2728
import org.apache.spark.sql.catalyst.ScalaReflection
29+
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
2830
import org.apache.spark.sql.catalyst.expressions._
2931
import org.apache.spark.sql.catalyst.expressions.{Literal => LiteralExpr}
3032
import org.apache.spark.sql.catalyst.plans.{JoinType, Inner}
@@ -170,9 +172,13 @@ class DataFrame(
170172

171173
override def unionAll(other: DataFrame): DataFrame = Union(logicalPlan, other.logicalPlan)
172174

173-
def intersect(other: DataFrame): DataFrame = Intersect(logicalPlan, other.logicalPlan)
175+
override def intersect(other: DataFrame): DataFrame = Intersect(logicalPlan, other.logicalPlan)
174176

175-
def except(other: DataFrame): DataFrame = Except(logicalPlan, other.logicalPlan)
177+
override def except(other: DataFrame): DataFrame = Except(logicalPlan, other.logicalPlan)
178+
179+
override def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = {
180+
Sample(fraction, withReplacement, seed, logicalPlan)
181+
}
176182

177183
/////////////////////////////////////////////////////////////////////////////
178184

@@ -238,6 +244,18 @@ class DataFrame(
238244
sqlContext.executePlan(WriteToFile(path, logicalPlan)).toRdd
239245
}
240246

247+
@Experimental
248+
override def saveAsTable(tableName: String): Unit = {
249+
sqlContext.executePlan(
250+
CreateTableAsSelect(None, tableName, logicalPlan, allowExisting = false)).toRdd
251+
}
252+
253+
@Experimental
254+
override def insertInto(tableName: String, overwrite: Boolean): Unit = {
255+
sqlContext.executePlan(InsertIntoTable(UnresolvedRelation(Seq(tableName)),
256+
Map.empty, logicalPlan, overwrite)).toRdd
257+
}
258+
241259
override def toJSON: RDD[String] = {
242260
val rowSchema = this.schema
243261
this.mapPartitions { iter =>

sql/core/src/main/scala/org/apache/spark/sql/api.scala

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,11 @@ package org.apache.spark.sql
1919

2020
import scala.reflect.ClassTag
2121

22+
import org.apache.spark.annotation.Experimental
2223
import org.apache.spark.rdd.RDD
2324
import org.apache.spark.sql.types.{DataType, StructType}
2425
import org.apache.spark.storage.StorageLevel
26+
import org.apache.spark.util.Utils
2527

2628

2729
trait RDDApi[T] {
@@ -129,6 +131,12 @@ trait DataFrameSpecificApi {
129131

130132
def except(other: DataFrame): DataFrame
131133

134+
def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame
135+
136+
def sample(withReplacement: Boolean, fraction: Double): DataFrame = {
137+
sample(withReplacement, fraction, Utils.random.nextLong)
138+
}
139+
132140
/////////////////////////////////////////////////////////////////////////////
133141
// Column mutation
134142
/////////////////////////////////////////////////////////////////////////////
@@ -144,11 +152,20 @@ trait DataFrameSpecificApi {
144152

145153
def rdd: RDD[Row]
146154

155+
def toJSON: RDD[String]
156+
147157
def registerTempTable(tableName: String): Unit
148158

149159
def saveAsParquetFile(path: String): Unit
150160

151-
def toJSON: RDD[String]
161+
@Experimental
162+
def saveAsTable(tableName: String): Unit
163+
164+
@Experimental
165+
def insertInto(tableName: String, overwrite: Boolean): Unit
166+
167+
@Experimental
168+
def insertInto(tableName: String): Unit = insertInto(tableName, overwrite = false)
152169

153170
/////////////////////////////////////////////////////////////////////////////
154171
// Stat functions

sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import java.io._
2323
import java.util.{ArrayList => JArrayList}
2424

2525
import jline.{ConsoleReader, History}
26+
2627
import org.apache.commons.lang.StringUtils
2728
import org.apache.commons.logging.LogFactory
2829
import org.apache.hadoop.conf.Configuration
@@ -39,7 +40,6 @@ import org.apache.thrift.transport.TSocket
3940

4041
import org.apache.spark.Logging
4142
import org.apache.spark.sql.hive.HiveShim
42-
import org.apache.spark.sql.hive.thriftserver.HiveThriftServerShim
4343

4444
private[hive] object SparkSQLCLIDriver {
4545
private var prompt = "spark-sql"

sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ import org.apache.hive.service.cli.operation.ExecuteStatementOperation
3232
import org.apache.hive.service.cli.session.HiveSession
3333

3434
import org.apache.spark.Logging
35-
import org.apache.spark.sql.{SQLConf, SchemaRDD, Row => SparkRow}
35+
import org.apache.spark.sql.{DataFrame, SQLConf, Row => SparkRow}
3636
import org.apache.spark.sql.execution.SetCommand
3737
import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._
3838
import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes}
@@ -71,7 +71,7 @@ private[hive] class SparkExecuteStatementOperation(
7171
sessionToActivePool: SMap[SessionHandle, String])
7272
extends ExecuteStatementOperation(parentSession, statement, confOverlay) with Logging {
7373

74-
private var result: SchemaRDD = _
74+
private var result: DataFrame = _
7575
private var iter: Iterator[SparkRow] = _
7676
private var dataTypes: Array[DataType] = _
7777

@@ -202,7 +202,7 @@ private[hive] class SparkExecuteStatementOperation(
202202
val useIncrementalCollect =
203203
hiveContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean
204204
if (useIncrementalCollect) {
205-
result.toLocalIterator
205+
result.rdd.toLocalIterator
206206
} else {
207207
result.collect().iterator
208208
}

sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ import org.apache.hive.service.cli.operation.ExecuteStatementOperation
3030
import org.apache.hive.service.cli.session.HiveSession
3131

3232
import org.apache.spark.Logging
33-
import org.apache.spark.sql.{Row => SparkRow, SQLConf, SchemaRDD}
33+
import org.apache.spark.sql.{DataFrame, Row => SparkRow, SQLConf}
3434
import org.apache.spark.sql.execution.SetCommand
3535
import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._
3636
import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes}
@@ -72,7 +72,7 @@ private[hive] class SparkExecuteStatementOperation(
7272
// NOTE: `runInBackground` is set to `false` intentionally to disable asynchronous execution
7373
extends ExecuteStatementOperation(parentSession, statement, confOverlay, false) with Logging {
7474

75-
private var result: SchemaRDD = _
75+
private var result: DataFrame = _
7676
private var iter: Iterator[SparkRow] = _
7777
private var dataTypes: Array[DataType] = _
7878

@@ -173,7 +173,7 @@ private[hive] class SparkExecuteStatementOperation(
173173
val useIncrementalCollect =
174174
hiveContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean
175175
if (useIncrementalCollect) {
176-
result.toLocalIterator
176+
result.rdd.toLocalIterator
177177
} else {
178178
result.collect().iterator
179179
}

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,15 +64,15 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
6464
getConf("spark.sql.hive.convertMetastoreParquet", "true") == "true"
6565

6666
override protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution =
67-
new this.QueryExecution { val logical = plan }
67+
new this.QueryExecution(plan)
6868

69-
override def sql(sqlText: String): SchemaRDD = {
69+
override def sql(sqlText: String): DataFrame = {
7070
val substituted = new VariableSubstitution().substitute(hiveconf, sqlText)
7171
// TODO: Create a framework for registering parsers instead of just hardcoding if statements.
7272
if (conf.dialect == "sql") {
7373
super.sql(substituted)
7474
} else if (conf.dialect == "hiveql") {
75-
new SchemaRDD(this, ddlParser(sqlText, false).getOrElse(HiveQl.parseSql(substituted)))
75+
new DataFrame(this, ddlParser(sqlText, false).getOrElse(HiveQl.parseSql(substituted)))
7676
} else {
7777
sys.error(s"Unsupported SQL dialect: ${conf.dialect}. Try 'sql' or 'hiveql'")
7878
}
@@ -352,7 +352,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
352352
override protected[sql] val planner = hivePlanner
353353

354354
/** Extends QueryExecution with hive specific features. */
355-
protected[sql] abstract class QueryExecution extends super.QueryExecution {
355+
protected[sql] class QueryExecution(logicalPlan: LogicalPlan)
356+
extends super.QueryExecution(logicalPlan) {
356357

357358
/**
358359
* Returns the result as a hive compatible sequence of strings. For native commands, the

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.sql.hive
2020
import scala.collection.JavaConversions._
2121

2222
import org.apache.spark.annotation.Experimental
23-
import org.apache.spark.sql.{SQLContext, SchemaRDD, Strategy}
23+
import org.apache.spark.sql.{Column, DataFrame, SQLContext, Strategy}
2424
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
2525
import org.apache.spark.sql.catalyst.expressions._
2626
import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate
@@ -55,16 +55,15 @@ private[hive] trait HiveStrategies {
5555
*/
5656
@Experimental
5757
object ParquetConversion extends Strategy {
58-
implicit class LogicalPlanHacks(s: SchemaRDD) {
59-
def lowerCase =
60-
new SchemaRDD(s.sqlContext, s.logicalPlan)
58+
implicit class LogicalPlanHacks(s: DataFrame) {
59+
def lowerCase = new DataFrame(s.sqlContext, s.logicalPlan)
6160

6261
def addPartitioningAttributes(attrs: Seq[Attribute]) = {
6362
// Don't add the partitioning key if its already present in the data.
6463
if (attrs.map(_.name).toSet.subsetOf(s.logicalPlan.output.map(_.name).toSet)) {
6564
s
6665
} else {
67-
new SchemaRDD(
66+
new DataFrame(
6867
s.sqlContext,
6968
s.logicalPlan transform {
7069
case p: ParquetRelation => p.copy(partitioningAttributes = attrs)
@@ -97,13 +96,13 @@ private[hive] trait HiveStrategies {
9796
// We are going to throw the predicates and projection back at the whole optimization
9897
// sequence so lets unresolve all the attributes, allowing them to be rebound to the
9998
// matching parquet attributes.
100-
val unresolvedOtherPredicates = otherPredicates.map(_ transform {
99+
val unresolvedOtherPredicates = new Column(otherPredicates.map(_ transform {
101100
case a: AttributeReference => UnresolvedAttribute(a.name)
102-
}).reduceOption(And).getOrElse(Literal(true))
101+
}).reduceOption(And).getOrElse(Literal(true)))
103102

104-
val unresolvedProjection = projectList.map(_ transform {
103+
val unresolvedProjection: Seq[Column] = projectList.map(_ transform {
105104
case a: AttributeReference => UnresolvedAttribute(a.name)
106-
})
105+
}).map(new Column(_))
107106

108107
try {
109108
if (relation.hiveQlTable.isPartitioned) {

sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) {
9999
override def runSqlHive(sql: String): Seq[String] = super.runSqlHive(rewritePaths(sql))
100100

101101
override def executePlan(plan: LogicalPlan): this.QueryExecution =
102-
new this.QueryExecution { val logical = plan }
102+
new this.QueryExecution(plan)
103103

104104
/** Fewer partitions to speed up testing. */
105105
protected[sql] override lazy val conf: SQLConf = new SQLConf {
@@ -150,16 +150,17 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) {
150150

151151
val describedTable = "DESCRIBE (\\w+)".r
152152

153-
protected[hive] class HiveQLQueryExecution(hql: String) extends this.QueryExecution {
154-
lazy val logical = HiveQl.parseSql(hql)
153+
protected[hive] class HiveQLQueryExecution(hql: String)
154+
extends this.QueryExecution(HiveQl.parseSql(hql)) {
155155
def hiveExec() = runSqlHive(hql)
156156
override def toString = hql + "\n" + super.toString
157157
}
158158

159159
/**
160160
* Override QueryExecution with special debug workflow.
161161
*/
162-
abstract class QueryExecution extends super.QueryExecution {
162+
class QueryExecution(logicalPlan: LogicalPlan)
163+
extends super.QueryExecution(logicalPlan) {
163164
override lazy val analyzed = {
164165
val describedTables = logical match {
165166
case HiveNativeCommand(describedTable(tbl)) => tbl :: Nil

sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,12 @@ class QueryTest extends PlanTest {
3636
/**
3737
* Runs the plan and makes sure the answer contains all of the keywords, or the
3838
* none of keywords are listed in the answer
39-
* @param rdd the [[SchemaRDD]] to be executed
39+
* @param rdd the [[DataFrame]] to be executed
4040
* @param exists true for make sure the keywords are listed in the output, otherwise
4141
* to make sure none of the keyword are not listed in the output
4242
* @param keywords keyword in string array
4343
*/
44-
def checkExistence(rdd: SchemaRDD, exists: Boolean, keywords: String*) {
44+
def checkExistence(rdd: DataFrame, exists: Boolean, keywords: String*) {
4545
val outputs = rdd.collect().map(_.mkString).mkString
4646
for (key <- keywords) {
4747
if (exists) {
@@ -54,10 +54,10 @@ class QueryTest extends PlanTest {
5454

5555
/**
5656
* Runs the plan and makes sure the answer matches the expected result.
57-
* @param rdd the [[SchemaRDD]] to be executed
57+
* @param rdd the [[DataFrame]] to be executed
5858
* @param expectedAnswer the expected result, can either be an Any, Seq[Product], or Seq[ Seq[Any] ].
5959
*/
60-
protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Seq[Row]): Unit = {
60+
protected def checkAnswer(rdd: DataFrame, expectedAnswer: Seq[Row]): Unit = {
6161
val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty
6262
def prepareAnswer(answer: Seq[Row]): Seq[Row] = {
6363
// Converts data to types that we can do equality comparison using Scala collections.
@@ -101,7 +101,7 @@ class QueryTest extends PlanTest {
101101
}
102102
}
103103

104-
protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Row): Unit = {
104+
protected def checkAnswer(rdd: DataFrame, expectedAnswer: Row): Unit = {
105105
checkAnswer(rdd, Seq(expectedAnswer))
106106
}
107107

sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,15 @@ package org.apache.spark.sql.hive
2020
import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation}
2121
import org.apache.spark.sql.hive.test.TestHive
2222
import org.apache.spark.sql.hive.test.TestHive._
23-
import org.apache.spark.sql.{QueryTest, SchemaRDD}
23+
import org.apache.spark.sql.{DataFrame, QueryTest}
2424
import org.apache.spark.storage.RDDBlockId
2525

2626
class CachedTableSuite extends QueryTest {
2727
/**
2828
* Throws a test failed exception when the number of cached tables differs from the expected
2929
* number.
3030
*/
31-
def assertCached(query: SchemaRDD, numCachedTables: Int = 1): Unit = {
31+
def assertCached(query: DataFrame, numCachedTables: Int = 1): Unit = {
3232
val planWithCaching = query.queryExecution.withCachedData
3333
val cachedData = planWithCaching collect {
3434
case cached: InMemoryRelation => cached

0 commit comments

Comments
 (0)