diff --git a/spark/src/main/resources/python/zeppelin_pyspark.py b/spark/src/main/resources/python/zeppelin_pyspark.py index 3e6535fa4f9..3a8425bed18 100644 --- a/spark/src/main/resources/python/zeppelin_pyspark.py +++ b/spark/src/main/resources/python/zeppelin_pyspark.py @@ -29,6 +29,12 @@ from pyspark.serializers import MarshalSerializer, PickleSerializer import ast import traceback +import base64 +from io import BytesIO +try: + from StringIO import StringIO +except ImportError: + from io import StringIO # for back compatibility from pyspark.sql import SQLContext, HiveContext, Row @@ -50,11 +56,16 @@ def flush(self): class PyZeppelinContext(dict): def __init__(self, zc): self.z = zc + self.max_result = 1000 - def show(self, obj): + def show(self, obj,**kwargs): from pyspark.sql import DataFrame - if isinstance(obj, DataFrame): - print(gateway.jvm.org.apache.zeppelin.spark.ZeppelinContext.showDF(self.z, obj._jdf)) + if isinstance(obj, DataFrame) and type(obj).__name__ == "DataFrame": + print(gateway.jvm.org.apache.zeppelin.spark.ZeppelinContext.showDF(self.z, obj._jdf)) + elif hasattr(obj, '__name__') and obj.__name__ == "matplotlib.pyplot": + self.show_matplotlib(obj, **kwargs) + elif hasattr(obj, '__call__'): + obj() #error reporting else: print(str(obj)) @@ -69,7 +80,31 @@ def __delitem__(self, key): self.z.remove(key) def __contains__(self, item): - return self.z.containsKey(item) + return self.z.containsKey(item) + + def show_matplotlib(self, p, fmt="png", width="auto", height="auto", + **kwargs): + """Matplotlib show function + """ + if fmt == "png": + img = BytesIO() + p.savefig(img, format=fmt) + img_str = b"data:image/png;base64," + img_str += base64.b64encode(img.getvalue().strip()) + img_tag = "" + # Decoding is necessary for Python 3 compability + img_str = img_str.decode("ascii") + img_str = img_tag.format(img=img_str, width=width, height=height) + elif fmt == "svg": + img = StringIO() + p.savefig(img, format=fmt) + img_str = img.getvalue() + else: + raise ValueError("fmt must be 'png' or 'svg'") + + html = "%html
{img}
" + print(html.format(width=width, height=height, img=img_str)) + img.close() def add(self, key, value): self.__setitem__(key, value)