Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.spark.SparkContext;
import org.apache.spark.SparkRBackend;
import org.apache.zeppelin.interpreter.*;
import org.apache.zeppelin.interpreter.thrift.InterpreterCompletion;
Expand Down Expand Up @@ -70,11 +71,16 @@ public void open() {
int port = SparkRBackend.port();

SparkInterpreter sparkInterpreter = getSparkInterpreter();
ZeppelinRContext.setSparkContext(sparkInterpreter.getSparkContext());
SparkContext sc = sparkInterpreter.getSparkContext();
SparkVersion sparkVersion = new SparkVersion(sc.version());
ZeppelinRContext.setSparkContext(sc);
if (Utils.isSpark2()) {
ZeppelinRContext.setSparkSession(sparkInterpreter.getSparkSession());
}
ZeppelinRContext.setSqlContext(sparkInterpreter.getSQLContext());
ZeppelinRContext.setZepplinContext(sparkInterpreter.getZeppelinContext());

zeppelinR = new ZeppelinR(rCmdPath, sparkRLibPath, port);
zeppelinR = new ZeppelinR(rCmdPath, sparkRLibPath, port, sparkVersion);
try {
zeppelinR.open();
} catch (IOException e) {
Expand Down
6 changes: 5 additions & 1 deletion spark/src/main/java/org/apache/zeppelin/spark/ZeppelinR.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
public class ZeppelinR implements ExecuteResultHandler {
Logger logger = LoggerFactory.getLogger(ZeppelinR.class);
private final String rCmdPath;
private final SparkVersion sparkVersion;
private DefaultExecutor executor;
private SparkOutputStream outputStream;
private PipedOutputStream input;
Expand Down Expand Up @@ -107,9 +108,11 @@ public Object getValue() {
* @param rCmdPath R repl commandline path
* @param libPath sparkr library path
*/
public ZeppelinR(String rCmdPath, String libPath, int sparkRBackendPort) {
public ZeppelinR(String rCmdPath, String libPath, int sparkRBackendPort,
SparkVersion sparkVersion) {
this.rCmdPath = rCmdPath;
this.libPath = libPath;
this.sparkVersion = sparkVersion;
this.port = sparkRBackendPort;
try {
File scriptFile = File.createTempFile("zeppelin_sparkr-", ".R");
Expand Down Expand Up @@ -137,6 +140,7 @@ public void open() throws IOException {
cmd.addArgument(Integer.toString(hashCode()));
cmd.addArgument(Integer.toString(port));
cmd.addArgument(libPath);
cmd.addArgument(Integer.toString(sparkVersion.toNumber()));

executor = new DefaultExecutor();
outputStream = new SparkOutputStream();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ public class ZeppelinRContext {
private static SparkContext sparkContext;
private static SQLContext sqlContext;
private static ZeppelinContext zeppelinContext;
private static Object sparkSession;

public static void setSparkContext(SparkContext sparkContext) {
ZeppelinRContext.sparkContext = sparkContext;
Expand All @@ -40,6 +41,10 @@ public static void setSqlContext(SQLContext sqlContext) {
ZeppelinRContext.sqlContext = sqlContext;
}

public static void setSparkSession(Object sparkSession) {
ZeppelinRContext.sparkSession = sparkSession;
}

public static SparkContext getSparkContext() {
return sparkContext;
}
Expand All @@ -52,4 +57,7 @@ public static ZeppelinContext getZeppelinContext() {
return zeppelinContext;
}

public static Object getSparkSession() {
return sparkSession;
}
}
5 changes: 5 additions & 0 deletions spark/src/main/resources/R/zeppelin_sparkr.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ args <- commandArgs(trailingOnly = TRUE)
hashCode <- as.integer(args[1])
port <- as.integer(args[2])
libPath <- args[3]
version <- as.integer(args[4])
rm(args)

print(paste("Port ", toString(port)))
Expand All @@ -41,6 +42,10 @@ assign(".scStartTime", as.integer(Sys.time()), envir = SparkR:::.sparkREnv)
# setup spark env
assign(".sc", SparkR:::callJStatic("org.apache.zeppelin.spark.ZeppelinRContext", "getSparkContext"), envir = SparkR:::.sparkREnv)
assign("sc", get(".sc", envir = SparkR:::.sparkREnv), envir=.GlobalEnv)
if (version >= 200) {
assign(".sparkRsession", SparkR:::callJStatic("org.apache.zeppelin.spark.ZeppelinRContext", "getSparkSession"), envir = SparkR:::.sparkREnv)
assign("spark", get(".sparkRsession", envir = SparkR:::.sparkREnv), envir = .GlobalEnv)
}
assign(".sqlc", SparkR:::callJStatic("org.apache.zeppelin.spark.ZeppelinRContext", "getSqlContext"), envir = SparkR:::.sparkREnv)
assign("sqlContext", get(".sqlc", envir = SparkR:::.sparkREnv), envir = .GlobalEnv)
assign(".zeppelinContext", SparkR:::callJStatic("org.apache.zeppelin.spark.ZeppelinRContext", "getZeppelinContext"), envir = .GlobalEnv)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,16 @@ public void sparkRTest() throws IOException {
}
}

// run markdown paragraph, again
String sqlContextName = "sqlContext";
if (sparkVersion >= 20) {
sqlContextName = "spark";
}
Paragraph p = note.addParagraph();
Map config = p.getConfig();
config.put("enabled", true);
p.setConfig(config);
p.setText("%r localDF <- data.frame(name=c(\"a\", \"b\", \"c\"), age=c(19, 23, 18))\n" +
"df <- createDataFrame(sqlContext, localDF)\n" +
"df <- createDataFrame(" + sqlContextName + ", localDF)\n" +
"count(df)"
);
note.run(p.getId());
Expand Down