diff --git a/spark/pom.xml b/spark/pom.xml index 0475515305e..71fb8a69550 100644 --- a/spark/pom.xml +++ b/spark/pom.xml @@ -367,6 +367,14 @@ 1.1 + + + org.scalatest + scalatest_${scala.binary.version} + 2.2.4 + test + + junit junit @@ -760,6 +768,34 @@ + + + org.scala-tools + maven-scala-plugin + + + compile + + compile + + compile + + + test-compile + + testCompile + + test-compile + + + process-resources + + compile + + + + + diff --git a/spark/src/main/java/org/apache/zeppelin/spark/SparkInterpreter.java b/spark/src/main/java/org/apache/zeppelin/spark/SparkInterpreter.java index 935b2a59c19..1c4c5e7c9cd 100644 --- a/spark/src/main/java/org/apache/zeppelin/spark/SparkInterpreter.java +++ b/spark/src/main/java/org/apache/zeppelin/spark/SparkInterpreter.java @@ -461,6 +461,14 @@ public void open() { intp.interpret("import org.apache.spark.sql.functions._"); } + // Utility functions for display + intp.interpret("import org.apache.zeppelin.spark.utils.DisplayUtils._"); + + // Scala implicit value for spark.maxResult + intp.interpret("import org.apache.zeppelin.spark.utils.SparkMaxResult"); + intp.interpret("implicit val sparkMaxResult = new SparkMaxResult(" + + Integer.parseInt(getProperty("zeppelin.spark.maxResult")) + ")"); + // add jar if (depInterpreter != null) { DependencyContext depc = depInterpreter.getDependencyContext(); diff --git a/spark/src/main/scala/org/apache/zeppelin/spark/utils/DisplayUtils.scala b/spark/src/main/scala/org/apache/zeppelin/spark/utils/DisplayUtils.scala new file mode 100644 index 00000000000..81814349c18 --- /dev/null +++ b/spark/src/main/scala/org/apache/zeppelin/spark/utils/DisplayUtils.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.zeppelin.spark.utils + +import java.lang.StringBuilder + +import org.apache.spark.rdd.RDD + +import scala.collection.IterableLike + +object DisplayUtils { + + implicit def toDisplayRDDFunctions[T <: Product](rdd: RDD[T]): DisplayRDDFunctions[T] = new DisplayRDDFunctions[T](rdd) + + implicit def toDisplayTraversableFunctions[T <: Product](traversable: Traversable[T]): DisplayTraversableFunctions[T] = new DisplayTraversableFunctions[T](traversable) + + def html(htmlContent: String = "") = s"%html $htmlContent" + + def img64(base64Content: String = "") = s"%img $base64Content" + + def img(url: String) = s"" +} + +trait DisplayCollection[T <: Product] { + + def printFormattedData(traversable: Traversable[T], columnLabels: String*): Unit = { + val providedLabelCount: Int = columnLabels.size + var maxColumnCount:Int = 1 + val headers = new StringBuilder("%table ") + + val data = new StringBuilder("") + + traversable.foreach(tuple => { + maxColumnCount = math.max(maxColumnCount,tuple.productArity) + data.append(tuple.productIterator.mkString("\t")).append("\n") + }) + + if (providedLabelCount > maxColumnCount) { + headers.append(columnLabels.take(maxColumnCount).mkString("\t")).append("\n") + } else if (providedLabelCount < maxColumnCount) { + val missingColumnHeaders = ((providedLabelCount+1) to maxColumnCount).foldLeft[String](""){ + (stringAccumulator,index) => if (index==1) s"Column$index" else s"$stringAccumulator\tColumn$index" + } + + headers.append(columnLabels.mkString("\t")).append(missingColumnHeaders).append("\n") + } else { + headers.append(columnLabels.mkString("\t")).append("\n") + } + + headers.append(data) + + print(headers.toString) + } + +} + +class DisplayRDDFunctions[T <: Product] (val rdd: RDD[T]) extends DisplayCollection[T] { + + def display(columnLabels: String*)(implicit sparkMaxResult: SparkMaxResult): Unit = { + printFormattedData(rdd.take(sparkMaxResult.maxResult), columnLabels: _*) + } + + def display(sparkMaxResult:Int, columnLabels: String*): Unit = { + printFormattedData(rdd.take(sparkMaxResult), columnLabels: _*) + } +} + +class DisplayTraversableFunctions[T <: Product] (val traversable: Traversable[T]) extends DisplayCollection[T] { + + def display(columnLabels: String*): Unit = { + printFormattedData(traversable, columnLabels: _*) + } +} + +class SparkMaxResult(val maxResult: Int) extends Serializable diff --git a/spark/src/test/scala/org/apache/zeppelin/spark/utils/DisplayFunctionsTest.scala b/spark/src/test/scala/org/apache/zeppelin/spark/utils/DisplayFunctionsTest.scala new file mode 100644 index 00000000000..2638f1710e9 --- /dev/null +++ b/spark/src/test/scala/org/apache/zeppelin/spark/utils/DisplayFunctionsTest.scala @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.zeppelin.spark.utils + +import java.io.ByteArrayOutputStream + +import org.apache.spark.rdd.RDD +import org.apache.spark.{SparkContext, SparkConf} +import org.scalatest._ +import org.scalatest.{BeforeAndAfter} + +case class Person(login : String, name: String, age: Int) + +class DisplayFunctionsTest extends FlatSpec with BeforeAndAfter with BeforeAndAfterEach with Matchers { + var sc: SparkContext = null + var testTuples:List[(String, String, Int)] = null + var testPersons:List[Person] = null + var testRDDTuples: RDD[(String,String,Int)] = null + var testRDDPersons: RDD[Person] = null + var stream: ByteArrayOutputStream = null + + before { + val sparkConf: SparkConf = new SparkConf(true) + .setAppName("test-DisplayFunctions") + .setMaster("local") + sc = new SparkContext(sparkConf) + testTuples = List(("jdoe", "John DOE", 32), ("hsue", "Helen SUE", 27), ("rsmith", "Richard SMITH", 45)) + testRDDTuples = sc.parallelize(testTuples) + testPersons = List(Person("jdoe", "John DOE", 32), Person("hsue", "Helen SUE", 27), Person("rsmith", "Richard SMITH", 45)) + testRDDPersons = sc.parallelize(testPersons) + } + + override def beforeEach() { + stream = new java.io.ByteArrayOutputStream() + super.beforeEach() // To be stackable, must call super.beforeEach + } + + + "DisplayFunctions" should "generate correct column headers for tuples" in { + implicit val sparkMaxResult = new SparkMaxResult(100) + Console.withOut(stream) { + new DisplayRDDFunctions[(String,String,Int)](testRDDTuples).display("Login","Name","Age") + } + + stream.toString("UTF-8") should be("%table Login\tName\tAge\n" + + "jdoe\tJohn DOE\t32\n" + + "hsue\tHelen SUE\t27\n" + + "rsmith\tRichard SMITH\t45\n") + } + + "DisplayFunctions" should "generate correct column headers for case class" in { + implicit val sparkMaxResult = new SparkMaxResult(100) + Console.withOut(stream) { + new DisplayRDDFunctions[Person](testRDDPersons).display("Login","Name","Age") + } + + stream.toString("UTF-8") should be("%table Login\tName\tAge\n" + + "jdoe\tJohn DOE\t32\n" + + "hsue\tHelen SUE\t27\n" + + "rsmith\tRichard SMITH\t45\n") + } + + "DisplayFunctions" should "truncate exceeding column headers for tuples" in { + implicit val sparkMaxResult = new SparkMaxResult(100) + Console.withOut(stream) { + new DisplayRDDFunctions[(String,String,Int)](testRDDTuples).display("Login","Name","Age","xxx","yyy") + } + + stream.toString("UTF-8") should be("%table Login\tName\tAge\n" + + "jdoe\tJohn DOE\t32\n" + + "hsue\tHelen SUE\t27\n" + + "rsmith\tRichard SMITH\t45\n") + } + + "DisplayFunctions" should "pad missing column headers with ColumnXXX for tuples" in { + implicit val sparkMaxResult = new SparkMaxResult(100) + Console.withOut(stream) { + new DisplayRDDFunctions[(String,String,Int)](testRDDTuples).display("Login") + } + + stream.toString("UTF-8") should be("%table Login\tColumn2\tColumn3\n" + + "jdoe\tJohn DOE\t32\n" + + "hsue\tHelen SUE\t27\n" + + "rsmith\tRichard SMITH\t45\n") + } + + "DisplayUtils" should "restricts RDD to sparkMaxresult with implicit limit" in { + + implicit val sparkMaxResult = new SparkMaxResult(2) + + Console.withOut(stream) { + new DisplayRDDFunctions[(String,String,Int)](testRDDTuples).display("Login") + } + + stream.toString("UTF-8") should be("%table Login\tColumn2\tColumn3\n" + + "jdoe\tJohn DOE\t32\n" + + "hsue\tHelen SUE\t27\n") + } + + "DisplayUtils" should "restricts RDD to sparkMaxresult with explicit limit" in { + + implicit val sparkMaxResult = new SparkMaxResult(2) + + Console.withOut(stream) { + new DisplayRDDFunctions[(String,String,Int)](testRDDTuples).display(1,"Login") + } + + stream.toString("UTF-8") should be("%table Login\tColumn2\tColumn3\n" + + "jdoe\tJohn DOE\t32\n") + } + + "DisplayFunctions" should "display traversable of tuples" in { + + Console.withOut(stream) { + new DisplayTraversableFunctions[(String,String,Int)](testTuples).display("Login","Name","Age") + } + + stream.toString("UTF-8") should be("%table Login\tName\tAge\n" + + "jdoe\tJohn DOE\t32\n" + + "hsue\tHelen SUE\t27\n" + + "rsmith\tRichard SMITH\t45\n") + } + + "DisplayFunctions" should "display traversable of case class" in { + + Console.withOut(stream) { + new DisplayTraversableFunctions[Person](testPersons).display("Login","Name","Age") + } + + stream.toString("UTF-8") should be("%table Login\tName\tAge\n" + + "jdoe\tJohn DOE\t32\n" + + "hsue\tHelen SUE\t27\n" + + "rsmith\tRichard SMITH\t45\n") + } + + "DisplayUtils" should "display HTML" in { + DisplayUtils.html() should be ("%html ") + DisplayUtils.html("test") should be ("%html test") + } + + "DisplayUtils" should "display img" in { + DisplayUtils.img("http://www.google.com") should be ("") + DisplayUtils.img64() should be ("%img ") + DisplayUtils.img64("abcde") should be ("%img abcde") + } + + override def afterEach() { + try super.afterEach() // To be stackable, must call super.afterEach + stream = null + } + + after { + sc.stop() + } + + +} + + diff --git a/zeppelin-interpreter/pom.xml b/zeppelin-interpreter/pom.xml index 980fe4ac568..c2a67524a95 100644 --- a/zeppelin-interpreter/pom.xml +++ b/zeppelin-interpreter/pom.xml @@ -36,6 +36,7 @@ http://zeppelin.incubator.apache.org + org.apache.thrift libthrift