diff --git a/spark/src/main/resources/python/zeppelin_pyspark.py b/spark/src/main/resources/python/zeppelin_pyspark.py
index 3e6535fa4f9..3a8425bed18 100644
--- a/spark/src/main/resources/python/zeppelin_pyspark.py
+++ b/spark/src/main/resources/python/zeppelin_pyspark.py
@@ -29,6 +29,12 @@
from pyspark.serializers import MarshalSerializer, PickleSerializer
import ast
import traceback
+import base64
+from io import BytesIO
+try:
+ from StringIO import StringIO
+except ImportError:
+ from io import StringIO
# for back compatibility
from pyspark.sql import SQLContext, HiveContext, Row
@@ -50,11 +56,16 @@ def flush(self):
class PyZeppelinContext(dict):
def __init__(self, zc):
self.z = zc
+ self.max_result = 1000
- def show(self, obj):
+ def show(self, obj,**kwargs):
from pyspark.sql import DataFrame
- if isinstance(obj, DataFrame):
- print(gateway.jvm.org.apache.zeppelin.spark.ZeppelinContext.showDF(self.z, obj._jdf))
+ if isinstance(obj, DataFrame) and type(obj).__name__ == "DataFrame":
+ print(gateway.jvm.org.apache.zeppelin.spark.ZeppelinContext.showDF(self.z, obj._jdf))
+ elif hasattr(obj, '__name__') and obj.__name__ == "matplotlib.pyplot":
+ self.show_matplotlib(obj, **kwargs)
+ elif hasattr(obj, '__call__'):
+ obj() #error reporting
else:
print(str(obj))
@@ -69,7 +80,31 @@ def __delitem__(self, key):
self.z.remove(key)
def __contains__(self, item):
- return self.z.containsKey(item)
+ return self.z.containsKey(item)
+
+ def show_matplotlib(self, p, fmt="png", width="auto", height="auto",
+ **kwargs):
+ """Matplotlib show function
+ """
+ if fmt == "png":
+ img = BytesIO()
+ p.savefig(img, format=fmt)
+ img_str = b"data:image/png;base64,"
+ img_str += base64.b64encode(img.getvalue().strip())
+ img_tag = ""
+ # Decoding is necessary for Python 3 compability
+ img_str = img_str.decode("ascii")
+ img_str = img_tag.format(img=img_str, width=width, height=height)
+ elif fmt == "svg":
+ img = StringIO()
+ p.savefig(img, format=fmt)
+ img_str = img.getvalue()
+ else:
+ raise ValueError("fmt must be 'png' or 'svg'")
+
+ html = "%html