diff --git a/cassandra/pom.xml b/cassandra/pom.xml index e12c75012d7..1019ad7acc2 100644 --- a/cassandra/pom.xml +++ b/cassandra/pom.xml @@ -35,8 +35,8 @@ 2.1.7.1 - 2.11.7 - 2.11 + 2.10.4 + 2.10 3.3.2 1.7.1 diff --git a/spark/pom.xml b/spark/pom.xml index 9b82acbedd5..5f470742c67 100644 --- a/spark/pom.xml +++ b/spark/pom.xml @@ -48,7 +48,7 @@ org.spark-project.akka 2.3.4-spark - + 1.9.5 http://www.apache.org/dist/spark/spark-${spark.version}/spark-${spark.version}.tgz @@ -494,6 +494,13 @@ junit test + + + org.mockito + mockito-core + ${mockito.version} + test + @@ -1002,7 +1009,19 @@ - + + org.scalatest + scalatest-maven-plugin + 1.0 + + + test + + test + + + + diff --git a/spark/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java b/spark/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java index 852dd335183..f2261ebc435 100644 --- a/spark/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java +++ b/spark/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java @@ -42,6 +42,7 @@ import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.SQLContext; +import org.apache.zeppelin.context.ZeppelinContext; import org.apache.zeppelin.interpreter.Interpreter; import org.apache.zeppelin.interpreter.InterpreterContext; import org.apache.zeppelin.interpreter.InterpreterException; diff --git a/spark/src/main/java/org/apache/zeppelin/spark/SparkInterpreter.java b/spark/src/main/java/org/apache/zeppelin/spark/SparkInterpreter.java index aec6d16d55a..65bae227571 100644 --- a/spark/src/main/java/org/apache/zeppelin/spark/SparkInterpreter.java +++ b/spark/src/main/java/org/apache/zeppelin/spark/SparkInterpreter.java @@ -33,6 +33,7 @@ import org.apache.spark.SparkConf; import org.apache.spark.SparkContext; import org.apache.spark.SparkEnv; +import org.apache.zeppelin.spark.display.SparkDisplayFunctionsHelper$; import org.apache.spark.repl.SparkCommandLine; import org.apache.spark.repl.SparkILoop; import org.apache.spark.repl.SparkIMain; @@ -43,6 +44,7 @@ import org.apache.spark.scheduler.Stage; import org.apache.spark.sql.SQLContext; import org.apache.spark.ui.jobs.JobProgressListener; +import org.apache.zeppelin.context.ZeppelinContext; import org.apache.zeppelin.interpreter.Interpreter; import org.apache.zeppelin.interpreter.InterpreterContext; import org.apache.zeppelin.interpreter.InterpreterException; @@ -335,14 +337,19 @@ public void open() { // https://groups.google.com/forum/#!topic/scala-user/MlVwo2xCCI0 /* - * > val env = new nsc.Settings(errLogger) > env.usejavacp.value = true > val p = new - * Interpreter(env) > p.setContextClassLoader > Alternatively you can set the class path through - * nsc.Settings.classpath. + * > val env = new nsc.Settings(errLogger) + * > env.usejavacp.value = true + * > val p = new Interpreter(env) + * > p.setContextClassLoader + * > Alternatively you can set the class path through nsc.Settings.classpath. * - * >> val settings = new Settings() >> settings.usejavacp.value = true >> - * settings.classpath.value += File.pathSeparator + >> System.getProperty("java.class.path") >> - * val in = new Interpreter(settings) { >> override protected def parentClassLoader = - * getClass.getClassLoader >> } >> in.setContextClassLoader() + * >> val settings = new Settings() + * >> settings.usejavacp.value = true + * >> settings.classpath.value += File.pathSeparator + System.getProperty("java.class.path") + * >> val in = new Interpreter(settings) { + * >> override protected def parentClassLoader = getClass.getClassLoader + * >> } + * >> in.setContextClassLoader() */ Settings settings = new Settings(); if (getProperty("args") != null) { @@ -433,18 +440,24 @@ public void open() { dep = getDependencyResolver(); - z = new ZeppelinContext(sc, sqlc, null, dep, printStream, - Integer.parseInt(getProperty("zeppelin.spark.maxResult"))); + final int defaultSparkMaxResult = Integer.parseInt(getProperty("zeppelin.spark.maxResult")); + z = new ZeppelinContext(defaultSparkMaxResult); + + SparkDisplayFunctionsHelper$.MODULE$.registerDisplayFunctions(sc, z); intp.interpret("@transient var _binder = new java.util.HashMap[String, Object]()"); binder = (Map) getValue("_binder"); binder.put("sc", sc); binder.put("sqlc", sqlc); binder.put("z", z); + binder.put("dep", dep); binder.put("out", printStream); intp.interpret("@transient val z = " - + "_binder.get(\"z\").asInstanceOf[org.apache.zeppelin.spark.ZeppelinContext]"); + + "_binder.get(\"z\").asInstanceOf[org.apache.zeppelin.context.ZeppelinContext]"); + intp.interpret("@transient val dep = " + + "_binder.get(\"dep\").asInstanceOf" + + "[org.apache.zeppelin.spark.dep.DependencyResolver]"); intp.interpret("@transient val sc = " + "_binder.get(\"sc\").asInstanceOf[org.apache.spark.SparkContext]"); intp.interpret("@transient val sqlc = " @@ -467,17 +480,6 @@ public void open() { intp.interpret("import org.apache.spark.sql.functions._"); } - /* Temporary disabling DisplayUtils. see https://issues.apache.org/jira/browse/ZEPPELIN-127 - * - // Utility functions for display - intp.interpret("import org.apache.zeppelin.spark.utils.DisplayUtils._"); - - // Scala implicit value for spark.maxResult - intp.interpret("import org.apache.zeppelin.spark.utils.SparkMaxResult"); - intp.interpret("implicit val sparkMaxResult = new SparkMaxResult(" + - Integer.parseInt(getProperty("zeppelin.spark.maxResult")) + ")"); - */ - try { if (sc.version().startsWith("1.1") || sc.version().startsWith("1.2")) { Method loadFiles = this.interpreter.getClass().getMethod("loadFiles", Settings.class); diff --git a/spark/src/main/java/org/apache/zeppelin/spark/SparkSqlInterpreter.java b/spark/src/main/java/org/apache/zeppelin/spark/SparkSqlInterpreter.java index e60ff2bc6bf..5ac90e45c12 100644 --- a/spark/src/main/java/org/apache/zeppelin/spark/SparkSqlInterpreter.java +++ b/spark/src/main/java/org/apache/zeppelin/spark/SparkSqlInterpreter.java @@ -17,8 +17,11 @@ package org.apache.zeppelin.spark; +import java.io.ByteArrayOutputStream; +import java.io.PrintStream; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; +import java.util.Arrays; import java.util.List; import java.util.Properties; import java.util.Set; @@ -30,6 +33,8 @@ import org.apache.spark.scheduler.Stage; import org.apache.spark.sql.SQLContext; import org.apache.spark.ui.jobs.JobProgressListener; +import org.apache.zeppelin.context.ZeppelinContext; +import org.apache.zeppelin.display.DisplayParams; import org.apache.zeppelin.interpreter.Interpreter; import org.apache.zeppelin.interpreter.InterpreterContext; import org.apache.zeppelin.interpreter.InterpreterException; @@ -120,7 +125,8 @@ public void close() {} public InterpreterResult interpret(String st, InterpreterContext context) { SQLContext sqlc = null; - sqlc = getSparkInterpreter().getSQLContext(); + final SparkInterpreter sparkInterpreter = getSparkInterpreter(); + sqlc = sparkInterpreter.getSQLContext(); SparkContext sc = sqlc.sparkContext(); if (concurrentSQL()) { @@ -131,8 +137,13 @@ public InterpreterResult interpret(String st, InterpreterContext context) { Object rdd = sqlc.sql(st); - String msg = ZeppelinContext.showRDD(sc, context, rdd, maxResult); - return new InterpreterResult(Code.SUCCESS, msg); + final ZeppelinContext zeppelinContext = sparkInterpreter.getZeppelinContext(); + zeppelinContext.setInterpreterContext(context); + final ByteArrayOutputStream out = new ByteArrayOutputStream(); + PrintStream stream = new PrintStream(out); + zeppelinContext.display(rdd, new DisplayParams(maxResult, stream, context, + new java.util.ArrayList())); + return new InterpreterResult(Code.SUCCESS, out.toString()); } @Override diff --git a/spark/src/main/java/org/apache/zeppelin/spark/ZeppelinContext.java b/spark/src/main/java/org/apache/zeppelin/spark/ZeppelinContext.java deleted file mode 100644 index 6cb94d9e927..00000000000 --- a/spark/src/main/java/org/apache/zeppelin/spark/ZeppelinContext.java +++ /dev/null @@ -1,751 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.zeppelin.spark; - -import static scala.collection.JavaConversions.asJavaCollection; -import static scala.collection.JavaConversions.asJavaIterable; -import static scala.collection.JavaConversions.collectionAsScalaIterable; - -import java.io.PrintStream; -import java.lang.reflect.InvocationTargetException; -import java.lang.reflect.Method; -import java.util.Collection; -import java.util.HashMap; -import java.util.Iterator; -import java.util.LinkedList; -import java.util.List; - -import org.apache.spark.SparkContext; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.SQLContext.QueryExecution; -import org.apache.spark.sql.catalyst.expressions.Attribute; -import org.apache.spark.sql.hive.HiveContext; -import org.apache.zeppelin.display.AngularObject; -import org.apache.zeppelin.display.AngularObjectRegistry; -import org.apache.zeppelin.display.AngularObjectWatcher; -import org.apache.zeppelin.display.GUI; -import org.apache.zeppelin.display.Input.ParamOption; -import org.apache.zeppelin.interpreter.InterpreterContext; -import org.apache.zeppelin.interpreter.InterpreterContextRunner; -import org.apache.zeppelin.interpreter.InterpreterException; -import org.apache.zeppelin.spark.dep.DependencyResolver; - -import scala.Tuple2; -import scala.Unit; -import scala.collection.Iterable; - -/** - * Spark context for zeppelin. - * - * @author Leemoonsoo - * - */ -public class ZeppelinContext extends HashMap { - private DependencyResolver dep; - private PrintStream out; - private InterpreterContext interpreterContext; - private int maxResult; - - public ZeppelinContext(SparkContext sc, SQLContext sql, - InterpreterContext interpreterContext, - DependencyResolver dep, PrintStream printStream, - int maxResult) { - this.sc = sc; - this.sqlContext = sql; - this.interpreterContext = interpreterContext; - this.dep = dep; - this.out = printStream; - this.maxResult = maxResult; - } - - public SparkContext sc; - public SQLContext sqlContext; - public HiveContext hiveContext; - private GUI gui; - - /** - * Load dependency for interpreter and runtime (driver). - * And distribute them to spark cluster (sc.add()) - * - * @param artifact "group:artifact:version" or file path like "/somepath/your.jar" - * @return - * @throws Exception - */ - public Iterable load(String artifact) throws Exception { - return collectionAsScalaIterable(dep.load(artifact, true)); - } - - /** - * Load dependency and it's transitive dependencies for interpreter and runtime (driver). - * And distribute them to spark cluster (sc.add()) - * - * @param artifact "groupId:artifactId:version" or file path like "/somepath/your.jar" - * @param excludes exclusion list of transitive dependency. list of "groupId:artifactId" string. - * @return - * @throws Exception - */ - public Iterable load(String artifact, scala.collection.Iterable excludes) - throws Exception { - return collectionAsScalaIterable( - dep.load(artifact, - asJavaCollection(excludes), - true)); - } - - /** - * Load dependency and it's transitive dependencies for interpreter and runtime (driver). - * And distribute them to spark cluster (sc.add()) - * - * @param artifact "groupId:artifactId:version" or file path like "/somepath/your.jar" - * @param excludes exclusion list of transitive dependency. list of "groupId:artifactId" string. - * @return - * @throws Exception - */ - public Iterable load(String artifact, Collection excludes) throws Exception { - return collectionAsScalaIterable(dep.load(artifact, excludes, true)); - } - - /** - * Load dependency for interpreter and runtime, and then add to sparkContext. - * But not adding them to spark cluster - * - * @param artifact "groupId:artifactId:version" or file path like "/somepath/your.jar" - * @return - * @throws Exception - */ - public Iterable loadLocal(String artifact) throws Exception { - return collectionAsScalaIterable(dep.load(artifact, false)); - } - - - /** - * Load dependency and it's transitive dependencies and then add to sparkContext. - * But not adding them to spark cluster - * - * @param artifact "groupId:artifactId:version" or file path like "/somepath/your.jar" - * @param excludes exclusion list of transitive dependency. list of "groupId:artifactId" string. - * @return - * @throws Exception - */ - public Iterable loadLocal(String artifact, - scala.collection.Iterable excludes) throws Exception { - return collectionAsScalaIterable(dep.load(artifact, - asJavaCollection(excludes), false)); - } - - /** - * Load dependency and it's transitive dependencies and then add to sparkContext. - * But not adding them to spark cluster - * - * @param artifact "groupId:artifactId:version" or file path like "/somepath/your.jar" - * @param excludes exclusion list of transitive dependency. list of "groupId:artifactId" string. - * @return - * @throws Exception - */ - public Iterable loadLocal(String artifact, Collection excludes) - throws Exception { - return collectionAsScalaIterable(dep.load(artifact, excludes, false)); - } - - - /** - * Add maven repository - * - * @param id id of repository ex) oss, local, snapshot - * @param url url of repository. supported protocol : file, http, https - */ - public void addRepo(String id, String url) { - addRepo(id, url, false); - } - - /** - * Add maven repository - * - * @param id id of repository - * @param url url of repository. supported protocol : file, http, https - * @param snapshot true if it is snapshot repository - */ - public void addRepo(String id, String url, boolean snapshot) { - dep.addRepo(id, url, snapshot); - } - - /** - * Remove maven repository by id - * @param id id of repository - */ - public void removeRepo(String id){ - dep.delRepo(id); - } - - /** - * Load dependency only interpreter. - * - * @param name - * @return - */ - - public Object input(String name) { - return input(name, ""); - } - - public Object input(String name, Object defaultValue) { - return gui.input(name, defaultValue); - } - - public Object select(String name, scala.collection.Iterable> options) { - return select(name, "", options); - } - - public Object select(String name, Object defaultValue, - scala.collection.Iterable> options) { - int n = options.size(); - ParamOption[] paramOptions = new ParamOption[n]; - Iterator> it = asJavaIterable(options).iterator(); - - int i = 0; - while (it.hasNext()) { - Tuple2 valueAndDisplayValue = it.next(); - paramOptions[i++] = new ParamOption(valueAndDisplayValue._1(), valueAndDisplayValue._2()); - } - - return gui.select(name, "", paramOptions); - } - - public void setGui(GUI o) { - this.gui = o; - } - - private void restartInterpreter() { - } - - public InterpreterContext getInterpreterContext() { - return interpreterContext; - } - - public void setInterpreterContext(InterpreterContext interpreterContext) { - this.interpreterContext = interpreterContext; - } - - public void setMaxResult(int maxResult) { - this.maxResult = maxResult; - } - - /** - * show DataFrame or SchemaRDD - * @param o DataFrame or SchemaRDD object - */ - public void show(Object o) { - show(o, maxResult); - } - - /** - * show DataFrame or SchemaRDD - * @param o DataFrame or SchemaRDD object - * @param maxResult maximum number of rows to display - */ - public void show(Object o, int maxResult) { - Class cls = null; - try { - cls = this.getClass().forName("org.apache.spark.sql.DataFrame"); - } catch (ClassNotFoundException e) { - } - - if (cls == null) { - try { - cls = this.getClass().forName("org.apache.spark.sql.SchemaRDD"); - } catch (ClassNotFoundException e) { - } - } - - if (cls == null) { - throw new InterpreterException("Can not road DataFrame/SchemaRDD class"); - } - - if (cls.isInstance(o)) { - out.print(showRDD(sc, interpreterContext, o, maxResult)); - } else { - out.print(o.toString()); - } - } - - public static String showRDD(SparkContext sc, - InterpreterContext interpreterContext, - Object rdd, int maxResult) { - Object[] rows = null; - Method take; - String jobGroup = "zeppelin-" + interpreterContext.getParagraphId(); - sc.setJobGroup(jobGroup, "Zeppelin", false); - - try { - take = rdd.getClass().getMethod("take", int.class); - rows = (Object[]) take.invoke(rdd, maxResult + 1); - - } catch (NoSuchMethodException | SecurityException | IllegalAccessException - | IllegalArgumentException | InvocationTargetException e) { - sc.clearJobGroup(); - throw new InterpreterException(e); - } - - String msg = null; - - // get field names - Method queryExecution; - QueryExecution qe; - try { - queryExecution = rdd.getClass().getMethod("queryExecution"); - qe = (QueryExecution) queryExecution.invoke(rdd); - } catch (NoSuchMethodException | SecurityException | IllegalAccessException - | IllegalArgumentException | InvocationTargetException e) { - throw new InterpreterException(e); - } - - List columns = - scala.collection.JavaConverters.asJavaListConverter( - qe.analyzed().output()).asJava(); - - for (Attribute col : columns) { - if (msg == null) { - msg = col.name(); - } else { - msg += "\t" + col.name(); - } - } - - msg += "\n"; - - // ArrayType, BinaryType, BooleanType, ByteType, DecimalType, DoubleType, DynamicType, - // FloatType, FractionalType, IntegerType, IntegralType, LongType, MapType, NativeType, - // NullType, NumericType, ShortType, StringType, StructType - - try { - for (int r = 0; r < maxResult && r < rows.length; r++) { - Object row = rows[r]; - Method isNullAt = row.getClass().getMethod("isNullAt", int.class); - Method apply = row.getClass().getMethod("apply", int.class); - - for (int i = 0; i < columns.size(); i++) { - if (!(Boolean) isNullAt.invoke(row, i)) { - msg += apply.invoke(row, i).toString(); - } else { - msg += "null"; - } - if (i != columns.size() - 1) { - msg += "\t"; - } - } - msg += "\n"; - } - } catch (NoSuchMethodException | SecurityException | IllegalAccessException - | IllegalArgumentException | InvocationTargetException e) { - throw new InterpreterException(e); - } - - if (rows.length > maxResult) { - msg += "\nResults are limited by " + maxResult + "."; - } - sc.clearJobGroup(); - return "%table " + msg; - } - - /** - * Run paragraph by id - * @param id - */ - public void run(String id) { - run(id, interpreterContext); - } - - /** - * Run paragraph by id - * @param id - * @param context - */ - public void run(String id, InterpreterContext context) { - if (id.equals(context.getParagraphId())) { - throw new InterpreterException("Can not run current Paragraph"); - } - - for (InterpreterContextRunner r : context.getRunners()) { - if (id.equals(r.getParagraphId())) { - r.run(); - return; - } - } - - throw new InterpreterException("Paragraph " + id + " not found"); - } - - /** - * Run paragraph at idx - * @param idx - */ - public void run(int idx) { - run(idx, interpreterContext); - } - - /** - * Run paragraph at index - * @param idx index starting from 0 - * @param context interpreter context - */ - public void run(int idx, InterpreterContext context) { - if (idx >= context.getRunners().size()) { - throw new InterpreterException("Index out of bound"); - } - - InterpreterContextRunner runner = context.getRunners().get(idx); - if (runner.getParagraphId().equals(context.getParagraphId())) { - throw new InterpreterException("Can not run current Paragraph"); - } - - runner.run(); - } - - public void run(List paragraphIdOrIdx) { - run(paragraphIdOrIdx, interpreterContext); - } - - /** - * Run paragraphs - * @param paragraphIdOrIdxs list of paragraph id or idx - */ - public void run(List paragraphIdOrIdx, InterpreterContext context) { - for (Object idOrIdx : paragraphIdOrIdx) { - if (idOrIdx instanceof String) { - String id = (String) idOrIdx; - run(id, context); - } else if (idOrIdx instanceof Integer) { - Integer idx = (Integer) idOrIdx; - run(idx, context); - } else { - throw new InterpreterException("Paragraph " + idOrIdx + " not found"); - } - } - } - - public void runAll() { - runAll(interpreterContext); - } - - /** - * Run all paragraphs. except this. - */ - public void runAll(InterpreterContext context) { - for (InterpreterContextRunner r : context.getRunners()) { - if (r.getParagraphId().equals(context.getParagraphId())) { - // skip itself - continue; - } - r.run(); - } - } - - public List listParagraphs() { - List paragraphs = new LinkedList(); - - for (InterpreterContextRunner r : interpreterContext.getRunners()) { - paragraphs.add(r.getParagraphId()); - } - - return paragraphs; - } - - - private AngularObject getAngularObject(String name, InterpreterContext interpreterContext) { - AngularObjectRegistry registry = interpreterContext.getAngularObjectRegistry(); - String noteId = interpreterContext.getNoteId(); - // try get local object - AngularObject ao = registry.get(name, interpreterContext.getNoteId()); - if (ao == null) { - // then global object - ao = registry.get(name, null); - } - return ao; - } - - - /** - * Get angular object. Look up local registry first and then global registry - * @param name variable name - * @return value - */ - public Object angular(String name) { - AngularObject ao = getAngularObject(name, interpreterContext); - if (ao == null) { - return null; - } else { - return ao.get(); - } - } - - /** - * Get angular object. Look up global registry - * @param name variable name - * @return value - */ - public Object angularGlobal(String name) { - AngularObjectRegistry registry = interpreterContext.getAngularObjectRegistry(); - AngularObject ao = registry.get(name, null); - if (ao == null) { - return null; - } else { - return ao.get(); - } - } - - /** - * Create angular variable in local registry and bind with front end Angular display system. - * If variable exists, it'll be overwritten. - * @param name name of the variable - * @param o value - */ - public void angularBind(String name, Object o) { - angularBind(name, o, interpreterContext.getNoteId()); - } - - /** - * Create angular variable in global registry and bind with front end Angular display system. - * If variable exists, it'll be overwritten. - * @param name name of the variable - * @param o value - */ - public void angularBindGlobal(String name, Object o) { - angularBind(name, o, (String) null); - } - - /** - * Create angular variable in local registry and bind with front end Angular display system. - * If variable exists, value will be overwritten and watcher will be added. - * @param name name of variable - * @param o value - * @param watcher watcher of the variable - */ - public void angularBind(String name, Object o, AngularObjectWatcher watcher) { - angularBind(name, o, interpreterContext.getNoteId(), watcher); - } - - /** - * Create angular variable in global registry and bind with front end Angular display system. - * If variable exists, value will be overwritten and watcher will be added. - * @param name name of variable - * @param o value - * @param watcher watcher of the variable - */ - public void angularBindGlobal(String name, Object o, AngularObjectWatcher watcher) { - angularBind(name, o, null, watcher); - } - - /** - * Add watcher into angular variable (local registry) - * @param name name of the variable - * @param watcher watcher - */ - public void angularWatch(String name, AngularObjectWatcher watcher) { - angularWatch(name, interpreterContext.getNoteId(), watcher); - } - - /** - * Add watcher into angular variable (global registry) - * @param name name of the variable - * @param watcher watcher - */ - public void angularWatchGlobal(String name, AngularObjectWatcher watcher) { - angularWatch(name, null, watcher); - } - - - public void angularWatch(String name, - final scala.Function2 func) { - angularWatch(name, interpreterContext.getNoteId(), func); - } - - public void angularWatchGlobal(String name, - final scala.Function2 func) { - angularWatch(name, null, func); - } - - public void angularWatch( - String name, - final scala.Function3 func) { - angularWatch(name, interpreterContext.getNoteId(), func); - } - - public void angularWatchGlobal( - String name, - final scala.Function3 func) { - angularWatch(name, null, func); - } - - /** - * Remove watcher from angular variable (local) - * @param name - * @param watcher - */ - public void angularUnwatch(String name, AngularObjectWatcher watcher) { - angularUnwatch(name, interpreterContext.getNoteId(), watcher); - } - - /** - * Remove watcher from angular variable (global) - * @param name - * @param watcher - */ - public void angularUnwatchGlobal(String name, AngularObjectWatcher watcher) { - angularUnwatch(name, null, watcher); - } - - - /** - * Remove all watchers for the angular variable (local) - * @param name - */ - public void angularUnwatch(String name) { - angularUnwatch(name, interpreterContext.getNoteId()); - } - - /** - * Remove all watchers for the angular variable (global) - * @param name - */ - public void angularUnwatchGlobal(String name) { - angularUnwatch(name, (String) null); - } - - /** - * Remove angular variable and all the watchers. - * @param name - */ - public void angularUnbind(String name) { - String noteId = interpreterContext.getNoteId(); - angularUnbind(name, noteId); - } - - /** - * Remove angular variable and all the watchers. - * @param name - */ - public void angularUnbindGlobal(String name) { - angularUnbind(name, null); - } - - /** - * Create angular variable in local registry and bind with front end Angular display system. - * If variable exists, it'll be overwritten. - * @param name name of the variable - * @param o value - */ - private void angularBind(String name, Object o, String noteId) { - AngularObjectRegistry registry = interpreterContext.getAngularObjectRegistry(); - - if (registry.get(name, noteId) == null) { - registry.add(name, o, noteId); - } else { - registry.get(name, noteId).set(o); - } - } - - /** - * Create angular variable in local registry and bind with front end Angular display system. - * If variable exists, value will be overwritten and watcher will be added. - * @param name name of variable - * @param o value - * @param watcher watcher of the variable - */ - private void angularBind(String name, Object o, String noteId, AngularObjectWatcher watcher) { - AngularObjectRegistry registry = interpreterContext.getAngularObjectRegistry(); - - if (registry.get(name, noteId) == null) { - registry.add(name, o, noteId); - } else { - registry.get(name, noteId).set(o); - } - angularWatch(name, watcher); - } - - /** - * Add watcher into angular binding variable - * @param name name of the variable - * @param watcher watcher - */ - private void angularWatch(String name, String noteId, AngularObjectWatcher watcher) { - AngularObjectRegistry registry = interpreterContext.getAngularObjectRegistry(); - - if (registry.get(name, noteId) != null) { - registry.get(name, noteId).addWatcher(watcher); - } - } - - - private void angularWatch(String name, String noteId, - final scala.Function2 func) { - AngularObjectWatcher w = new AngularObjectWatcher(getInterpreterContext()) { - @Override - public void watch(Object oldObject, Object newObject, - InterpreterContext context) { - func.apply(newObject, newObject); - } - }; - angularWatch(name, noteId, w); - } - - private void angularWatch( - String name, - String noteId, - final scala.Function3 func) { - AngularObjectWatcher w = new AngularObjectWatcher(getInterpreterContext()) { - @Override - public void watch(Object oldObject, Object newObject, - InterpreterContext context) { - func.apply(oldObject, newObject, context); - } - }; - angularWatch(name, noteId, w); - } - - /** - * Remove watcher - * @param name - * @param watcher - */ - private void angularUnwatch(String name, String noteId, AngularObjectWatcher watcher) { - AngularObjectRegistry registry = interpreterContext.getAngularObjectRegistry(); - if (registry.get(name, noteId) != null) { - registry.get(name, noteId).removeWatcher(watcher); - } - } - - /** - * Remove all watchers for the angular variable - * @param name - */ - private void angularUnwatch(String name, String noteId) { - AngularObjectRegistry registry = interpreterContext.getAngularObjectRegistry(); - if (registry.get(name, noteId) != null) { - registry.get(name, noteId).clearAllWatchers(); - } - } - - /** - * Remove angular variable and all the watchers. - * @param name - */ - private void angularUnbind(String name, String noteId) { - AngularObjectRegistry registry = interpreterContext.getAngularObjectRegistry(); - registry.remove(name, noteId); - } -} diff --git a/spark/src/main/scala/org/apache/spark/sql/QueryExecutionHelper.scala b/spark/src/main/scala/org/apache/spark/sql/QueryExecutionHelper.scala new file mode 100644 index 00000000000..25bc583980c --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/QueryExecutionHelper.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.SparkContext +import org.apache.spark.sql.catalyst.expressions.Attribute +class QueryExecutionHelper(@transient val sc: SparkContext) extends SQLContext(sc) { + def schemaAttributes(rdd: Any):Seq[Attribute] = { + rdd.getClass().getMethod("queryExecution").invoke(rdd).asInstanceOf[SQLContext#QueryExecution].analyzed.output + } +} diff --git a/spark/src/main/scala/org/apache/zeppelin/spark/display/AbstractDisplay.scala b/spark/src/main/scala/org/apache/zeppelin/spark/display/AbstractDisplay.scala new file mode 100644 index 00000000000..6e61a90ef31 --- /dev/null +++ b/spark/src/main/scala/org/apache/zeppelin/spark/display/AbstractDisplay.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.zeppelin.spark.display + +import org.apache.spark.SparkContext +import org.apache.spark.sql.{QueryExecutionHelper, Row} +import org.apache.zeppelin.interpreter.{InterpreterContext, InterpreterException} +import org.slf4j.{LoggerFactory, Logger} + +/** + * Trait to display Scala traversable + * and SchemaRDD or DataFrame + */ +trait AbstractDisplay { + + private val logger:Logger = LoggerFactory.getLogger(classOf[AbstractDisplay]) + + def printFormattedData[T](traversable: Traversable[T], columnLabels: String*): String = { + + if(logger.isDebugEnabled()) logger.debug(s"Print collection $traversable with columns label $columnLabels") + + val providedLabelCount: Int = columnLabels.size + var maxColumnCount:Int = 1 + val headers = new StringBuilder("%table ") + + val data = new StringBuilder("") + + traversable.foreach(instance => { + require(instance.isInstanceOf[Product], + throw new InterpreterException(s"$instance should be an instance of scala.Product (case class or tuple)")) + + val product = instance.asInstanceOf[Product] + maxColumnCount = math.max(maxColumnCount,product.productArity) + data.append(product.productIterator.mkString("\t")).append("\n") + }) + + if (providedLabelCount > maxColumnCount) { + headers.append(columnLabels.take(maxColumnCount).mkString("\t")).append("\n") + } else if (providedLabelCount < maxColumnCount) { + val missingColumnHeaders = ((providedLabelCount+1) to maxColumnCount).foldLeft[String](""){ + (stringAccumulator,index) => if (index==1) s"Column$index" else s"$stringAccumulator\tColumn$index" + } + + headers.append(columnLabels.mkString("\t")).append(missingColumnHeaders).append("\n") + } else { + headers.append(columnLabels.mkString("\t")).append("\n") + } + + headers.append(data) + headers.toString + } + + def printDFOrSchemaRDD(sc: SparkContext, interpreterContext: InterpreterContext, df: Any, maxResult: Int): String = { + + if(logger.isDebugEnabled()) logger.debug(s"Print DataFrame/SchemaRDD $df limiting to $maxResult elements") + + sc.setJobGroup("zeppelin-" + interpreterContext.getParagraphId, "Zeppelin", false) + + val queryExecutionHelper: QueryExecutionHelper = new QueryExecutionHelper(sc) + try { + val rows: Array[Row] = df.getClass().getMethod("take", classOf[Int]).invoke(df, new Integer(maxResult)).asInstanceOf[Array[Row]] + val attributes: Seq[String] = queryExecutionHelper.schemaAttributes(df).map(_.name) + val msg = new StringBuilder("") + try { + val headerCount = attributes.size + msg.append("%table ").append(attributes.mkString("", "\t", "\n")) + + rows.foreach(row => { + val tableRow: String = (0 until headerCount).map(index => + if (row.isNullAt(index)) "null" else row(index).toString + ).mkString("", "\t", "\n") + msg.append(tableRow) + }) + } catch { + case e: Throwable => { + sc.clearJobGroup() + throw new InterpreterException(e) + } + } + msg.toString + } catch { + case e: Throwable => { + sc.clearJobGroup() + throw new InterpreterException(e) + } + } + } +} \ No newline at end of file diff --git a/spark/src/main/scala/org/apache/zeppelin/spark/display/DisplayDataFrame.scala b/spark/src/main/scala/org/apache/zeppelin/spark/display/DisplayDataFrame.scala new file mode 100644 index 00000000000..262d2c120d8 --- /dev/null +++ b/spark/src/main/scala/org/apache/zeppelin/spark/display/DisplayDataFrame.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.zeppelin.spark.display + +import org.apache.spark.SparkContext +import org.apache.zeppelin.display.{DisplayParams, DisplayFunction} +import org.apache.zeppelin.interpreter.{InterpreterContext, InterpreterException} +import org.slf4j.{LoggerFactory, Logger} + +/** + * Display function for DataFrame + * @param sc the current SparkContext + */ +class DisplayDataFrame(val sc: SparkContext) extends DisplayFunction with AbstractDisplay { + + val logger:Logger = LoggerFactory.getLogger(classOf[DisplayDataFrame]) + + val cls = try { + Class.forName("org.apache.spark.sql.DataFrame") + } catch { + case cnfe: ClassNotFoundException => + throw new InterpreterException("Cannot instantiate 'org.apache.spark.sql.DataFrame'. " + + "Are you sure you have Spark 1.3 or above ?") + case e:Throwable => throw new InterpreterException(e) + } + + override def canDisplay(anyObject: Any): Boolean = { + anyObject != null && cls.isInstance(anyObject) + } + + override def display(anyObject: Any, displayParams: DisplayParams): Unit = { + if(logger.isDebugEnabled()) logger.debug(s"Display $anyObject with params $displayParams") + + require(anyObject != null, "Cannot display null DataFrame") + val output = printDFOrSchemaRDD(sc, displayParams.context, anyObject, displayParams.maxResult) + displayParams.out.print(output) + } + +} diff --git a/spark/src/main/scala/org/apache/zeppelin/spark/display/DisplayRDD.scala b/spark/src/main/scala/org/apache/zeppelin/spark/display/DisplayRDD.scala new file mode 100644 index 00000000000..8d0aff00ae3 --- /dev/null +++ b/spark/src/main/scala/org/apache/zeppelin/spark/display/DisplayRDD.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.zeppelin.spark.display + +import org.apache.spark.rdd.RDD +import org.apache.zeppelin.display.{DisplayParams, DisplayFunction} +import org.slf4j.{LoggerFactory, Logger} +import scala.collection.JavaConverters._ + +/** + * Display function for RDD of Product + */ +class DisplayRDD extends DisplayFunction with AbstractDisplay { + + val logger:Logger = LoggerFactory.getLogger(classOf[DisplayRDD]) + + override def canDisplay(anyObject: Any): Boolean = { + anyObject != null && classOf[RDD[Product]].isAssignableFrom(anyObject.getClass) + } + + override def display(anyObject: Any, displayParams: DisplayParams): Unit = { + if(logger.isDebugEnabled()) logger.debug(s"Display $anyObject with params $displayParams") + + require(anyObject != null, "Cannot display null RDD") + val rdd: RDD[_] = anyObject.asInstanceOf[RDD[_]] + val nullSafeList = Option(displayParams.columnsLabel).getOrElse(List[String]().asJava) + val newParams = displayParams.copy(columnsLabel = nullSafeList) + val output = printFormattedData(rdd.take(newParams.maxResult), newParams.columnsLabel.asScala.toArray: _*) + newParams.out.print(output) + } +} + diff --git a/spark/src/main/scala/org/apache/zeppelin/spark/display/DisplaySchemaRDD.scala b/spark/src/main/scala/org/apache/zeppelin/spark/display/DisplaySchemaRDD.scala new file mode 100644 index 00000000000..501ea732e1d --- /dev/null +++ b/spark/src/main/scala/org/apache/zeppelin/spark/display/DisplaySchemaRDD.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.zeppelin.spark.display + +import org.apache.spark.SparkContext +import org.apache.zeppelin.display.{DisplayParams, DisplayFunction} +import org.apache.zeppelin.interpreter.{InterpreterContext, InterpreterException} +import org.slf4j.{LoggerFactory, Logger} + +/** + * Display function for SchemaRDD + * @param sc the current SparkContext + */ +class DisplaySchemaRDD(val sc: SparkContext) extends DisplayFunction with AbstractDisplay { + + val logger:Logger = LoggerFactory.getLogger(classOf[DisplaySchemaRDD]) + + val cls = try { + Class.forName("org.apache.spark.sql.SchemaRDD") + } catch { + case cnfe: ClassNotFoundException => + throw new InterpreterException("Cannot instantiate 'org.apache.spark.sql.SchemaRDD'. " + + "Are you sure you have Spark 1.2 or above ?") + case e:Throwable => throw new InterpreterException(e) + } + + override def canDisplay(anyObject: Any): Boolean = { + anyObject != null && cls.isInstance(anyObject) + } + + override def display(anyObject: Any, displayParams: DisplayParams): Unit = { + if(logger.isDebugEnabled()) logger.debug(s"Display $anyObject with params $displayParams") + + require(anyObject != null, "Cannot display null SchemaRDD") + val output = printDFOrSchemaRDD(sc, displayParams.context, anyObject, displayParams.maxResult) + displayParams.out.print(output) + } +} diff --git a/spark/src/main/scala/org/apache/zeppelin/spark/display/DisplayTraversable.scala b/spark/src/main/scala/org/apache/zeppelin/spark/display/DisplayTraversable.scala new file mode 100644 index 00000000000..8c63df47c94 --- /dev/null +++ b/spark/src/main/scala/org/apache/zeppelin/spark/display/DisplayTraversable.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.zeppelin.spark.display + +import org.apache.zeppelin.display.{DisplayParams, DisplayFunction} +import org.apache.zeppelin.interpreter.{InterpreterException} +import org.slf4j.{LoggerFactory, Logger} +import scala.collection.JavaConverters._ + +/** + * Display function for Scala traversables + */ +class DisplayTraversable extends DisplayFunction with AbstractDisplay{ + + val logger:Logger = LoggerFactory.getLogger(classOf[DisplayTraversable]) + + override def canDisplay(anyObject: Any): Boolean = { + anyObject != null && classOf[Traversable[Product]].isAssignableFrom(anyObject.getClass) + } + + override def display(anyObject: Any, displayParams: DisplayParams): Unit = { + if(logger.isDebugEnabled()) logger.debug(s"Display $anyObject with params $displayParams") + require(anyObject != null, "Cannot display null Scala collection") + val collection: Traversable[_] = anyObject.asInstanceOf[Traversable[_]] + val nullSafeList = Option(displayParams.columnsLabel).getOrElse(List[String]().asJava) + val newParams = displayParams.copy(columnsLabel = nullSafeList) + val output = printFormattedData(collection.take(newParams.maxResult), newParams.columnsLabel.asScala.toArray: _*) + newParams.out.print(output) + } +} \ No newline at end of file diff --git a/spark/src/main/scala/org/apache/zeppelin/spark/display/SparkDisplayFunctionsHelper.scala b/spark/src/main/scala/org/apache/zeppelin/spark/display/SparkDisplayFunctionsHelper.scala new file mode 100644 index 00000000000..82a08a08244 --- /dev/null +++ b/spark/src/main/scala/org/apache/zeppelin/spark/display/SparkDisplayFunctionsHelper.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.zeppelin.spark.display + +import org.apache.spark.SparkContext +import org.apache.zeppelin.context.ZeppelinContext +import org.apache.zeppelin.interpreter.InterpreterException +import org.slf4j.{LoggerFactory, Logger} + +/** + * Helper singleton to register + * appropriate display function + * depending on the Spark version + */ +object SparkDisplayFunctionsHelper{ + + val logger:Logger = LoggerFactory.getLogger(classOf[SparkDisplayFunctionsHelper]) + + def registerDisplayFunctions(sc: SparkContext, z: ZeppelinContext): Unit = { + + val dfClass = try { + Option(Class.forName("org.apache.spark.sql.DataFrame")) + } catch { + case e: Throwable => None + } + + val schemaRDDClass = try { + Option(Class.forName("org.apache.spark.sql.SchemaRDD")) + } catch { + case e: Throwable => None + } + + dfClass match { + case Some(_) => + logger.info("Registering DisplayDataFrame function") + z.registerDisplayFunction(new DisplayDataFrame(sc)) + case None => schemaRDDClass match { + case Some(_) => + logger.info("Registering DisplaySchemaRDD function") + z.registerDisplayFunction(new DisplaySchemaRDD(sc)) + case None => throw new InterpreterException("Can not road DataFrame/SchemaRDD class") + } + } + + logger.info("Registering DisplayRDD function") + z.registerDisplayFunction(new DisplayRDD) + + logger.info("Registering DisplayTraversable function") + z.registerDisplayFunction(new DisplayTraversable) + + } +} + +class SparkDisplayFunctionsHelper diff --git a/spark/src/main/scala/org/apache/zeppelin/spark/utils/DisplayUtils.scala b/spark/src/main/scala/org/apache/zeppelin/spark/utils/DisplayUtils.scala deleted file mode 100644 index 81814349c18..00000000000 --- a/spark/src/main/scala/org/apache/zeppelin/spark/utils/DisplayUtils.scala +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.zeppelin.spark.utils - -import java.lang.StringBuilder - -import org.apache.spark.rdd.RDD - -import scala.collection.IterableLike - -object DisplayUtils { - - implicit def toDisplayRDDFunctions[T <: Product](rdd: RDD[T]): DisplayRDDFunctions[T] = new DisplayRDDFunctions[T](rdd) - - implicit def toDisplayTraversableFunctions[T <: Product](traversable: Traversable[T]): DisplayTraversableFunctions[T] = new DisplayTraversableFunctions[T](traversable) - - def html(htmlContent: String = "") = s"%html $htmlContent" - - def img64(base64Content: String = "") = s"%img $base64Content" - - def img(url: String) = s"" -} - -trait DisplayCollection[T <: Product] { - - def printFormattedData(traversable: Traversable[T], columnLabels: String*): Unit = { - val providedLabelCount: Int = columnLabels.size - var maxColumnCount:Int = 1 - val headers = new StringBuilder("%table ") - - val data = new StringBuilder("") - - traversable.foreach(tuple => { - maxColumnCount = math.max(maxColumnCount,tuple.productArity) - data.append(tuple.productIterator.mkString("\t")).append("\n") - }) - - if (providedLabelCount > maxColumnCount) { - headers.append(columnLabels.take(maxColumnCount).mkString("\t")).append("\n") - } else if (providedLabelCount < maxColumnCount) { - val missingColumnHeaders = ((providedLabelCount+1) to maxColumnCount).foldLeft[String](""){ - (stringAccumulator,index) => if (index==1) s"Column$index" else s"$stringAccumulator\tColumn$index" - } - - headers.append(columnLabels.mkString("\t")).append(missingColumnHeaders).append("\n") - } else { - headers.append(columnLabels.mkString("\t")).append("\n") - } - - headers.append(data) - - print(headers.toString) - } - -} - -class DisplayRDDFunctions[T <: Product] (val rdd: RDD[T]) extends DisplayCollection[T] { - - def display(columnLabels: String*)(implicit sparkMaxResult: SparkMaxResult): Unit = { - printFormattedData(rdd.take(sparkMaxResult.maxResult), columnLabels: _*) - } - - def display(sparkMaxResult:Int, columnLabels: String*): Unit = { - printFormattedData(rdd.take(sparkMaxResult), columnLabels: _*) - } -} - -class DisplayTraversableFunctions[T <: Product] (val traversable: Traversable[T]) extends DisplayCollection[T] { - - def display(columnLabels: String*): Unit = { - printFormattedData(traversable, columnLabels: _*) - } -} - -class SparkMaxResult(val maxResult: Int) extends Serializable diff --git a/spark/src/test/java/org/apache/zeppelin/spark/SparkInterpreterTest.java b/spark/src/test/java/org/apache/zeppelin/spark/SparkInterpreterTest.java index daa7eeda81d..208d58f65ec 100644 --- a/spark/src/test/java/org/apache/zeppelin/spark/SparkInterpreterTest.java +++ b/spark/src/test/java/org/apache/zeppelin/spark/SparkInterpreterTest.java @@ -155,7 +155,7 @@ public void testZContextDependencyLoading() { assertEquals(InterpreterResult.Code.ERROR, repl.interpret("import org.apache.commons.csv.CSVFormat", context).code()); // load library from maven repository and try to import again - repl.interpret("z.load(\"org.apache.commons:commons-csv:1.1\")", context); + repl.interpret("dep.load(\"org.apache.commons:commons-csv:1.1\", true)", context); assertEquals(InterpreterResult.Code.SUCCESS, repl.interpret("import org.apache.commons.csv.CSVFormat", context).code()); } diff --git a/spark/src/test/java/org/apache/zeppelin/spark/SparkSqlInterpreterTest.java b/spark/src/test/java/org/apache/zeppelin/spark/SparkSqlInterpreterTest.java index bb818fd2c56..04c08b438c6 100644 --- a/spark/src/test/java/org/apache/zeppelin/spark/SparkSqlInterpreterTest.java +++ b/spark/src/test/java/org/apache/zeppelin/spark/SparkSqlInterpreterTest.java @@ -137,8 +137,8 @@ public void test_null_value_in_row() { repl.interpret( "val raw = csv.map(_.split(\",\")).map(p => Row(p(0),toInt(p(1)),p(2)))", context); - repl.interpret("val people = z.sqlContext.applySchema(raw, schema)", - context); + repl.interpret("val people = sqlContext.applySchema(raw, schema)",context); + if (isDataFrameSupported()) { repl.interpret("people.toDF.registerTempTable(\"people\")", context); } else { diff --git a/spark/src/test/scala/org/apache/zeppelin/spark/display/DisplayRDDTest.scala b/spark/src/test/scala/org/apache/zeppelin/spark/display/DisplayRDDTest.scala new file mode 100644 index 00000000000..9e98eee1448 --- /dev/null +++ b/spark/src/test/scala/org/apache/zeppelin/spark/display/DisplayRDDTest.scala @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.zeppelin.spark.display + +import java.io.{PrintStream, ByteArrayOutputStream} +import java.util.Arrays._ + +import org.apache.spark.rdd.RDD +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.zeppelin.context.ZeppelinContext +import org.apache.zeppelin.display.DisplayParams +import org.apache.zeppelin.interpreter.InterpreterException +import org.scalatest.mock.MockitoSugar +import org.scalatest.{BeforeAndAfterEach, BeforeAndAfter, FlatSpec, Matchers} + +case class Person(login : String, name: String, age: Int) + +class DisplayRDDTest extends FlatSpec + with BeforeAndAfter + with BeforeAndAfterEach + with Matchers + with MockitoSugar { + + + var sc: SparkContext = null + var testRDDTuples: RDD[(String,String,Int)] = null + var testRDDPersons: RDD[Person] = null + var z: ZeppelinContext = new ZeppelinContext(2) + var stream: ByteArrayOutputStream = null + var printStream: PrintStream = null + + before { + val sparkConf: SparkConf = new SparkConf(true) + .setAppName("test-DisplayRDD") + .setMaster("local") + sc = new SparkContext(sparkConf) + + testRDDTuples = sc.parallelize( + List( + ("jdoe", "John DOE", 32), + ("hsue", "Helen SUE", 27), + ("rsmith", "Richard SMITH", 45))) + + testRDDPersons = sc.parallelize( + List( + Person("jdoe", "John DOE", 32), + Person("hsue", "Helen SUE", 27), + Person("rsmith", "Richard SMITH", 45))) + z.registerDisplayFunction(new DisplayRDD) + } + + override def beforeEach(): Unit = { + stream = new java.io.ByteArrayOutputStream() + printStream = new PrintStream(stream) + } + + "DisplayRDD" should "generate correct column headers for tuples" in { + z.display(testRDDTuples, DisplayParams(100, printStream, null, asList("Login","Name","Age"))) + + stream.toString("UTF-8") should be("%table Login\tName\tAge\n" + + "jdoe\tJohn DOE\t32\n" + + "hsue\tHelen SUE\t27\n" + + "rsmith\tRichard SMITH\t45\n") + } + + "DisplayRDD" should "generate correct column headers for case class" in { + z.display(testRDDPersons, DisplayParams(100, printStream, null, asList("Login","Name","Age"))) + + stream.toString("UTF-8") should be("%table Login\tName\tAge\n" + + "jdoe\tJohn DOE\t32\n" + + "hsue\tHelen SUE\t27\n" + + "rsmith\tRichard SMITH\t45\n") + } + + "DisplayRDD" should "truncate exceeding column headers for tuples" in { + z.display(testRDDTuples, DisplayParams(100, printStream, null, asList("Login","Name","Age","xxx","yyy"))) + + stream.toString("UTF-8") should be("%table Login\tName\tAge\n" + + "jdoe\tJohn DOE\t32\n" + + "hsue\tHelen SUE\t27\n" + + "rsmith\tRichard SMITH\t45\n") + } + + "DisplayRDD" should "pad missing column headers with ColumnXXX for tuples" in { + z.display(testRDDTuples, DisplayParams(100, printStream, null, asList("Login"))) + + stream.toString("UTF-8") should be("%table Login\tColumn2\tColumn3\n" + + "jdoe\tJohn DOE\t32\n" + + "hsue\tHelen SUE\t27\n" + + "rsmith\tRichard SMITH\t45\n") + } + + "DisplayRDD" should "display RDD with limit" in { + + z.display(testRDDTuples, DisplayParams(2, printStream, null, asList("Login"))) + + stream.toString("UTF-8") should be("%table Login\tColumn2\tColumn3\n" + + "jdoe\tJohn DOE\t32\n" + + "hsue\tHelen SUE\t27\n") + } + + "DisplayRDD" should "only display RDDs" in { + + val exception = intercept[InterpreterException] { + z.display(List("a","b"), DisplayParams(100, printStream, null, asList("Login"))) + } + + exception.getMessage should be("Cannot find any suitable display function for object List(a, b)") + } + + "DisplayRDD" should "exception when displaying non Product RDD" in { + + val exception = intercept[InterpreterException] { + val rdd: RDD[String] = sc.parallelize(List[String]("a", "b")) + z.display(rdd, DisplayParams(100, printStream, null, asList("Login"))) + } + + exception.getMessage should be("a should be an instance of scala.Product (case class or tuple)") + } + + after { + sc.stop() + } + +} diff --git a/spark/src/test/scala/org/apache/zeppelin/spark/display/DisplayTraversableTest.scala b/spark/src/test/scala/org/apache/zeppelin/spark/display/DisplayTraversableTest.scala new file mode 100644 index 00000000000..bea6ac6b253 --- /dev/null +++ b/spark/src/test/scala/org/apache/zeppelin/spark/display/DisplayTraversableTest.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.zeppelin.spark.display + +import java.io.{PrintStream, ByteArrayOutputStream} +import java.util.Arrays.asList + +import org.apache.zeppelin.context.ZeppelinContext +import org.apache.zeppelin.display.DisplayParams +import org.apache.zeppelin.interpreter.InterpreterException +import org.scalatest._ +import org.scalatest.mock.MockitoSugar + + +class DisplayTraversableTest extends FlatSpec + with BeforeAndAfter + with BeforeAndAfterEach + with Matchers + with MockitoSugar { + + var testTuples: List[(String, String, Int)] = null + var testPersons: List[Person] = null + var z: ZeppelinContext = new ZeppelinContext(2) + var stream: ByteArrayOutputStream = null + var printStream: PrintStream = null + + before { + testTuples = List(("jdoe", "John DOE", 32), ("hsue", "Helen SUE", 27), ("rsmith", "Richard SMITH", 45)) + testPersons = List(Person("jdoe", "John DOE", 32), Person("hsue", "Helen SUE", 27), Person("rsmith", "Richard SMITH", 45)) + z.registerDisplayFunction(new DisplayTraversable) + } + + override def beforeEach(): Unit = { + stream = new java.io.ByteArrayOutputStream() + printStream = new PrintStream(stream) + } + + "DisplayTraversable" should "generate correct column headers for tuples" in { + z.display(testTuples, DisplayParams(100, printStream, null, asList("Login","Name","Age"))) + + stream.toString("UTF-8") should be("%table Login\tName\tAge\n" + + "jdoe\tJohn DOE\t32\n" + + "hsue\tHelen SUE\t27\n" + + "rsmith\tRichard SMITH\t45\n") + } + + "DisplayTraversable" should "generate correct column headers for case class" in { + z.display(testPersons, DisplayParams(100, printStream, null, asList("Login","Name","Age"))) + + stream.toString("UTF-8") should be("%table Login\tName\tAge\n" + + "jdoe\tJohn DOE\t32\n" + + "hsue\tHelen SUE\t27\n" + + "rsmith\tRichard SMITH\t45\n") + } + + "DisplayTraversable" should "truncate exceeding column headers for tuples" in { + z.display(testTuples, DisplayParams(100, printStream, null, asList("Login","Name","Age","xxx","yyy"))) + + stream.toString("UTF-8") should be("%table Login\tName\tAge\n" + + "jdoe\tJohn DOE\t32\n" + + "hsue\tHelen SUE\t27\n" + + "rsmith\tRichard SMITH\t45\n") + } + + "DisplayTraversable" should "pad missing column headers with ColumnXXX for tuples" in { + z.display(testTuples, DisplayParams(100, printStream, null, asList("Login"))) + + stream.toString("UTF-8") should be("%table Login\tColumn2\tColumn3\n" + + "jdoe\tJohn DOE\t32\n" + + "hsue\tHelen SUE\t27\n" + + "rsmith\tRichard SMITH\t45\n") + } + + "DisplayTraversable" should "display tuples with limit" in { + z.display(testTuples, DisplayParams(2, printStream, null, asList("Login","Name","Age"))) + + stream.toString("UTF-8") should be("%table Login\tName\tAge\n" + + "jdoe\tJohn DOE\t32\n" + + "hsue\tHelen SUE\t27\n") + } + + + "DisplayTraversable" should "refuse to display non Product collection" in { + val exception = intercept[InterpreterException] { + z.display(List[String]("a","b", "c"), DisplayParams(100, printStream, null, asList("Value"))) + } + exception.getMessage should be ("a should be an instance of scala.Product (case class or tuple)") + } + + "DisplayTraversable" should "exception when displaying mixed collection with Product and non Product" in { + val exception = intercept[InterpreterException] { + z.display(List(("a","b"), 1, "c"), DisplayParams(100, printStream, null, asList("Value"))) + } + exception.getMessage should be ("1 should be an instance of scala.Product (case class or tuple)") + } +} \ No newline at end of file diff --git a/spark/src/test/scala/org/apache/zeppelin/spark/utils/DisplayFunctionsTest.scala b/spark/src/test/scala/org/apache/zeppelin/spark/utils/DisplayFunctionsTest.scala deleted file mode 100644 index 2638f1710e9..00000000000 --- a/spark/src/test/scala/org/apache/zeppelin/spark/utils/DisplayFunctionsTest.scala +++ /dev/null @@ -1,173 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.zeppelin.spark.utils - -import java.io.ByteArrayOutputStream - -import org.apache.spark.rdd.RDD -import org.apache.spark.{SparkContext, SparkConf} -import org.scalatest._ -import org.scalatest.{BeforeAndAfter} - -case class Person(login : String, name: String, age: Int) - -class DisplayFunctionsTest extends FlatSpec with BeforeAndAfter with BeforeAndAfterEach with Matchers { - var sc: SparkContext = null - var testTuples:List[(String, String, Int)] = null - var testPersons:List[Person] = null - var testRDDTuples: RDD[(String,String,Int)] = null - var testRDDPersons: RDD[Person] = null - var stream: ByteArrayOutputStream = null - - before { - val sparkConf: SparkConf = new SparkConf(true) - .setAppName("test-DisplayFunctions") - .setMaster("local") - sc = new SparkContext(sparkConf) - testTuples = List(("jdoe", "John DOE", 32), ("hsue", "Helen SUE", 27), ("rsmith", "Richard SMITH", 45)) - testRDDTuples = sc.parallelize(testTuples) - testPersons = List(Person("jdoe", "John DOE", 32), Person("hsue", "Helen SUE", 27), Person("rsmith", "Richard SMITH", 45)) - testRDDPersons = sc.parallelize(testPersons) - } - - override def beforeEach() { - stream = new java.io.ByteArrayOutputStream() - super.beforeEach() // To be stackable, must call super.beforeEach - } - - - "DisplayFunctions" should "generate correct column headers for tuples" in { - implicit val sparkMaxResult = new SparkMaxResult(100) - Console.withOut(stream) { - new DisplayRDDFunctions[(String,String,Int)](testRDDTuples).display("Login","Name","Age") - } - - stream.toString("UTF-8") should be("%table Login\tName\tAge\n" + - "jdoe\tJohn DOE\t32\n" + - "hsue\tHelen SUE\t27\n" + - "rsmith\tRichard SMITH\t45\n") - } - - "DisplayFunctions" should "generate correct column headers for case class" in { - implicit val sparkMaxResult = new SparkMaxResult(100) - Console.withOut(stream) { - new DisplayRDDFunctions[Person](testRDDPersons).display("Login","Name","Age") - } - - stream.toString("UTF-8") should be("%table Login\tName\tAge\n" + - "jdoe\tJohn DOE\t32\n" + - "hsue\tHelen SUE\t27\n" + - "rsmith\tRichard SMITH\t45\n") - } - - "DisplayFunctions" should "truncate exceeding column headers for tuples" in { - implicit val sparkMaxResult = new SparkMaxResult(100) - Console.withOut(stream) { - new DisplayRDDFunctions[(String,String,Int)](testRDDTuples).display("Login","Name","Age","xxx","yyy") - } - - stream.toString("UTF-8") should be("%table Login\tName\tAge\n" + - "jdoe\tJohn DOE\t32\n" + - "hsue\tHelen SUE\t27\n" + - "rsmith\tRichard SMITH\t45\n") - } - - "DisplayFunctions" should "pad missing column headers with ColumnXXX for tuples" in { - implicit val sparkMaxResult = new SparkMaxResult(100) - Console.withOut(stream) { - new DisplayRDDFunctions[(String,String,Int)](testRDDTuples).display("Login") - } - - stream.toString("UTF-8") should be("%table Login\tColumn2\tColumn3\n" + - "jdoe\tJohn DOE\t32\n" + - "hsue\tHelen SUE\t27\n" + - "rsmith\tRichard SMITH\t45\n") - } - - "DisplayUtils" should "restricts RDD to sparkMaxresult with implicit limit" in { - - implicit val sparkMaxResult = new SparkMaxResult(2) - - Console.withOut(stream) { - new DisplayRDDFunctions[(String,String,Int)](testRDDTuples).display("Login") - } - - stream.toString("UTF-8") should be("%table Login\tColumn2\tColumn3\n" + - "jdoe\tJohn DOE\t32\n" + - "hsue\tHelen SUE\t27\n") - } - - "DisplayUtils" should "restricts RDD to sparkMaxresult with explicit limit" in { - - implicit val sparkMaxResult = new SparkMaxResult(2) - - Console.withOut(stream) { - new DisplayRDDFunctions[(String,String,Int)](testRDDTuples).display(1,"Login") - } - - stream.toString("UTF-8") should be("%table Login\tColumn2\tColumn3\n" + - "jdoe\tJohn DOE\t32\n") - } - - "DisplayFunctions" should "display traversable of tuples" in { - - Console.withOut(stream) { - new DisplayTraversableFunctions[(String,String,Int)](testTuples).display("Login","Name","Age") - } - - stream.toString("UTF-8") should be("%table Login\tName\tAge\n" + - "jdoe\tJohn DOE\t32\n" + - "hsue\tHelen SUE\t27\n" + - "rsmith\tRichard SMITH\t45\n") - } - - "DisplayFunctions" should "display traversable of case class" in { - - Console.withOut(stream) { - new DisplayTraversableFunctions[Person](testPersons).display("Login","Name","Age") - } - - stream.toString("UTF-8") should be("%table Login\tName\tAge\n" + - "jdoe\tJohn DOE\t32\n" + - "hsue\tHelen SUE\t27\n" + - "rsmith\tRichard SMITH\t45\n") - } - - "DisplayUtils" should "display HTML" in { - DisplayUtils.html() should be ("%html ") - DisplayUtils.html("test") should be ("%html test") - } - - "DisplayUtils" should "display img" in { - DisplayUtils.img("http://www.google.com") should be ("") - DisplayUtils.img64() should be ("%img ") - DisplayUtils.img64("abcde") should be ("%img abcde") - } - - override def afterEach() { - try super.afterEach() // To be stackable, must call super.afterEach - stream = null - } - - after { - sc.stop() - } - - -} - - diff --git a/zeppelin-interpreter/pom.xml b/zeppelin-interpreter/pom.xml index 8251055395b..c630dad6f5a 100644 --- a/zeppelin-interpreter/pom.xml +++ b/zeppelin-interpreter/pom.xml @@ -15,7 +15,6 @@ ~ See the License for the specific language governing permissions and ~ limitations under the License. --> - @@ -35,8 +34,19 @@ Zeppelin Interpreter http://zeppelin.incubator.apache.org + + 2.10.4 + 2.10 + + + + org.scala-lang + scala-library + ${scala.version} + + org.apache.thrift libthrift @@ -81,5 +91,60 @@ 1.9.0 test + + + + org.scalatest + scalatest_${scala.binary.version} + 2.2.4 + test + + + + + + + + org.scala-tools + maven-scala-plugin + + + compile + + compile + + compile + + + test-compile + + testCompile + + test-compile + + + process-resources + + compile + + + + + + + org.scalatest + scalatest-maven-plugin + 1.0 + + + test + + test + + + + + + diff --git a/zeppelin-interpreter/src/main/java/org/apache/zeppelin/display/DisplayFunction.java b/zeppelin-interpreter/src/main/java/org/apache/zeppelin/display/DisplayFunction.java new file mode 100644 index 00000000000..dbfffdeb16f --- /dev/null +++ b/zeppelin-interpreter/src/main/java/org/apache/zeppelin/display/DisplayFunction.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.zeppelin.display; + +import org.apache.zeppelin.interpreter.InterpreterContext; + +/** + * The DisplayFunction interface + * + * Any interpreter that wan to use zeppelinContext.display(...) + * should register at least one implementation of this class + * in the zeppelinContext + */ +public interface DisplayFunction { + + /** + * Whether this display function can handle the give + * object + * @param anyObject target object + * @return true/false + */ + boolean canDisplay(Object anyObject); + + /** + * Generate the display output for the given input object + * Warning: this method should not be invoked + * without invoking canDisplay(...) + * + * + * @param anyObject input object + * @param displayParams display parameters + * @return the formatted display + */ + void display(Object anyObject, DisplayParams displayParams); + +} diff --git a/zeppelin-interpreter/src/main/java/org/apache/zeppelin/display/Input.java b/zeppelin-interpreter/src/main/java/org/apache/zeppelin/display/Input.java index 2f7858ca03c..2872548096d 100644 --- a/zeppelin-interpreter/src/main/java/org/apache/zeppelin/display/Input.java +++ b/zeppelin-interpreter/src/main/java/org/apache/zeppelin/display/Input.java @@ -23,6 +23,7 @@ import java.util.LinkedList; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -65,6 +66,19 @@ public void setDisplayName(String displayName) { this.displayName = displayName; } + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ParamOption that = (ParamOption) o; + return Objects.equals(value, that.value) && + Objects.equals(displayName, that.displayName); + } + + @Override + public int hashCode() { + return Objects.hash(value, displayName); + } } String name; diff --git a/zeppelin-interpreter/src/main/scala/org/apache/zeppelin/context/ZeppelinContext.scala b/zeppelin-interpreter/src/main/scala/org/apache/zeppelin/context/ZeppelinContext.scala new file mode 100644 index 00000000000..2c895fa0139 --- /dev/null +++ b/zeppelin-interpreter/src/main/scala/org/apache/zeppelin/context/ZeppelinContext.scala @@ -0,0 +1,576 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.zeppelin.context + +import java.util + +import org.apache.zeppelin.display.Input.ParamOption +import org.apache.zeppelin.display._ +import org.apache.zeppelin.interpreter.{InterpreterContext, InterpreterContextRunner, InterpreterException} +import org.slf4j.{Logger, LoggerFactory} + +import scala.collection.JavaConversions.{collectionAsScalaIterable} +import scala.collection.mutable +import scala.collection.mutable.{ArrayBuffer} +import scala.collection.JavaConverters._ + + +/** + * ZeppelinContext + * @param defaultMaxResult the default max result + * + */ +class ZeppelinContext(private var defaultMaxResult: Int) extends mutable.HashMap[String,Any] { + + val logger:Logger = LoggerFactory.getLogger(classOf[ZeppelinContext]) + + private var gui: GUI = null + private var interpreterContext: InterpreterContext = null + val displayFunctionRegistry = mutable.ArrayBuffer.empty[DisplayFunction] + + /** + * Register a new display function + * @param displayFunction + */ + def registerDisplayFunction(displayFunction: DisplayFunction): Unit = { + if (logger.isDebugEnabled()) logger.debug(s"Registering display function $displayFunction") + + if (!displayFunctionRegistry.contains(displayFunction)) { + displayFunctionRegistry.append(displayFunction) + } else { + logger.warn(s"Display function $displayFunction is already registered !") + } + } + + /** + * Display the current object to the standard console output + * @param obj current object to display + * @return formatted output + */ + def display(obj: AnyRef): Unit = { + display(obj, DisplayParams(defaultMaxResult, Console.out, interpreterContext, List[String]().asJava)) + } + + /** + * Display the current object to the standard console output, + * restricted to the first 'maxResult' items + * @param obj current object to display + * @param maxResult max results + * @return formatted output + */ + def display(obj: AnyRef, maxResult: Int): Unit = { + display(obj, DisplayParams(maxResult, Console.out, interpreterContext, List[String]().asJava)) + } + + /** + * Display the current object to the standard console output, + * with the provided columns label optionally + * @param obj current object to display + * @param firstColumnLabel first columns label + * @param remainingColumnsLabel remaining columns label + * @return formatted output + */ + def display(obj: AnyRef, firstColumnLabel: String, remainingColumnsLabel: String*): Unit = { + val orElse = Option(remainingColumnsLabel).getOrElse(Seq[String]()) + val columnsLabel: List[String] = List(firstColumnLabel) ::: orElse.toList + display(obj, DisplayParams(defaultMaxResult, Console.out, interpreterContext,columnsLabel.asJava)) + } + + /** + * Display the current object to the standard console output, + * with the provided columns label optionally + * @param obj current object to display + * @param maxResult max results + * @param columnsLabel columns label + * @return formatted output + */ + def display(obj: AnyRef, maxResult: Int, columnsLabel: String*): Unit = { + val safeList = Option(columnsLabel).getOrElse(Seq[String]()) + display(obj, DisplayParams(maxResult, Console.out, interpreterContext, safeList.asJava)) + } + + /** + * Display the current object with display parameters + * @param displayParams display parameters + * @param obj current object to display + * @return formatted output + */ + def display(obj: AnyRef, displayParams: DisplayParams): Unit = { + + if (logger.isDebugEnabled()) logger.debug(s"Attempting to display $obj with params $displayParams") + + require(obj != null, "Cannot display null object") + + val max = Option(displayParams.maxResult).getOrElse(defaultMaxResult) + val columnsLabel = Option(displayParams.columnsLabel).getOrElse(List[String]().asJava) + val context = Option(displayParams.context).getOrElse(interpreterContext) + val newParams = displayParams.copy(maxResult = max, columnsLabel = columnsLabel, context = context) + + val matchedDisplayedFunctions: ArrayBuffer[DisplayFunction] = displayFunctionRegistry + .filter(_.canDisplay(obj)) + + if (logger.isDebugEnabled()) + logger.debug(s"""Matched display function(s) found for $obj: ${matchedDisplayedFunctions.mkString(",")}""") + + val displayFunction: Option[DisplayFunction] = matchedDisplayedFunctions.headOption + + matchedDisplayedFunctions.size match { + case 0 => throw new InterpreterException(s"Cannot find any suitable display function for object ${obj.toString}") + case 1 => displayFunction.get.display(obj, newParams) + case _ => { + logger.warn(s"More than one display function found for type ${obj.getClass}. Will use the first one : $displayFunction") + displayFunction.get.display(obj, newParams) + } + } + } + + /** + * Define and retrieve the current value of an input from the GUI + * @param name name of the GUI input + * @return current input value + */ + def input (name: String): Any = { + input(name, "") + } + + /** + * Define a default value and retrieve the current value of an input from the GUI + * @param name name of the GUI input + * @param defaultValue default value for the input + * @return current input value + */ + def input(name: String, defaultValue: Any): Any = { + gui.input(name, defaultValue) + } + + /** + * Define an HMTL <select> and retrieve the current value from the GUI + * @param name name of the HMTL <select> + * @param options list of (value,displayedValue) + * @return current selected element + */ + def select(name:String, options: Iterable[(Any, String)]):Any = { + select(name, "", options) + } + + /** + * Define an HMTL <select> with a default value and retrieve the current value from the GUI + * @param name name of the HMTL <select> + * @param defaultValue default selected value + * @param options list of (value,displayedValue) + * @return current selected element + */ + def select(name: String, defaultValue: Any, options: Iterable[(Any, String)]): Any = { + val paramOptions = options.map{case(value,label) => new ParamOption(value,label)}.toArray + gui.select(name, defaultValue, paramOptions) + } + + def setGui(o: GUI) { + this.gui = o + } + + /** + * Set max result for display + * @param maxResult max result for display + */ + def setMaxResult(maxResult: Int):Unit = { + this.defaultMaxResult = maxResult + } + + def getMaxResult(): Int = { + this.defaultMaxResult + } + + def getInterpreterContext: InterpreterContext = interpreterContext + + def setInterpreterContext(interpreterContext: InterpreterContext):Unit = { + this.interpreterContext = interpreterContext + } + + /** + * Run paragraph by id + * @param id paragraph id + */ + def run(id: String):Unit = { + run(id, this.interpreterContext) + } + + /** + * Run paragraph by id + * @param id paragraph id + * @param context interpreter context + */ + def run(id: String, context: InterpreterContext):Unit = { + if (id == context.getParagraphId) { + throw new InterpreterException("Can not run current Paragraph") + } + + context.getRunners.filter(r => r.getParagraphId == id).foreach(r => { + r.run + return + }) + throw new InterpreterException("Paragraph " + id + " not found") + } + + /** + * Run paragraph at idx + * @param idx paragraph index to run + */ + def run(idx: Int):Unit = { + run(idx, this.interpreterContext) + } + + /** + * Run paragraph at index + * @param idx index starting from 0 + * @param context interpreter context + */ + def run(idx: Int, context: InterpreterContext):Unit = { + val runners: util.List[InterpreterContextRunner] = context.getRunners + if (idx >= runners.size) { + throw new InterpreterException("Index out of bound") + } + val runner: InterpreterContextRunner = runners.get(idx) + if (runner.getParagraphId == context.getParagraphId) { + throw new InterpreterException("Can not run current Paragraph") + } + runner.run + } + + /** + * Run paragraphs using their index or id + * @param paragraphIdOrIdx list of paragraph id or idx + */ + def run(paragraphIdOrIdx: List[Any]):Unit = { + run(paragraphIdOrIdx, this.interpreterContext) + } + + /** + * Run paragraphs using their index or id + * @param paragraphIdOrIdx list of paragraph id or idx + */ + def run(paragraphIdOrIdx: List[Any], context: InterpreterContext):Unit = { + + val unknownIds: List[Any] = paragraphIdOrIdx.filter(x => (!x.isInstanceOf[String] && !x.isInstanceOf[Int])) + if (unknownIds.size > 0) { + throw new InterpreterException(s"""Paragraphs ${unknownIds.mkString(",")} not found""") + } + paragraphIdOrIdx + .filter(_.isInstanceOf[String]) + .foreach(id => run(id.asInstanceOf[String], context)) + + paragraphIdOrIdx + .filter(_.isInstanceOf[Int]) + .foreach(index => run(index.asInstanceOf[Int], context)) + + } + + /** + * Run all paragraphs, except the current + */ + def runAll:Unit = { + runAll(interpreterContext) + } + + /** + * Run all paragraphs. except the current + */ + def runAll(context: InterpreterContext):Unit = { + context.getRunners.foreach(r => { + if (r.getParagraphId != context.getParagraphId) r.run + }) + } + + /** + * List all paragraps + * @return paragraph ids as list + */ + def listParagraphs: List[String] = { + interpreterContext.getRunners.map(_.getParagraphId).toList + } + + /** + * Retrieve an Angular variable by name + * @param name variable name + * @return variable value + */ + def angular(name: String): Any = { + Option(getAngularObject(name, interpreterContext)) + .map(_.get) + .getOrElse(null) + } + + /** + * Get angular object. Look up global registry + * @param name variable name + * @return value + */ + def angularGlobal(name: String): AnyRef = { + val registry: AngularObjectRegistry = interpreterContext.getAngularObjectRegistry + Option(registry.get(name, null)) + .map(_.get) + .getOrElse(null) + } + + /** + * Create angular variable in local registry and bind with front end Angular display system. + * If variable exists, it'll be overwritten. + * @param name name of the variable + * @param o value + */ + def angularBind(name: String, o: AnyRef) { + angularBindWithNodeId(name, o, interpreterContext.getNoteId) + } + + /** + * Create angular variable in global registry and bind with front end Angular display system. + * If variable exists, it'll be overwritten. + * @param name name of the variable + * @param o value + */ + def angularBindGlobal(name: String, o: AnyRef) { + angularBindWithNodeId(name, o, null) + } + + /** + * Create angular variable in local registry and bind with front end Angular display system. + * If variable exists, value will be overwritten and watcher will be added. + * @param name name of variable + * @param o value + * @param watcher watcher of the variable + */ + def angularBind(name: String, o: AnyRef, watcher: AngularObjectWatcher) { + angularBindWithWatcher(name, o, interpreterContext.getNoteId, watcher) + } + + /** + * Create angular variable in global registry and bind with front end Angular display system. + * If variable exists, value will be overwritten and watcher will be added. + * @param name name of variable + * @param o value + * @param watcher watcher of the variable + */ + def angularBindGlobal(name: String, o: AnyRef, watcher: AngularObjectWatcher) { + angularBindWithWatcher(name, o, null, watcher) + } + + /** + * Add watcher into angular variable (local registry) + * @param name name of the variable + * @param watcher watcher + */ + def angularWatch(name: String, watcher: AngularObjectWatcher) { + angularWatch(name, interpreterContext.getNoteId, watcher) + } + + /** + * Add watcher into angular variable (global registry) + * @param name name of the variable + * @param watcher watcher + */ + def angularWatchGlobal(name: String, watcher: AngularObjectWatcher) { + angularWatch(name, null, watcher) + } + + def angularWatch(name: String, func: (AnyRef, AnyRef) => Unit) { + angularWatch(name, interpreterContext.getNoteId, func) + } + + def angularWatchGlobal(name: String, func: (AnyRef, AnyRef) => Unit) { + angularWatch(name, null, func) + } + + def angularWatch(name: String, func: (AnyRef, AnyRef, InterpreterContext) => Unit) { + angularWatch(name, interpreterContext.getNoteId, func) + } + + def angularWatchGlobal(name: String, func: (AnyRef, AnyRef, InterpreterContext) => Unit) { + angularWatch(name, null, func) + } + + /** + * Remove watcher from angular variable (local) + * @param name + * @param watcher + */ + def angularUnwatch(name: String, watcher: AngularObjectWatcher) { + angularUnwatchFromWatcher(name, interpreterContext.getNoteId, watcher) + } + + /** + * Remove watcher from angular variable (global) + * @param name + * @param watcher + */ + def angularUnwatchGlobal(name: String, watcher: AngularObjectWatcher) { + angularUnwatchFromWatcher(name, null, watcher) + } + + /** + * Remove all watchers for the angular variable (local) + * @param name + */ + def angularUnwatch(name: String) { + angularUnwatchFromNoteId(name, interpreterContext.getNoteId) + } + + /** + * Remove all watchers for the angular variable (global) + * @param name + */ + def angularUnwatchGlobal(name: String) { + angularUnwatchFromNoteId(name, null) + } + + /** + * Remove angular variable and all the watchers. + * @param name + */ + def angularUnbind(name: String) { + val noteId: String = interpreterContext.getNoteId + angularUnbind(name, noteId) + } + + /** + * Remove angular variable and all the watchers. + * @param name + */ + def angularUnbindGlobal(name: String) { + angularUnbind(name, null) + } + + private def getAngularObject(name: String, interpreterContext: InterpreterContext): AngularObject[_] = { + val registry: AngularObjectRegistry = interpreterContext.getAngularObjectRegistry + Option(registry.get(name, interpreterContext.getNoteId)) + .getOrElse(registry.get(name, null)) + } + + /** + * Create angular variable in local registry and bind with front end Angular display system. + * If variable exists, it'll be overwritten. + * @param name name of the variable + * @param o value + */ + private def angularBindWithNodeId(name: String, o: AnyRef, noteId: String) { + val registry: AngularObjectRegistry = interpreterContext.getAngularObjectRegistry + if (registry.get(name, noteId) == null) { + registry.add(name, o, noteId) + } + else { + registry.get(name, noteId) + .asInstanceOf[AngularObject[AnyRef]] + .set(o) + } + } + + /** + * Create angular variable in local registry and bind with front end Angular display system. + * If variable exists, value will be overwritten and watcher will be added. + * @param name name of variable + * @param o value + * @param watcher watcher of the variable + */ + private def angularBindWithWatcher(name: String, o: AnyRef, noteId: String, watcher: AngularObjectWatcher) { + val registry: AngularObjectRegistry = interpreterContext.getAngularObjectRegistry + if (registry.get(name, noteId) == null) { + registry.add(name, o, noteId) + } + else { + registry.get(name, noteId) + .asInstanceOf[AngularObject[AnyRef]] + .set(o) + } + angularWatch(name, watcher) + } + + /** + * Add watcher into angular binding variable + * @param name name of the variable + * @param watcher watcher + */ + private def angularWatch(name: String, noteId: String, watcher: AngularObjectWatcher) { + Option(interpreterContext.getAngularObjectRegistry) + .foreach(_.get(name, noteId).addWatcher(watcher)) + } + + private def angularWatch(name: String, noteId: String, func: (AnyRef, AnyRef) => Unit) { + val w: AngularObjectWatcher = new AngularObjectWatcher((getInterpreterContext)) { + def watch(oldObject: AnyRef, newObject: AnyRef, context: InterpreterContext) { + func.apply(newObject, newObject) + } + } + angularWatch(name, noteId, w) + } + + private def angularWatch(name: String, noteId: String, func: (AnyRef, AnyRef, InterpreterContext) => Unit) { + val w: AngularObjectWatcher = new AngularObjectWatcher((getInterpreterContext)) { + def watch(oldObject: AnyRef, newObject: AnyRef, context: InterpreterContext) { + func.apply(oldObject, newObject, context) + } + } + angularWatch(name, noteId, w) + } + + /** + * Remove watcher + * @param name + * @param watcher + */ + private def angularUnwatchFromWatcher(name: String, noteId: String, watcher: AngularObjectWatcher) { + Option(interpreterContext.getAngularObjectRegistry) + .foreach(_.get(name, noteId).removeWatcher(watcher)) + } + + /** + * Remove all watchers for the angular variable + * @param name + */ + private def angularUnwatchFromNoteId(name: String, noteId: String) { + Option(interpreterContext.getAngularObjectRegistry) + .foreach(_.get(name, noteId).clearAllWatchers) + } + + /** + * Remove angular variable and all the watchers. + * @param name + */ + private def angularUnbind(name: String, noteId: String) { + interpreterContext.getAngularObjectRegistry.remove(name, noteId) + } + + /** + * Display HTML code + * @param htmlContent unescaped HTML content + * @return HTML content prefixed by the magic %html + */ + def html(htmlContent: String = ""):String = s"%html $htmlContent" + + /** + * Display image using base 64 content + * @param base64Content base64 content + * @return base64Content prefixed by the magic %img + */ + def img64(base64Content: String = "") = s"%img $base64Content" + + /** + * Display image using URL + * @param url image URL + * @return a HTML <img> tag with src = base64 content + */ + def img(url: String) = s"" +} diff --git a/zeppelin-interpreter/src/main/scala/org/apache/zeppelin/display/DisplayParams.scala b/zeppelin-interpreter/src/main/scala/org/apache/zeppelin/display/DisplayParams.scala new file mode 100644 index 00000000000..afe62ff5a1f --- /dev/null +++ b/zeppelin-interpreter/src/main/scala/org/apache/zeppelin/display/DisplayParams.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.zeppelin.display + +import java.io.PrintStream +import org.apache.zeppelin.interpreter.InterpreterContext + +/** + * Parameters for display function + * @param maxResult restrict display to 'maxResult' items + * @param out output stream. If not provided, it will be Console output + * @param context interpreter context, to provide extra info + * @param columnsLabel columns label + */ +case class DisplayParams(maxResult: Int, out: PrintStream, context: InterpreterContext, columnsLabel: java.util.List[String]) diff --git a/zeppelin-interpreter/src/test/scala/org/apache/zeppelin/context/ZeppelinContextTest.scala b/zeppelin-interpreter/src/test/scala/org/apache/zeppelin/context/ZeppelinContextTest.scala new file mode 100644 index 00000000000..7bc9deabc10 --- /dev/null +++ b/zeppelin-interpreter/src/test/scala/org/apache/zeppelin/context/ZeppelinContextTest.scala @@ -0,0 +1,386 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.zeppelin.context + +import java.io.{PrintStream, ByteArrayOutputStream} +import java.util.concurrent.atomic.AtomicBoolean + + +import collection.JavaConversions._ +import org.apache.zeppelin.display._ +import org.apache.zeppelin.display.Input.ParamOption +import org.apache.zeppelin.interpreter.{InterpreterException, InterpreterContextRunner, InterpreterContext} +import org.mockito.Mockito.when +import org.scalatest._ +import org.scalatest.mock.MockitoSugar + +class ZeppelinContextTest extends FlatSpec + with BeforeAndAfter + with BeforeAndAfterEach + with Matchers + with MockitoSugar { + + var stream: ByteArrayOutputStream = null + var registry: AngularObjectRegistry = null + val interpreterContext: InterpreterContext = mock[InterpreterContext] + val gui: GUI = mock[GUI] + val out: PrintStream = mock[PrintStream] + val noteId: String = "myNoteId" + var zeppelinContext: ZeppelinContext = null + + override def beforeEach() { + stream = new java.io.ByteArrayOutputStream() + registry = new AngularObjectRegistry("int1", new AngularListener) + when(interpreterContext.getAngularObjectRegistry).thenReturn(registry:AngularObjectRegistry) + when(interpreterContext.getNoteId).thenReturn(noteId) + zeppelinContext = new ZeppelinContext(1000) + zeppelinContext.setInterpreterContext(interpreterContext) + super.beforeEach() // To be stackable, must call super.beforeEach + } + + "ZeppelinContext" should "display HTML" in { + zeppelinContext.html() should be ("%html ") + zeppelinContext.html("test") should be ("%html test") + } + + "ZeppelinContext" should "display img" in { + zeppelinContext.img("http://www.google.com") should be ("") + zeppelinContext.img64() should be ("%img ") + zeppelinContext.img64("abcde") should be ("%img abcde") + } + + "ZeppelinContext" should "add input to GUI with default value" in { + //Given + zeppelinContext.setGui(gui) + when(gui.input("test input", "default")).thenReturn("defaultVal", Nil:_*) + + //When + val actual: Any = zeppelinContext.input("test input", "default") + + //Then + actual should be("defaultVal") + + } + + "ZeppelinContext" should "add select to GUI with default value" in { + //Given + zeppelinContext.setGui(gui) + val paramOptions: Array[ParamOption] = Seq(new ParamOption(1,"1"),new ParamOption(2,"2")).toArray + when(gui.select("test select", "1",paramOptions)).thenReturn("1", Nil:_*) + + //When + val seq: Seq[(Any, String)] = Seq((1, "1"), (2, "2")) + val actual: Any = zeppelinContext.select("test select", "1",seq) + + //Then + actual should be("1") + } + + "ZeppelinContext" should "run paragraph by id" in { + //Given + val hasRun1 = new AtomicBoolean(false) + val hasRun2 = new AtomicBoolean(false) + val runner1:InterpreterContextRunner = new InterpreterContextRunner("1", "par1") { + override def run(): Unit = {hasRun1.getAndSet(true) } + } + val runner2:InterpreterContextRunner = new InterpreterContextRunner("2", "par2") { + override def run(): Unit = {hasRun2.getAndSet(true)} + } + when(interpreterContext.getRunners()).thenReturn(Seq(runner1,runner2),Nil:_*) + + //When + zeppelinContext.run("par1", interpreterContext) + + //Then + hasRun1.get should be(true) + hasRun2.get should be(false) + } + + "ZeppelinContext" should "run paragraph by index" in { + //Given + val hasRun1 = new AtomicBoolean(false) + val hasRun2 = new AtomicBoolean(false) + val runner1:InterpreterContextRunner = new InterpreterContextRunner("1", "par1") { + override def run(): Unit = {hasRun1.getAndSet(true) } + } + val runner2:InterpreterContextRunner = new InterpreterContextRunner("2", "par2") { + override def run(): Unit = {hasRun2.getAndSet(true)} + } + when(interpreterContext.getParagraphId).thenReturn("whatever",Nil:_*) + when(interpreterContext.getRunners()).thenReturn(Seq(runner1,runner2),null) + + //When + zeppelinContext.run(0, interpreterContext) + + //Then + hasRun1.get should be(true) + hasRun2.get should be(false) + } + + "ZeppelinContext" should "not run current paragraph" in { + //Given + val hasRun1 = new AtomicBoolean(false) + val hasRun2 = new AtomicBoolean(false) + val runner1:InterpreterContextRunner = new InterpreterContextRunner("1", "par1") { + override def run(): Unit = {hasRun1.getAndSet(true) } + } + val runner2:InterpreterContextRunner = new InterpreterContextRunner("2", "par2") { + override def run(): Unit = {hasRun2.getAndSet(true)} + } + when(interpreterContext.getParagraphId).thenReturn("par1",Nil:_*) + when(interpreterContext.getRunners()).thenReturn(Seq(runner1,runner2),null) + + //When + intercept[InterpreterException] { + zeppelinContext.run(0, interpreterContext) + } + + //Then + hasRun1.get should be(false) + hasRun2.get should be(false) + } + + "ZeppelinContext" should "run paragraphs by index and id" in { + //Given + val hasRun1 = new AtomicBoolean(false) + val hasRun2 = new AtomicBoolean(false) + val hasRun3 = new AtomicBoolean(false) + val runner1:InterpreterContextRunner = new InterpreterContextRunner("1", "par1") { + override def run(): Unit = {hasRun1.getAndSet(true) } + } + val runner2:InterpreterContextRunner = new InterpreterContextRunner("2", "par2") { + override def run(): Unit = {hasRun2.getAndSet(true)} + } + val runner3:InterpreterContextRunner = new InterpreterContextRunner("3", "par3") { + override def run(): Unit = {hasRun3.getAndSet(true)} + } + when(interpreterContext.getParagraphId).thenReturn("par10",Nil:_*) + when(interpreterContext.getRunners()).thenReturn(Seq(runner1,runner2,runner3): java.util.List[InterpreterContextRunner]) + + //When + zeppelinContext.run(List("par1",new Integer(1),"par3"):List[Any], interpreterContext) + + //Then + hasRun1.get should be(true) + hasRun2.get should be(true) + hasRun3.get should be(true) + } + + "ZeppelinContext" should "run all paragraphs except the current" in { + //Given + val hasRun1 = new AtomicBoolean(false) + val hasRun2 = new AtomicBoolean(false) + val hasRun3 = new AtomicBoolean(false) + val runner1:InterpreterContextRunner = new InterpreterContextRunner("1", "par1") { + override def run(): Unit = {hasRun1.getAndSet(true) } + } + val runner2:InterpreterContextRunner = new InterpreterContextRunner("2", "par2") { + override def run(): Unit = {hasRun2.getAndSet(true)} + } + val runner3:InterpreterContextRunner = new InterpreterContextRunner("3", "par3") { + override def run(): Unit = {hasRun3.getAndSet(true)} + } + when(interpreterContext.getParagraphId).thenReturn("par1",Nil:_*) + when(interpreterContext.getRunners()).thenReturn(Seq(runner1,runner2,runner3): java.util.List[InterpreterContextRunner]) + + //When + zeppelinContext.runAll(interpreterContext) + + //Then + hasRun1.get should be(false) + hasRun2.get should be(true) + hasRun3.get should be(true) + } + + "ZeppelinContext" should "list paragraph ids" in { + //Given + val runner1:InterpreterContextRunner = new InterpreterContextRunner("1", "par1") { + override def run(): Unit = {} + } + val runner2:InterpreterContextRunner = new InterpreterContextRunner("2", "par2") { + override def run(): Unit = {} + } + val runner3:InterpreterContextRunner = new InterpreterContextRunner("3", "par3") { + override def run(): Unit = {} + } + when(interpreterContext.getRunners()).thenReturn(Seq(runner1,runner2,runner3): java.util.List[InterpreterContextRunner]) + + //When + val actual: List[String] = zeppelinContext.listParagraphs + + //Then + actual should be(List("par1", "par2", "par3")) + } + + "ZeppelinContext" should "fetch Angular object by name" in { + //Given + registry.add("name1", "val", noteId) + + //When + val actual = zeppelinContext.angular("name1") + + //Then + actual should be("val") + } + + "ZeppelinContext" should "fetch null if no Angular object found" in { + //Given + + //When + val actual = zeppelinContext.angular("name2") + + //Then + assert(actual == null) + } + + "ZeppelinContext" should "bind Angular object by name" in { + //Given + + //When + zeppelinContext.angularBind("name3", "value") + + //Then + registry.get("name3", noteId).get() should be("value") + } + + "ZeppelinContext" should "update bound Angular object by name with new value" in { + //Given + registry.add("name4", "val1", noteId) + + //When + zeppelinContext.angularBind("name4", "val2") + + //Then + registry.get("name4", noteId).get() should be("val2") + } + + "ZeppelinContext" should "unbind Angular object by name" in { + //Given + registry.add("name5", "val1", noteId) + + //When + zeppelinContext.angularUnbind("name5") + + //Then + assert(registry.get("name5", noteId) == null ) + } + + "ZeppelinContext" should "add watch to Angular object by name" in { + //Given + registry.add("name6", "val1", noteId) + + val hasChanged = new AtomicBoolean(false) + val watcher = new AngularObjectWatcher(interpreterContext) { + override def watch(oldObject: scala.Any, newObject: scala.Any, context: InterpreterContext): Unit = { + hasChanged.getAndSet(true) + } + } + zeppelinContext.angularWatch("name6", watcher) + + //When + zeppelinContext.angularBind("name6", "val2") + + //Then + registry.get("name6", noteId).get() should be("val2") + + //Wait for the update to be effective + java.lang.Thread.sleep(100) + + hasChanged.get() should be(true) + } + + "ZeppelinContext" should "add watch to Angular object by name with anonymous function" in { + //Given + registry.add("name7", "val1", noteId) + + val hasChanged = new AtomicBoolean(false) + zeppelinContext.angularWatch("name7", (oldObject: scala.Any, newObject: scala.Any) => { + hasChanged.getAndSet(true) + }) + + //When + zeppelinContext.angularBind("name7", "val2") + + //Then + registry.get("name7", noteId).get() should be("val2") + hasChanged.get() should be(true) + } + + "ZeppelinContext" should "stop watching an Angular object using a given watcher" in { + //Given + registry.add("name8", "val1", noteId) + val hasChanged = new AtomicBoolean(false) + val watcher = new AngularObjectWatcher(interpreterContext) { + override def watch(oldObject: scala.Any, newObject: scala.Any, context: InterpreterContext): Unit = { + hasChanged.getAndSet(true) + } + } + + zeppelinContext.angularWatch("name8", watcher) + + //When + zeppelinContext.angularUnwatch("name8", watcher) + zeppelinContext.angularBind("name8", "val2") + + //Then + registry.get("name8", noteId).get() should be("val2") + hasChanged.get() should be(false) + } + + "ZeppelinContext" should "stop watching an Angular object for all watchers" in { + //Given + registry.add("name9", "val1", noteId) + val hasChanged1 = new AtomicBoolean(false) + val hasChanged2 = new AtomicBoolean(false) + val watcher1 = new AngularObjectWatcher(interpreterContext) { + override def watch(oldObject: scala.Any, newObject: scala.Any, context: InterpreterContext): Unit = { + hasChanged1.getAndSet(true) + } + } + val watcher2 = new AngularObjectWatcher(interpreterContext) { + override def watch(oldObject: scala.Any, newObject: scala.Any, context: InterpreterContext): Unit = { + hasChanged1.getAndSet(true) + } + } + + zeppelinContext.angularWatch("name9", watcher1) + zeppelinContext.angularWatch("name9", watcher2) + + //When + zeppelinContext.angularUnwatch("name9") + zeppelinContext.angularBind("name9", "val2") + + //Then + registry.get("name9", noteId).get() should be("val2") + hasChanged1.get() should be(false) + hasChanged2.get() should be(false) + } + + override def afterEach() { + try super.afterEach() // To be stackable, must call super.afterEach + stream = null + } +} + +class AngularListener extends AngularObjectRegistryListener { + override def onAdd(interpreterGroupId: String, `object`: AngularObject[_]): Unit = {} + + override def onUpdate(interpreterGroupId: String, `object`: AngularObject[_]): Unit = {} + + override def onRemove(interpreterGroupId: String, name: String, noteId: String): Unit = {} +}