diff --git a/spark/src/main/java/org/apache/zeppelin/spark/SparkInterpreter.java b/spark/src/main/java/org/apache/zeppelin/spark/SparkInterpreter.java index 9a54912a35c..d465890dcee 100644 --- a/spark/src/main/java/org/apache/zeppelin/spark/SparkInterpreter.java +++ b/spark/src/main/java/org/apache/zeppelin/spark/SparkInterpreter.java @@ -118,7 +118,7 @@ public class SparkInterpreter extends Interpreter { private Map binder; private SparkVersion sparkVersion; - private File outputDir; // class outputdir for scala 2.11 + private static File outputDir; // class outputdir for scala 2.11 private Object classServer; // classserver for scala 2.11 @@ -572,8 +572,11 @@ public void open() { sparkReplClassDir = System.getProperty("java.io.tmpdir"); } - outputDir = createTempDir(sparkReplClassDir); - + synchronized (sharedInterpreterLock) { + if (outputDir == null) { + outputDir = createTempDir(sparkReplClassDir); + } + } argList.add("-Yrepl-class-based"); argList.add("-Yrepl-outdir"); argList.add(outputDir.getAbsolutePath()); @@ -1276,7 +1279,12 @@ public void close() { logger.info("Close interpreter"); if (numReferenceOfSparkContext.decrementAndGet() == 0) { - sc.stop(); + if (sparkSession != null) { + Utils.invokeMethod(sparkSession, "stop"); + } else if (sc != null){ + sc.stop(); + } + sparkSession = null; sc = null; if (classServer != null) { Utils.invokeMethod(classServer, "stop"); diff --git a/spark/src/main/resources/python/zeppelin_pyspark.py b/spark/src/main/resources/python/zeppelin_pyspark.py index 3e6535fa4f9..53465c2cd80 100644 --- a/spark/src/main/resources/python/zeppelin_pyspark.py +++ b/spark/src/main/resources/python/zeppelin_pyspark.py @@ -218,7 +218,10 @@ def getCompletion(self, text_value): jconf = intp.getSparkConf() conf = SparkConf(_jvm = gateway.jvm, _jconf = jconf) sc = SparkContext(jsc=jsc, gateway=gateway, conf=conf) -sqlc = SQLContext(sc, intp.getSQLContext()) +if sparkVersion.isSpark2(): + sqlc = SQLContext(sparkContext=sc, jsqlContext=intp.getSQLContext()) +else: + sqlc = SQLContext(sparkContext=sc, sqlContext=intp.getSQLContext()) sqlContext = sqlc if sparkVersion.isSpark2(): diff --git a/zeppelin-server/src/test/java/org/apache/zeppelin/rest/AbstractTestRestApi.java b/zeppelin-server/src/test/java/org/apache/zeppelin/rest/AbstractTestRestApi.java index 580e5a08e94..eb080fe9605 100644 --- a/zeppelin-server/src/test/java/org/apache/zeppelin/rest/AbstractTestRestApi.java +++ b/zeppelin-server/src/test/java/org/apache/zeppelin/rest/AbstractTestRestApi.java @@ -154,7 +154,7 @@ protected static void startUp() throws Exception { // set spark master and other properties sparkIntpSetting.getProperties().setProperty("master", "spark://" + getHostname() + ":7071"); sparkIntpSetting.getProperties().setProperty("spark.cores.max", "2"); - + sparkIntpSetting.getProperties().setProperty("zeppelin.spark.useHiveContext", "false"); // set spark home for pyspark sparkIntpSetting.getProperties().setProperty("spark.home", getSparkHome()); pySpark = true; @@ -171,10 +171,16 @@ protected static void startUp() throws Exception { String sparkHome = getSparkHome(); if (sparkHome != null) { - sparkIntpSetting.getProperties().setProperty("master", "spark://" + getHostname() + ":7071"); + if (System.getenv("SPARK_MASTER") != null) { + sparkIntpSetting.getProperties().setProperty("master", System.getenv("SPARK_MASTER")); + } else { + sparkIntpSetting.getProperties() + .setProperty("master", "spark://" + getHostname() + ":7071"); + } sparkIntpSetting.getProperties().setProperty("spark.cores.max", "2"); // set spark home for pyspark sparkIntpSetting.getProperties().setProperty("spark.home", sparkHome); + sparkIntpSetting.getProperties().setProperty("zeppelin.spark.useHiveContext", "false"); pySpark = true; sparkR = true; } @@ -194,7 +200,11 @@ private static String getHostname() { } private static String getSparkHome() { - String sparkHome = getSparkHomeRecursively(new File(System.getProperty(ZeppelinConfiguration.ConfVars.ZEPPELIN_HOME.getVarName()))); + String sparkHome = System.getenv("SPARK_HOME"); + if (sparkHome != null) { + return sparkHome; + } + sparkHome = getSparkHomeRecursively(new File(System.getProperty(ZeppelinConfiguration.ConfVars.ZEPPELIN_HOME.getVarName()))); System.out.println("SPARK HOME detected " + sparkHome); return sparkHome; } diff --git a/zeppelin-server/src/test/java/org/apache/zeppelin/rest/ZeppelinSparkClusterTest.java b/zeppelin-server/src/test/java/org/apache/zeppelin/rest/ZeppelinSparkClusterTest.java index 1250f9ce58b..4e516dbc664 100644 --- a/zeppelin-server/src/test/java/org/apache/zeppelin/rest/ZeppelinSparkClusterTest.java +++ b/zeppelin-server/src/test/java/org/apache/zeppelin/rest/ZeppelinSparkClusterTest.java @@ -135,11 +135,38 @@ public void pySparkTest() throws IOException { config.put("enabled", true); p.setConfig(config); p.setText("%pyspark print(sc.parallelize(range(1, 11)).reduce(lambda a, b: a + b))"); -// p.getRepl("org.apache.zeppelin.spark.SparkInterpreter").open(); note.run(p.getId()); waitForFinish(p); assertEquals(Status.FINISHED, p.getStatus()); assertEquals("55\n", p.getResult().message()); + if (sparkVersion >= 13) { + // run sqlContext test + p = note.addParagraph(); + config = p.getConfig(); + config.put("enabled", true); + p.setConfig(config); + p.setText("%pyspark from pyspark.sql import Row\n" + + "df=sqlContext.createDataFrame([Row(id=1, age=20)])\n" + + "df.collect()"); + note.run(p.getId()); + waitForFinish(p); + assertEquals(Status.FINISHED, p.getStatus()); + assertEquals("[Row(age=20, id=1)]\n", p.getResult().message()); + } + if (sparkVersion >= 20) { + // run SparkSession test + p = note.addParagraph(); + config = p.getConfig(); + config.put("enabled", true); + p.setConfig(config); + p.setText("%pyspark from pyspark.sql import Row\n" + + "df=sqlContext.createDataFrame([Row(id=1, age=20)])\n" + + "df.collect()"); + note.run(p.getId()); + waitForFinish(p); + assertEquals(Status.FINISHED, p.getStatus()); + assertEquals("[Row(age=20, id=1)]\n", p.getResult().message()); + } } ZeppelinServer.notebook.removeNote(note.getId(), null); } @@ -166,7 +193,6 @@ public void pySparkAutoConvertOptionTest() throws IOException { p.setText("%pyspark\nfrom pyspark.sql.functions import *\n" + "print(" + sqlContextName + ".range(0, 10).withColumn('uniform', rand(seed=10) * 3.14).count())"); -// p.getRepl("org.apache.zeppelin.spark.SparkInterpreter").open(); note.run(p.getId()); waitForFinish(p); assertEquals(Status.FINISHED, p.getStatus()); @@ -257,6 +283,7 @@ public void pySparkDepLoaderTest() throws IOException { assertEquals(Status.FINISHED, p1.getStatus()); assertEquals("2\n", p1.getResult().message()); } + ZeppelinServer.notebook.removeNote(note.getId(), null); } /** @@ -270,7 +297,6 @@ private int getSparkVersionNumber(Note note) { config.put("enabled", true); p.setConfig(config); p.setText("%spark print(sc.version)"); -// p.getRepl("org.apache.zeppelin.spark.SparkInterpreter").open(); note.run(p.getId()); waitForFinish(p); assertEquals(Status.FINISHED, p.getStatus()); diff --git a/zeppelin-server/src/test/resources/log4j.properties b/zeppelin-server/src/test/resources/log4j.properties index 376ce00d9b7..500739064a8 100644 --- a/zeppelin-server/src/test/resources/log4j.properties +++ b/zeppelin-server/src/test/resources/log4j.properties @@ -43,4 +43,4 @@ log4j.logger.DataNucleus.Datastore=ERROR # Log all JDBC parameters log4j.logger.org.hibernate.type=ALL - +log4j.logger.org.apache.zeppelin.interpreter.remote.RemoteInterpreter=DEBUG