diff --git a/spark/src/main/java/org/apache/zeppelin/spark/ZeppelinContext.java b/spark/src/main/java/org/apache/zeppelin/spark/ZeppelinContext.java index 7bccbac7d52..7465756e14f 100644 --- a/spark/src/main/java/org/apache/zeppelin/spark/ZeppelinContext.java +++ b/spark/src/main/java/org/apache/zeppelin/spark/ZeppelinContext.java @@ -24,6 +24,7 @@ import java.io.IOException; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; +import java.util.ArrayList; import java.util.Iterator; import java.util.LinkedList; import java.util.List; @@ -55,6 +56,7 @@ public class ZeppelinContext { private SparkDependencyResolver dep; private InterpreterContext interpreterContext; private int maxResult; + private List supportedClasses; public ZeppelinContext(SparkContext sc, SQLContext sql, InterpreterContext interpreterContext, @@ -65,6 +67,25 @@ public ZeppelinContext(SparkContext sc, SQLContext sql, this.interpreterContext = interpreterContext; this.dep = dep; this.maxResult = maxResult; + this.supportedClasses = new ArrayList<>(); + try { + supportedClasses.add(this.getClass().forName("org.apache.spark.sql.Dataset")); + } catch (ClassNotFoundException e) { + } + + try { + supportedClasses.add(this.getClass().forName("org.apache.spark.sql.DataFrame")); + } catch (ClassNotFoundException e) { + } + + try { + supportedClasses.add(this.getClass().forName("org.apache.spark.sql.SchemaRDD")); + } catch (ClassNotFoundException e) { + } + + if (supportedClasses.isEmpty()) { + throw new InterpreterException("Can not road Dataset/DataFrame/SchemaRDD class"); + } } public SparkContext sc; @@ -161,33 +182,8 @@ public void show(Object o) { @ZeppelinApi public void show(Object o, int maxResult) { - Class cls = null; try { - cls = this.getClass().forName("org.apache.spark.sql.Dataset"); - } catch (ClassNotFoundException e) { - } - - if (cls == null) { - try { - cls = this.getClass().forName("org.apache.spark.sql.DataFrame"); - } catch (ClassNotFoundException e) { - } - } - - if (cls == null) { - try { - cls = this.getClass().forName("org.apache.spark.sql.SchemaRDD"); - } catch (ClassNotFoundException e) { - } - } - - if (cls == null) { - throw new InterpreterException("Can not road Dataset/DataFrame/SchemaRDD class"); - } - - - try { - if (cls.isInstance(o)) { + if (supportedClasses.contains(o.getClass())) { interpreterContext.out.write(showDF(sc, interpreterContext, o, maxResult)); } else { interpreterContext.out.write(o.toString()); @@ -210,6 +206,12 @@ public static String showDF(SparkContext sc, sc.setJobGroup(jobGroup, "Zeppelin", false); try { + // convert it to DataFrame if it is Dataset, as we will iterate all the records + // and assume it is type Row. + if (df.getClass().getCanonicalName().equals("org.apache.spark.sql.Dataset")) { + Method convertToDFMethod = df.getClass().getMethod("toDF"); + df = convertToDFMethod.invoke(df); + } take = df.getClass().getMethod("take", int.class); rows = (Object[]) take.invoke(df, maxResult + 1); } catch (NoSuchMethodException | SecurityException | IllegalAccessException diff --git a/zeppelin-interpreter/src/main/java/org/apache/zeppelin/interpreter/remote/RemoteInterpreterServer.java b/zeppelin-interpreter/src/main/java/org/apache/zeppelin/interpreter/remote/RemoteInterpreterServer.java index 7ddb92838f4..8344366e569 100644 --- a/zeppelin-interpreter/src/main/java/org/apache/zeppelin/interpreter/remote/RemoteInterpreterServer.java +++ b/zeppelin-interpreter/src/main/java/org/apache/zeppelin/interpreter/remote/RemoteInterpreterServer.java @@ -504,11 +504,13 @@ protected InterpreterOutput createInterpreterOutput(final String noteId, final S return new InterpreterOutput(new InterpreterOutputListener() { @Override public void onAppend(InterpreterOutput out, byte[] line) { + logger.debug("Output Append:" + new String(line)); eventClient.onInterpreterOutputAppend(noteId, paragraphId, new String(line)); } @Override public void onUpdate(InterpreterOutput out, byte[] output) { + logger.debug("Output Update:" + new String(output)); eventClient.onInterpreterOutputUpdate(noteId, paragraphId, new String(output)); } }); diff --git a/zeppelin-server/src/test/java/org/apache/zeppelin/rest/ZeppelinSparkClusterTest.java b/zeppelin-server/src/test/java/org/apache/zeppelin/rest/ZeppelinSparkClusterTest.java index 4e516dbc664..025506844a7 100644 --- a/zeppelin-server/src/test/java/org/apache/zeppelin/rest/ZeppelinSparkClusterTest.java +++ b/zeppelin-server/src/test/java/org/apache/zeppelin/rest/ZeppelinSparkClusterTest.java @@ -17,6 +17,7 @@ package org.apache.zeppelin.rest; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; import java.io.File; import java.io.IOException; @@ -24,6 +25,7 @@ import java.util.Map; import org.apache.commons.io.FileUtils; +import org.apache.zeppelin.interpreter.InterpreterResult; import org.apache.zeppelin.interpreter.InterpreterSetting; import org.apache.zeppelin.notebook.Note; import org.apache.zeppelin.notebook.Paragraph; @@ -82,6 +84,57 @@ public void basicRDDTransformationAndActionTest() throws IOException { ZeppelinServer.notebook.removeNote(note.getId(), null); } + @Test + public void sparkSQLTest() throws IOException { + // create new note + Note note = ZeppelinServer.notebook.createNote(null); + int sparkVersion = getSparkVersionNumber(note); + // DataFrame API is available from spark 1.3 + if (sparkVersion >= 13) { + // test basic dataframe api + Paragraph p = note.addParagraph(); + Map config = p.getConfig(); + config.put("enabled", true); + p.setConfig(config); + p.setText("%spark val df=sqlContext.createDataFrame(Seq((\"hello\",20)))\n" + + "df.collect()"); + note.run(p.getId()); + waitForFinish(p); + assertEquals(Status.FINISHED, p.getStatus()); + assertTrue(p.getResult().message().contains( + "Array[org.apache.spark.sql.Row] = Array([hello,20])")); + + // test display DataFrame + p = note.addParagraph(); + config = p.getConfig(); + config.put("enabled", true); + p.setConfig(config); + p.setText("%spark val df=sqlContext.createDataFrame(Seq((\"hello\",20)))\n" + + "z.show(df)"); + note.run(p.getId()); + waitForFinish(p); + assertEquals(Status.FINISHED, p.getStatus()); + assertEquals(InterpreterResult.Type.TABLE, p.getResult().type()); + assertEquals("_1\t_2\nhello\t20\n", p.getResult().message()); + + // test display DataSet + if (sparkVersion >= 20) { + p = note.addParagraph(); + config = p.getConfig(); + config.put("enabled", true); + p.setConfig(config); + p.setText("%spark val ds=spark.createDataset(Seq((\"hello\",20)))\n" + + "z.show(ds)"); + note.run(p.getId()); + waitForFinish(p); + assertEquals(Status.FINISHED, p.getStatus()); + assertEquals(InterpreterResult.Type.TABLE, p.getResult().type()); + assertEquals("_1\t_2\nhello\t20\n", p.getResult().message()); + } + ZeppelinServer.notebook.removeNote(note.getId(), null); + } + } + @Test public void sparkRTest() throws IOException { // create new note @@ -152,6 +205,21 @@ public void pySparkTest() throws IOException { waitForFinish(p); assertEquals(Status.FINISHED, p.getStatus()); assertEquals("[Row(age=20, id=1)]\n", p.getResult().message()); + + // test display Dataframe + p = note.addParagraph(); + config = p.getConfig(); + config.put("enabled", true); + p.setConfig(config); + p.setText("%pyspark from pyspark.sql import Row\n" + + "df=sqlContext.createDataFrame([Row(id=1, age=20)])\n" + + "z.show(df)"); + note.run(p.getId()); + waitForFinish(p); + assertEquals(Status.FINISHED, p.getStatus()); + assertEquals(InterpreterResult.Type.TABLE, p.getResult().type()); + // TODO (zjffdu), one more \n is appended, need to investigate why. + assertEquals("age\tid\n20\t1\n\n", p.getResult().message()); } if (sparkVersion >= 20) { // run SparkSession test