diff --git a/docs/manual/interpreters.md b/docs/manual/interpreters.md index c47d5369d8a..0aeabac8e16 100644 --- a/docs/manual/interpreters.md +++ b/docs/manual/interpreters.md @@ -82,3 +82,49 @@ interpreter.start() The above code will start interpreter thread inside your process. Once the interpreter is started you can configure zeppelin to connect to RemoteInterpreter by checking **Connect to existing process** checkbox and then provide **Host** and **Port** on which interpreter porocess is listening as shown in the image below: + + +## (Experimental) Interpreter Execution Hooks + +Zeppelin allows for users to specify additional code to be executed by an interpreter at pre and post-paragraph code execution. This is primarily useful if you need to run the same set of code for all of the paragraphs within your notebook at specific times. Currently, this feature is only available for the spark and pyspark interpreters. To specify your hook code, you may use '`z.registerHook()`. For example, enter the following into one paragraph: + +```python +%pyspark +z.registerHook("post_exec", "print 'This code should be executed before the parapgraph code!'") +z.registerHook("pre_exec", "print 'This code should be executed after the paragraph code!'") +``` + +These calls will not take into effect until the next time you run a paragraph. In another paragraph, enter +```python +%pyspark +print "This code should be entered into the paragraph by the user!" +``` + +The output should be: +``` +This code should be executed before the paragraph code! +This code should be entered into the paragraph by the user! +This code should be executed after the paragraph code! +``` + +If you ever need to know the hook code, use `z.getHook()`: +```python +%pyspark +print z.getHook("post_exec") +``` +``` +print 'This code should be executed after the paragraph code!' +``` +Any call to `z.registerHook()` will automatically overwrite what was previously registered. To completely unregister a hook event, use `z.unregisterHook(eventCode)`. Currently only `"post_exec"` and `"pre_exec"` are valid event codes for the Zeppelin Hook Registry system. + +Finally, the hook registry is internally shared by other interpreters in the same group. This would allow for hook code for one interpreter REPL to be set by another as follows: + +```scala +%spark +z.unregisterHook("post_exec", "pyspark") +``` +The API is identical for both the spark (scala) and pyspark (python) implementations. + +### Caveats +Calls to `z.registerHook("pre_exec", ...)` should be made with care. If there are errors in your specified hook code, this will cause the interpreter REPL to become unable to execute any code pass the pre-execute stage making it impossible for direct calls to `z.unregisterHook()` to take into effect. Current workarounds include calling `z.unregisterHook()` from a different interpreter REPL in the same interpreter group (see above) or manually restarting the interpreter group in the UI. + diff --git a/spark/src/main/java/org/apache/zeppelin/spark/SparkInterpreter.java b/spark/src/main/java/org/apache/zeppelin/spark/SparkInterpreter.java index 44c2a7461b1..248225d3c9c 100644 --- a/spark/src/main/java/org/apache/zeppelin/spark/SparkInterpreter.java +++ b/spark/src/main/java/org/apache/zeppelin/spark/SparkInterpreter.java @@ -49,6 +49,7 @@ import org.apache.zeppelin.interpreter.Interpreter; import org.apache.zeppelin.interpreter.InterpreterContext; import org.apache.zeppelin.interpreter.InterpreterException; +import org.apache.zeppelin.interpreter.InterpreterHookRegistry; import org.apache.zeppelin.interpreter.InterpreterProperty; import org.apache.zeppelin.interpreter.InterpreterResult; import org.apache.zeppelin.interpreter.InterpreterResult.Code; @@ -101,6 +102,7 @@ public class SparkInterpreter extends Interpreter { private SparkConf conf; private static SparkContext sc; private static SQLContext sqlc; + private static InterpreterHookRegistry hooks; private static SparkEnv env; private static Object sparkSession; // spark 2.x private static JobProgressListener sparkListener; @@ -778,8 +780,10 @@ public void open() { sqlc = getSQLContext(); dep = getDependencyResolver(); + + hooks = getInterpreterGroup().getInterpreterHookRegistry(); - z = new ZeppelinContext(sc, sqlc, null, dep, + z = new ZeppelinContext(sc, sqlc, null, dep, hooks, Integer.parseInt(getProperty("zeppelin.spark.maxResult"))); interpret("@transient val _binder = new java.util.HashMap[String, Object]()"); diff --git a/spark/src/main/java/org/apache/zeppelin/spark/ZeppelinContext.java b/spark/src/main/java/org/apache/zeppelin/spark/ZeppelinContext.java index 7465756e14f..92b50d0a3b8 100644 --- a/spark/src/main/java/org/apache/zeppelin/spark/ZeppelinContext.java +++ b/spark/src/main/java/org/apache/zeppelin/spark/ZeppelinContext.java @@ -28,11 +28,14 @@ import java.util.Iterator; import java.util.LinkedList; import java.util.List; +import java.util.Map; +import java.util.HashMap; import org.apache.spark.SparkContext; import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.catalyst.expressions.Attribute; import org.apache.zeppelin.annotation.ZeppelinApi; +import org.apache.zeppelin.annotation.Experimental; import org.apache.zeppelin.display.AngularObject; import org.apache.zeppelin.display.AngularObjectRegistry; import org.apache.zeppelin.display.AngularObjectWatcher; @@ -41,6 +44,7 @@ import org.apache.zeppelin.interpreter.InterpreterContext; import org.apache.zeppelin.interpreter.InterpreterContextRunner; import org.apache.zeppelin.interpreter.InterpreterException; +import org.apache.zeppelin.interpreter.InterpreterHookRegistry; import org.apache.zeppelin.spark.dep.SparkDependencyResolver; import org.apache.zeppelin.resource.Resource; import org.apache.zeppelin.resource.ResourcePool; @@ -53,19 +57,33 @@ * Spark context for zeppelin. */ public class ZeppelinContext { + // Map interpreter class name (to be used by hook registry) from + // given replName in parapgraph + private static final Map interpreterClassMap; + static { + interpreterClassMap = new HashMap(); + interpreterClassMap.put("spark", "org.apache.zeppelin.spark.SparkInterpreter"); + interpreterClassMap.put("sql", "org.apache.zeppelin.spark.SparkSqlInterpreter"); + interpreterClassMap.put("dep", "org.apache.zeppelin.spark.DepInterpreter"); + interpreterClassMap.put("pyspark", "org.apache.zeppelin.spark.PySparkInterpreter"); + } + private SparkDependencyResolver dep; private InterpreterContext interpreterContext; private int maxResult; private List supportedClasses; - + private InterpreterHookRegistry hooks; + public ZeppelinContext(SparkContext sc, SQLContext sql, InterpreterContext interpreterContext, SparkDependencyResolver dep, + InterpreterHookRegistry hooks, int maxResult) { this.sc = sc; this.sqlContext = sql; this.interpreterContext = interpreterContext; this.dep = dep; + this.hooks = hooks; this.maxResult = maxResult; this.supportedClasses = new ArrayList<>(); try { @@ -697,6 +715,90 @@ private void angularUnbind(String name, String noteId) { registry.remove(name, noteId, null); } + /** + * Get the interpreter class name from name entered in paragraph + * @param replName if replName is a valid className, return that instead. + */ + public String getClassNameFromReplName(String replName) { + for (String name : interpreterClassMap.values()) { + if (replName.equals(name)) { + return replName; + } + } + + if (replName.contains("spark.")) { + replName = replName.replace("spark.", ""); + } + return interpreterClassMap.get(replName); + } + + /** + * General function to register hook event + * @param event The type of event to hook to (pre_exec, post_exec) + * @param cmd The code to be executed by the interpreter on given event + * @param replName Name of the interpreter + */ + @Experimental + public void registerHook(String event, String cmd, String replName) { + String noteId = interpreterContext.getNoteId(); + String className = getClassNameFromReplName(replName); + hooks.register(noteId, className, event, cmd); + } + + /** + * registerHook() wrapper for current repl + * @param event The type of event to hook to (pre_exec, post_exec) + * @param cmd The code to be executed by the interpreter on given event + */ + @Experimental + public void registerHook(String event, String cmd) { + String className = interpreterContext.getClassName(); + registerHook(event, cmd, className); + } + + /** + * Get the hook code + * @param event The type of event to hook to (pre_exec, post_exec) + * @param replName Name of the interpreter + */ + @Experimental + public String getHook(String event, String replName) { + String noteId = interpreterContext.getNoteId(); + String className = getClassNameFromReplName(replName); + return hooks.get(noteId, className, event); + } + + /** + * getHook() wrapper for current repl + * @param event The type of event to hook to (pre_exec, post_exec) + */ + @Experimental + public String getHook(String event) { + String className = interpreterContext.getClassName(); + return getHook(event, className); + } + + /** + * Unbind code from given hook event + * @param event The type of event to hook to (pre_exec, post_exec) + * @param replName Name of the interpreter + */ + @Experimental + public void unregisterHook(String event, String replName) { + String noteId = interpreterContext.getNoteId(); + String className = getClassNameFromReplName(replName); + hooks.unregister(noteId, className, event); + } + + /** + * unregisterHook() wrapper for current repl + * @param event The type of event to hook to (pre_exec, post_exec) + */ + @Experimental + public void unregisterHook(String event) { + String className = interpreterContext.getClassName(); + unregisterHook(event, className); + } /** * Add object into resource pool diff --git a/spark/src/main/resources/python/zeppelin_pyspark.py b/spark/src/main/resources/python/zeppelin_pyspark.py index 53465c2cd80..b1076a653f5 100644 --- a/spark/src/main/resources/python/zeppelin_pyspark.py +++ b/spark/src/main/resources/python/zeppelin_pyspark.py @@ -80,16 +80,16 @@ def put(self, key, value): def get(self, key): return self.__getitem__(key) - def input(self, name, defaultValue = ""): + def input(self, name, defaultValue=""): return self.z.input(name, defaultValue) - def select(self, name, options, defaultValue = ""): + def select(self, name, options, defaultValue=""): # auto_convert to ArrayList doesn't match the method signature on JVM side tuples = list(map(lambda items: self.__tupleToScalaTuple2(items), options)) iterables = gateway.jvm.scala.collection.JavaConversions.collectionAsScalaIterable(tuples) return self.z.select(name, defaultValue, iterables) - def checkbox(self, name, options, defaultChecked = None): + def checkbox(self, name, options, defaultChecked=None): if defaultChecked is None: defaultChecked = list(map(lambda items: items[0], options)) optionTuples = list(map(lambda items: self.__tupleToScalaTuple2(items), options)) @@ -99,6 +99,23 @@ def checkbox(self, name, options, defaultChecked = None): checkedIterables = self.z.checkbox(name, defaultCheckedIterables, optionIterables) return gateway.jvm.scala.collection.JavaConversions.asJavaCollection(checkedIterables) + def registerHook(self, event, cmd, replName=None): + if replName is None: + self.z.registerHook(event, cmd) + else: + self.z.registerHook(event, cmd, replName) + + def unregisterHook(self, event, replName=None): + if replName is None: + self.z.unregisterHook(event) + else: + self.z.unregisterHook(event, replName) + + def getHook(self, event, replName=None): + if replName is None: + return self.z.getHook(event) + return self.z.getHook(event, replName) + def __tupleToScalaTuple2(self, tuple): if (len(tuple) == 2): return gateway.jvm.scala.Tuple2(tuple[0], tuple[1]) diff --git a/zeppelin-interpreter/src/main/java/org/apache/zeppelin/interpreter/Interpreter.java b/zeppelin-interpreter/src/main/java/org/apache/zeppelin/interpreter/Interpreter.java index 9678b4691df..3e323200f6e 100644 --- a/zeppelin-interpreter/src/main/java/org/apache/zeppelin/interpreter/Interpreter.java +++ b/zeppelin-interpreter/src/main/java/org/apache/zeppelin/interpreter/Interpreter.java @@ -27,6 +27,7 @@ import com.google.gson.annotations.SerializedName; import org.apache.zeppelin.annotation.ZeppelinApi; +import org.apache.zeppelin.annotation.Experimental; import org.apache.zeppelin.interpreter.thrift.InterpreterCompletion; import org.apache.zeppelin.scheduler.Scheduler; import org.apache.zeppelin.scheduler.SchedulerFactory; @@ -203,6 +204,71 @@ public void setClassloaderUrls(URL[] classloaderUrls) { this.classloaderUrls = classloaderUrls; } + /** + * General function to register hook event + * @param noteId - Note to bind hook to + * @param event The type of event to hook to (pre_exec, post_exec) + * @param cmd The code to be executed by the interpreter on given event + */ + @Experimental + public void registerHook(String noteId, String event, String cmd) { + InterpreterHookRegistry hooks = interpreterGroup.getInterpreterHookRegistry(); + String className = getClassName(); + hooks.register(noteId, className, event, cmd); + } + + /** + * registerHook() wrapper for global scope + * @param event The type of event to hook to (pre_exec, post_exec) + * @param cmd The code to be executed by the interpreter on given event + */ + @Experimental + public void registerHook(String event, String cmd) { + registerHook(null, event, cmd); + } + + /** + * Get the hook code + * @param noteId - Note to bind hook to + * @param event The type of event to hook to (pre_exec, post_exec) + */ + @Experimental + public String getHook(String noteId, String event) { + InterpreterHookRegistry hooks = interpreterGroup.getInterpreterHookRegistry(); + String className = getClassName(); + return hooks.get(noteId, className, event); + } + + /** + * getHook() wrapper for global scope + * @param event The type of event to hook to (pre_exec, post_exec) + */ + @Experimental + public String getHook(String event) { + return getHook(null, event); + } + + /** + * Unbind code from given hook event + * @param noteId - Note to bind hook to + * @param event The type of event to hook to (pre_exec, post_exec) + */ + @Experimental + public void unregisterHook(String noteId, String event) { + InterpreterHookRegistry hooks = interpreterGroup.getInterpreterHookRegistry(); + String className = getClassName(); + hooks.unregister(noteId, className, event); + } + + /** + * unregisterHook() wrapper for global scope + * @param event The type of event to hook to (pre_exec, post_exec) + */ + @Experimental + public void unregisterHook(String event) { + unregisterHook(null, event); + } + @ZeppelinApi public Interpreter getInterpreterInTheSameSessionByClassName(String className) { synchronized (interpreterGroup) { diff --git a/zeppelin-interpreter/src/main/java/org/apache/zeppelin/interpreter/InterpreterContext.java b/zeppelin-interpreter/src/main/java/org/apache/zeppelin/interpreter/InterpreterContext.java index 21ca2e67b72..e33b9352252 100644 --- a/zeppelin-interpreter/src/main/java/org/apache/zeppelin/interpreter/InterpreterContext.java +++ b/zeppelin-interpreter/src/main/java/org/apache/zeppelin/interpreter/InterpreterContext.java @@ -57,6 +57,7 @@ public static void remove() { private AngularObjectRegistry angularObjectRegistry; private ResourcePool resourcePool; private List runners; + private String className; public InterpreterContext(String noteId, String paragraphId, @@ -124,4 +125,11 @@ public List getRunners() { return runners; } + public String getClassName() { + return className; + } + + public void setClassName(String className) { + this.className = className; + } } diff --git a/zeppelin-interpreter/src/main/java/org/apache/zeppelin/interpreter/InterpreterGroup.java b/zeppelin-interpreter/src/main/java/org/apache/zeppelin/interpreter/InterpreterGroup.java index bc56784b15b..ee53f8e7571 100644 --- a/zeppelin-interpreter/src/main/java/org/apache/zeppelin/interpreter/InterpreterGroup.java +++ b/zeppelin-interpreter/src/main/java/org/apache/zeppelin/interpreter/InterpreterGroup.java @@ -45,6 +45,7 @@ public class InterpreterGroup extends ConcurrentHashMap>> registry = + new HashMap>>(); + + /** + * hookRegistry constructor. + * + * @param interpreterId The Id of the InterpreterGroup instance to bind to + */ + public InterpreterHookRegistry(final String interpreterId) { + this.interpreterId = interpreterId; + } + + /** + * Get the interpreterGroup id this instance is bound to + */ + public String getInterpreterId() { + return interpreterId; + } + + /** + * Adds a note to the registry + * + * @param noteId The Id of the Note instance to add + */ + public void addNote(String noteId) { + synchronized (registry) { + if (registry.get(noteId) == null) { + registry.put(noteId, new HashMap>()); + } + } + } + + /** + * Adds a className to the registry + * + * @param noteId The note id + * @param className The name of the interpreter repl to map the hooks to + */ + public void addRepl(String noteId, String className) { + synchronized (registry) { + addNote(noteId); + if (registry.get(noteId).get(className) == null) { + registry.get(noteId).put(className, new HashMap()); + } + } + } + + /** + * Register a hook for a specific event. + * + * @param noteId Denotes the note this instance belongs to + * @param className The name of the interpreter repl to map the hooks to + * @param event hook event (see constants defined in this class) + * @param cmd Code to be executed by the interpreter + */ + public void register(String noteId, String className, + String event, String cmd) throws IllegalArgumentException { + synchronized (registry) { + if (noteId == null) { + noteId = GLOBAL_KEY; + } + addRepl(noteId, className); + if (!event.equals(HookType.POST_EXEC) && !event.equals(HookType.PRE_EXEC) && + !event.equals(HookType.POST_EXEC_DEV) && !event.equals(HookType.PRE_EXEC_DEV)) { + throw new IllegalArgumentException("Must be " + HookType.POST_EXEC + ", " + + HookType.POST_EXEC_DEV + ", " + + HookType.PRE_EXEC + " or " + + HookType.PRE_EXEC_DEV); + } + registry.get(noteId).get(className).put(event, cmd); + } + } + + /** + * Unregister a hook for a specific event. + * + * @param noteId Denotes the note this instance belongs to + * @param className The name of the interpreter repl to map the hooks to + * @param event hook event (see constants defined in this class) + */ + public void unregister(String noteId, String className, String event) { + synchronized (registry) { + if (noteId == null) { + noteId = GLOBAL_KEY; + } + addRepl(noteId, className); + registry.get(noteId).get(className).remove(event); + } + } + + /** + * Get a hook for a specific event. + * + * @param noteId Denotes the note this instance belongs to + * @param className The name of the interpreter repl to map the hooks to + * @param event hook event (see constants defined in this class) + */ + public String get(String noteId, String className, String event) { + synchronized (registry) { + if (noteId == null) { + noteId = GLOBAL_KEY; + } + addRepl(noteId, className); + return registry.get(noteId).get(className).get(event); + } + } + + /** + * Container for hook event type constants + */ + public static final class HookType { + // Execute the hook code PRIOR to main paragraph code execution + public static final String PRE_EXEC = "pre_exec"; + + // Execute the hook code AFTER main paragraph code execution + public static final String POST_EXEC = "post_exec"; + + // Same as above but reserved for interpreter developers, in order to allow + // notebook users to use the above without overwriting registry settings + // that are initialized directly in subclasses of Interpreter. + public static final String PRE_EXEC_DEV = "pre_exec_dev"; + public static final String POST_EXEC_DEV = "post_exec_dev"; + } + +} diff --git a/zeppelin-interpreter/src/main/java/org/apache/zeppelin/interpreter/LazyOpenInterpreter.java b/zeppelin-interpreter/src/main/java/org/apache/zeppelin/interpreter/LazyOpenInterpreter.java index c62ab05eb1f..425ae20a4f1 100644 --- a/zeppelin-interpreter/src/main/java/org/apache/zeppelin/interpreter/LazyOpenInterpreter.java +++ b/zeppelin-interpreter/src/main/java/org/apache/zeppelin/interpreter/LazyOpenInterpreter.java @@ -147,4 +147,34 @@ public void setInterpreterGroup(InterpreterGroup interpreterGroup) { public void setClassloaderUrls(URL [] urls) { intp.setClassloaderUrls(urls); } + + @Override + public void registerHook(String noteId, String event, String cmd) { + intp.registerHook(noteId, event, cmd); + } + + @Override + public void registerHook(String event, String cmd) { + intp.registerHook(event, cmd); + } + + @Override + public String getHook(String noteId, String event) { + return intp.getHook(noteId, event); + } + + @Override + public String getHook(String event) { + return intp.getHook(event); + } + + @Override + public void unregisterHook(String noteId, String event) { + intp.unregisterHook(noteId, event); + } + + @Override + public void unregisterHook(String event) { + intp.unregisterHook(event); + } } diff --git a/zeppelin-interpreter/src/main/java/org/apache/zeppelin/interpreter/remote/RemoteInterpreterServer.java b/zeppelin-interpreter/src/main/java/org/apache/zeppelin/interpreter/remote/RemoteInterpreterServer.java index 8344366e569..0a7b1ed6912 100644 --- a/zeppelin-interpreter/src/main/java/org/apache/zeppelin/interpreter/remote/RemoteInterpreterServer.java +++ b/zeppelin-interpreter/src/main/java/org/apache/zeppelin/interpreter/remote/RemoteInterpreterServer.java @@ -33,6 +33,8 @@ import org.apache.zeppelin.display.*; import org.apache.zeppelin.helium.*; import org.apache.zeppelin.interpreter.*; +import org.apache.zeppelin.interpreter.InterpreterHookRegistry.HookType; +import org.apache.zeppelin.interpreter.InterpreterHookListener; import org.apache.zeppelin.interpreter.InterpreterResult.Code; import org.apache.zeppelin.interpreter.dev.ZeppelinDevServer; import org.apache.zeppelin.interpreter.thrift.*; @@ -60,6 +62,7 @@ public class RemoteInterpreterServer InterpreterGroup interpreterGroup; AngularObjectRegistry angularObjectRegistry; + InterpreterHookRegistry hookRegistry; DistributedResourcePool resourcePool; private ApplicationLoader appLoader; @@ -152,7 +155,9 @@ public void createInterpreter(String interpreterGroupId, String noteId, String if (interpreterGroup == null) { interpreterGroup = new InterpreterGroup(interpreterGroupId); angularObjectRegistry = new AngularObjectRegistry(interpreterGroup.getId(), this); + hookRegistry = new InterpreterHookRegistry(interpreterGroup.getId()); resourcePool = new DistributedResourcePool(interpreterGroup.getId(), eventClient); + interpreterGroup.setInterpreterHookRegistry(hookRegistry); interpreterGroup.setAngularObjectRegistry(angularObjectRegistry); interpreterGroup.setResourcePool(resourcePool); @@ -290,6 +295,7 @@ public RemoteInterpreterResult interpret(String noteId, String className, String } Interpreter intp = getInterpreter(noteId, className); InterpreterContext context = convert(interpreterContext); + context.setClassName(intp.getClassName()); Scheduler scheduler = intp.getScheduler(); InterpretJobListener jobListener = new InterpretJobListener(); @@ -383,10 +389,61 @@ public Map info() { return infos; } + private void processInterpreterHooks(final String noteId) { + InterpreterHookListener hookListener = new InterpreterHookListener() { + @Override + public void onPreExecute(String script) { + String cmdDev = interpreter.getHook(noteId, HookType.PRE_EXEC_DEV); + String cmdUser = interpreter.getHook(noteId, HookType.PRE_EXEC); + + // User defined hook should be executed before dev hook + List cmds = Arrays.asList(cmdDev, cmdUser); + for (String cmd : cmds) { + if (cmd != null) { + script = cmd + '\n' + script; + } + } + + InterpretJob.this.script = script; + } + + @Override + public void onPostExecute(String script) { + String cmdDev = interpreter.getHook(noteId, HookType.POST_EXEC_DEV); + String cmdUser = interpreter.getHook(noteId, HookType.POST_EXEC); + + // User defined hook should be executed after dev hook + List cmds = Arrays.asList(cmdUser, cmdDev); + for (String cmd : cmds) { + if (cmd != null) { + script += '\n' + cmd; + } + } + + InterpretJob.this.script = script; + } + }; + hookListener.onPreExecute(script); + hookListener.onPostExecute(script); + } + @Override protected Object jobRun() throws Throwable { try { InterpreterContext.set(context); + + // Open the interpreter instance prior to calling interpret(). + // This is necessary because the earliest we can register a hook + // is from within the open() method. + LazyOpenInterpreter lazy = (LazyOpenInterpreter) interpreter; + if (!lazy.isOpen()) { + lazy.open(); + } + + // Add hooks to script from registry. + // Global scope first, followed by notebook scope + processInterpreterHooks(null); + processInterpreterHooks(context.getNoteId()); InterpreterResult result = interpreter.interpret(script, context); // data from context.out is prepended to InterpreterResult if both defined diff --git a/zeppelin-interpreter/src/test/java/org/apache/zeppelin/interpreter/InterpreterHookRegistryTest.java b/zeppelin-interpreter/src/test/java/org/apache/zeppelin/interpreter/InterpreterHookRegistryTest.java new file mode 100644 index 00000000000..7614e9eb204 --- /dev/null +++ b/zeppelin-interpreter/src/test/java/org/apache/zeppelin/interpreter/InterpreterHookRegistryTest.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.zeppelin.interpreter; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +import java.util.concurrent.atomic.AtomicInteger; + +import org.junit.Test; + +public class InterpreterHookRegistryTest { + + @Test + public void testBasic() { + final String PRE_EXEC = InterpreterHookRegistry.HookType.PRE_EXEC; + final String POST_EXEC = InterpreterHookRegistry.HookType.POST_EXEC; + final String PRE_EXEC_DEV = InterpreterHookRegistry.HookType.PRE_EXEC_DEV; + final String POST_EXEC_DEV = InterpreterHookRegistry.HookType.POST_EXEC_DEV; + final String GLOBAL_KEY = InterpreterHookRegistry.GLOBAL_KEY; + final String noteId = "note"; + final String className = "class"; + final String preExecHook = "pre"; + final String postExecHook = "post"; + InterpreterHookRegistry registry = new InterpreterHookRegistry("intpId"); + + // Test register() + registry.register(noteId, className, PRE_EXEC, preExecHook); + registry.register(noteId, className, POST_EXEC, postExecHook); + registry.register(noteId, className, PRE_EXEC_DEV, preExecHook); + registry.register(noteId, className, POST_EXEC_DEV, postExecHook); + + // Test get() + assertEquals(registry.get(noteId, className, PRE_EXEC), preExecHook); + assertEquals(registry.get(noteId, className, POST_EXEC), postExecHook); + assertEquals(registry.get(noteId, className, PRE_EXEC_DEV), preExecHook); + assertEquals(registry.get(noteId, className, POST_EXEC_DEV), postExecHook); + + // Test Unregister + registry.unregister(noteId, className, PRE_EXEC); + registry.unregister(noteId, className, POST_EXEC); + registry.unregister(noteId, className, PRE_EXEC_DEV); + registry.unregister(noteId, className, POST_EXEC_DEV); + assertNull(registry.get(noteId, className, PRE_EXEC)); + assertNull(registry.get(noteId, className, POST_EXEC)); + assertNull(registry.get(noteId, className, PRE_EXEC_DEV)); + assertNull(registry.get(noteId, className, POST_EXEC_DEV)); + + // Test Global Scope + registry.register(null, className, PRE_EXEC, preExecHook); + assertEquals(registry.get(GLOBAL_KEY, className, PRE_EXEC), preExecHook); + } + + @Test(expected = IllegalArgumentException.class) + public void testValidEventCode() { + InterpreterHookRegistry registry = new InterpreterHookRegistry("intpId"); + + // Test that only valid event codes ("pre_exec", "post_exec") are accepted + registry.register("foo", "bar", "baz", "whatever"); + } + +}