Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 28 additions & 26 deletions spark/src/main/java/org/apache/zeppelin/spark/ZeppelinContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
Expand Down Expand Up @@ -55,6 +56,7 @@ public class ZeppelinContext {
private SparkDependencyResolver dep;
private InterpreterContext interpreterContext;
private int maxResult;
private List<Class> supportedClasses;

public ZeppelinContext(SparkContext sc, SQLContext sql,
InterpreterContext interpreterContext,
Expand All @@ -65,6 +67,25 @@ public ZeppelinContext(SparkContext sc, SQLContext sql,
this.interpreterContext = interpreterContext;
this.dep = dep;
this.maxResult = maxResult;
this.supportedClasses = new ArrayList<>();
try {
supportedClasses.add(this.getClass().forName("org.apache.spark.sql.Dataset"));
} catch (ClassNotFoundException e) {
}

try {
supportedClasses.add(this.getClass().forName("org.apache.spark.sql.DataFrame"));
} catch (ClassNotFoundException e) {
}

try {
supportedClasses.add(this.getClass().forName("org.apache.spark.sql.SchemaRDD"));
} catch (ClassNotFoundException e) {
}

if (supportedClasses.isEmpty()) {
throw new InterpreterException("Can not road Dataset/DataFrame/SchemaRDD class");
}
}

public SparkContext sc;
Expand Down Expand Up @@ -161,33 +182,8 @@ public void show(Object o) {

@ZeppelinApi
public void show(Object o, int maxResult) {
Class cls = null;
try {
cls = this.getClass().forName("org.apache.spark.sql.Dataset");
} catch (ClassNotFoundException e) {
}

if (cls == null) {
try {
cls = this.getClass().forName("org.apache.spark.sql.DataFrame");
} catch (ClassNotFoundException e) {
}
}

if (cls == null) {
try {
cls = this.getClass().forName("org.apache.spark.sql.SchemaRDD");
} catch (ClassNotFoundException e) {
}
}

if (cls == null) {
throw new InterpreterException("Can not road Dataset/DataFrame/SchemaRDD class");
}


try {
if (cls.isInstance(o)) {
if (supportedClasses.contains(o.getClass())) {
interpreterContext.out.write(showDF(sc, interpreterContext, o, maxResult));
} else {
interpreterContext.out.write(o.toString());
Expand All @@ -210,6 +206,12 @@ public static String showDF(SparkContext sc,
sc.setJobGroup(jobGroup, "Zeppelin", false);

try {
// convert it to DataFrame if it is Dataset, as we will iterate all the records
// and assume it is type Row.
if (df.getClass().getCanonicalName().equals("org.apache.spark.sql.Dataset")) {
Method convertToDFMethod = df.getClass().getMethod("toDF");
df = convertToDFMethod.invoke(df);
}
take = df.getClass().getMethod("take", int.class);
rows = (Object[]) take.invoke(df, maxResult + 1);
} catch (NoSuchMethodException | SecurityException | IllegalAccessException
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -504,11 +504,13 @@ protected InterpreterOutput createInterpreterOutput(final String noteId, final S
return new InterpreterOutput(new InterpreterOutputListener() {
@Override
public void onAppend(InterpreterOutput out, byte[] line) {
logger.debug("Output Append:" + new String(line));
eventClient.onInterpreterOutputAppend(noteId, paragraphId, new String(line));
}

@Override
public void onUpdate(InterpreterOutput out, byte[] output) {
logger.debug("Output Update:" + new String(output));
eventClient.onInterpreterOutputUpdate(noteId, paragraphId, new String(output));
}
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@
package org.apache.zeppelin.rest;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

import java.io.File;
import java.io.IOException;
import java.util.List;
import java.util.Map;

import org.apache.commons.io.FileUtils;
import org.apache.zeppelin.interpreter.InterpreterResult;
import org.apache.zeppelin.interpreter.InterpreterSetting;
import org.apache.zeppelin.notebook.Note;
import org.apache.zeppelin.notebook.Paragraph;
Expand Down Expand Up @@ -82,6 +84,57 @@ public void basicRDDTransformationAndActionTest() throws IOException {
ZeppelinServer.notebook.removeNote(note.getId(), null);
}

@Test
public void sparkSQLTest() throws IOException {
// create new note
Note note = ZeppelinServer.notebook.createNote(null);
int sparkVersion = getSparkVersionNumber(note);
// DataFrame API is available from spark 1.3
if (sparkVersion >= 13) {
// test basic dataframe api
Paragraph p = note.addParagraph();
Map config = p.getConfig();
config.put("enabled", true);
p.setConfig(config);
p.setText("%spark val df=sqlContext.createDataFrame(Seq((\"hello\",20)))\n" +
"df.collect()");
note.run(p.getId());
waitForFinish(p);
assertEquals(Status.FINISHED, p.getStatus());
assertTrue(p.getResult().message().contains(
"Array[org.apache.spark.sql.Row] = Array([hello,20])"));

// test display DataFrame
p = note.addParagraph();
config = p.getConfig();
config.put("enabled", true);
p.setConfig(config);
p.setText("%spark val df=sqlContext.createDataFrame(Seq((\"hello\",20)))\n" +
"z.show(df)");
note.run(p.getId());
waitForFinish(p);
assertEquals(Status.FINISHED, p.getStatus());
assertEquals(InterpreterResult.Type.TABLE, p.getResult().type());
assertEquals("_1\t_2\nhello\t20\n", p.getResult().message());

// test display DataSet
if (sparkVersion >= 20) {
p = note.addParagraph();
config = p.getConfig();
config.put("enabled", true);
p.setConfig(config);
p.setText("%spark val ds=spark.createDataset(Seq((\"hello\",20)))\n" +
"z.show(ds)");
note.run(p.getId());
waitForFinish(p);
assertEquals(Status.FINISHED, p.getStatus());
assertEquals(InterpreterResult.Type.TABLE, p.getResult().type());
assertEquals("_1\t_2\nhello\t20\n", p.getResult().message());
}
ZeppelinServer.notebook.removeNote(note.getId(), null);
}
}

@Test
public void sparkRTest() throws IOException {
// create new note
Expand Down Expand Up @@ -152,6 +205,21 @@ public void pySparkTest() throws IOException {
waitForFinish(p);
assertEquals(Status.FINISHED, p.getStatus());
assertEquals("[Row(age=20, id=1)]\n", p.getResult().message());

// test display Dataframe
p = note.addParagraph();
config = p.getConfig();
config.put("enabled", true);
p.setConfig(config);
p.setText("%pyspark from pyspark.sql import Row\n" +
"df=sqlContext.createDataFrame([Row(id=1, age=20)])\n" +
"z.show(df)");
note.run(p.getId());
waitForFinish(p);
assertEquals(Status.FINISHED, p.getStatus());
assertEquals(InterpreterResult.Type.TABLE, p.getResult().type());
// TODO (zjffdu), one more \n is appended, need to investigate why.
assertEquals("age\tid\n20\t1\n\n", p.getResult().message());
}
if (sparkVersion >= 20) {
// run SparkSession test
Expand Down