diff --git a/spark/src/main/resources/python/zeppelin_pyspark.py b/spark/src/main/resources/python/zeppelin_pyspark.py
index 3e6535fa4f9..e40f928e6ba 100644
--- a/spark/src/main/resources/python/zeppelin_pyspark.py
+++ b/spark/src/main/resources/python/zeppelin_pyspark.py
@@ -30,6 +30,13 @@
import ast
import traceback
+import base64
+from io import BytesIO
+try:
+ from StringIO import StringIO
+except ImportError:
+ from io import StringIO
+
# for back compatibility
from pyspark.sql import SQLContext, HiveContext, Row
@@ -51,13 +58,43 @@ class PyZeppelinContext(dict):
def __init__(self, zc):
self.z = zc
- def show(self, obj):
+
+ def show(self, obj, **kwargs):
from pyspark.sql import DataFrame
+
if isinstance(obj, DataFrame):
print(gateway.jvm.org.apache.zeppelin.spark.ZeppelinContext.showDF(self.z, obj._jdf))
+ elif hasattr(obj, '__name__') and obj.__name__ == "matplotlib.pyplot":
+ self.show_matplotlib(obj, **kwargs)
+ elif hasattr(obj, '__call__'):
+ obj() #error reporting
else:
print(str(obj))
+ def show_matplotlib(self, p, fmt="png", width="auto", height="auto",
+ **kwargs):
+ """Matplotlib show function
+ """
+ if fmt == "png":
+ img = BytesIO()
+ p.savefig(img, format=fmt)
+ img_str = b"data:image/png;base64,"
+ img_str += base64.b64encode(img.getvalue().strip())
+ img_tag = ""
+ # Decoding is necessary for Python 3 compability
+ img_str = img_str.decode("ascii")
+ img_str = img_tag.format(img=img_str, width=width, height=height)
+ elif fmt == "svg":
+ img = StringIO()
+ p.savefig(img, format=fmt)
+ img_str = img.getvalue()
+ else:
+ raise ValueError("fmt must be 'png' or 'svg'")
+
+ html = "%html