diff --git a/databricks/koalas/plot/core.py b/databricks/koalas/plot/core.py
index dd5c0b508c..3114af6f87 100644
--- a/databricks/koalas/plot/core.py
+++ b/databricks/koalas/plot/core.py
@@ -330,11 +330,12 @@ def _get_plot_backend(backend=None):
backend = backend or get_option("plotting.backend")
+ # TODO: leverage koalas_plotting_backends, and remove the codes below.
if backend == "matplotlib":
# Because matplotlib is an optional dependency and first-party backend,
# we need to attempt an import here to raise an ImportError if needed.
try:
- import databricks.koalas.plot as module
+ from databricks.koalas.plot import matplotlib as module
except ImportError:
raise ImportError(
"matplotlib is required for plotting when the "
@@ -343,16 +344,23 @@ def _get_plot_backend(backend=None):
KoalasPlotAccessor._backends["matplotlib"] = module
+ if backend == "plotly":
+ try:
+ # test if plotly can be imported
+ import plotly # noqa: F401
+ from databricks.koalas.plot import plotly as module
+ except ImportError:
+ raise ImportError(
+ "matplotlib is required for plotting when the "
+ "default backend 'matplotlib' is selected."
+ ) from None
+
+ KoalasPlotAccessor._backends["plotly"] = module
+
if backend in KoalasPlotAccessor._backends:
return KoalasPlotAccessor._backends[backend]
module = KoalasPlotAccessor._find_backend(backend)
-
- if backend == "plotly":
- from databricks.koalas.plot.plotly import plot_plotly
-
- module.plot = plot_plotly(module.plot)
-
KoalasPlotAccessor._backends[backend] = module
return module
@@ -360,7 +368,9 @@ def __call__(self, kind="line", backend=None, **kwargs):
plot_backend = KoalasPlotAccessor._get_plot_backend(backend)
plot_data = self.data
- if plot_backend.__name__ != "databricks.koalas.plot":
+ # TODO: make 'databricks.koalas.plot.matplotlib' module to implement
+ # plot interface.
+ if plot_backend.__name__ != "databricks.koalas.plot.matplotlib":
data_preprocessor_map = {
"pie": TopNPlotBase().get_top_n,
"bar": TopNPlotBase().get_top_n,
diff --git a/databricks/koalas/plot/matplotlib.py b/databricks/koalas/plot/matplotlib.py
index a2b8b39597..ebaaeee37c 100644
--- a/databricks/koalas/plot/matplotlib.py
+++ b/databricks/koalas/plot/matplotlib.py
@@ -16,7 +16,7 @@
from distutils.version import LooseVersion
-import matplotlib
+import matplotlib as mat
import numpy as np
import pandas as pd
from matplotlib.axes._base import _process_plot_format
@@ -111,9 +111,7 @@ def update_dict(dictionary, rc_name, properties):
if dictionary is None:
dictionary = dict()
for prop_dict in properties:
- dictionary.setdefault(
- prop_dict, matplotlib.rcParams[rc_str.format(rc_name, prop_dict)]
- )
+ dictionary.setdefault(prop_dict, mat.rcParams[rc_str.format(rc_name, prop_dict)])
return dictionary
# Common property dictionaries loading from rc
@@ -203,7 +201,7 @@ def update_dict(dictionary, rc_name, properties):
if manage_ticks is not None:
should_manage_ticks = manage_ticks
- if LooseVersion(matplotlib.__version__) < LooseVersion("3.1.0"):
+ if LooseVersion(mat.__version__) < LooseVersion("3.1.0"):
extra_args = {"manage_xticks": should_manage_ticks}
else:
extra_args = {"manage_ticks": should_manage_ticks}
@@ -333,26 +331,26 @@ def rc_defaults(
):
# Missing arguments default to rcParams.
if whis is None:
- whis = matplotlib.rcParams["boxplot.whiskers"]
+ whis = mat.rcParams["boxplot.whiskers"]
if bootstrap is None:
- bootstrap = matplotlib.rcParams["boxplot.bootstrap"]
+ bootstrap = mat.rcParams["boxplot.bootstrap"]
if notch is None:
- notch = matplotlib.rcParams["boxplot.notch"]
+ notch = mat.rcParams["boxplot.notch"]
if vert is None:
- vert = matplotlib.rcParams["boxplot.vertical"]
+ vert = mat.rcParams["boxplot.vertical"]
if patch_artist is None:
- patch_artist = matplotlib.rcParams["boxplot.patchartist"]
+ patch_artist = mat.rcParams["boxplot.patchartist"]
if meanline is None:
- meanline = matplotlib.rcParams["boxplot.meanline"]
+ meanline = mat.rcParams["boxplot.meanline"]
if showmeans is None:
- showmeans = matplotlib.rcParams["boxplot.showmeans"]
+ showmeans = mat.rcParams["boxplot.showmeans"]
if showcaps is None:
- showcaps = matplotlib.rcParams["boxplot.showcaps"]
+ showcaps = mat.rcParams["boxplot.showcaps"]
if showbox is None:
- showbox = matplotlib.rcParams["boxplot.showbox"]
+ showbox = mat.rcParams["boxplot.showbox"]
if showfliers is None:
- showfliers = matplotlib.rcParams["boxplot.showfliers"]
+ showfliers = mat.rcParams["boxplot.showfliers"]
return dict(
whis=whis,
diff --git a/databricks/koalas/plot/plotly.py b/databricks/koalas/plot/plotly.py
index ac14b4dc9b..ad1de26ec8 100644
--- a/databricks/koalas/plot/plotly.py
+++ b/databricks/koalas/plot/plotly.py
@@ -15,22 +15,21 @@
#
import pandas as pd
-from databricks.koalas.plot import HistogramPlotBase
+from databricks.koalas.plot import HistogramPlotBase, name_like_string
-def plot_plotly(origin_plot):
- def plot(data, kind, **kwargs):
- # Koalas specific plots
- if kind == "pie":
- return plot_pie(data, **kwargs)
- if kind == "hist":
- # Note that here data is a Koalas DataFrame or Series unlike other type of plots.
- return plot_histogram(data, **kwargs)
+def plot(data, kind, **kwargs):
+ import plotly
- # Other plots.
- return origin_plot(data, kind, **kwargs)
+ # Koalas specific plots
+ if kind == "pie":
+ return plot_pie(data, **kwargs)
+ if kind == "hist":
+ # Note that here data is a Koalas DataFrame or Series unlike other type of plots.
+ return plot_histogram(data, **kwargs)
- return plot
+ # Other plots.
+ return plotly.plot(data, kind, **kwargs)
def plot_pie(data, **kwargs):
@@ -40,7 +39,6 @@ def plot_pie(data, **kwargs):
pdf = data.to_frame()
return express.pie(pdf, values=pdf.columns[0], names=pdf.index, **kwargs)
elif isinstance(data, pd.DataFrame):
- # DataFrame
values = kwargs.pop("y", None)
default_names = None
if values is not None:
@@ -57,37 +55,35 @@ def plot_pie(data, **kwargs):
def plot_histogram(data, **kwargs):
- from plotly import express
- from plotly.subplots import make_subplots
import plotly.graph_objs as go
- from databricks import koalas as ks
bins = kwargs.get("bins", 10)
- is_single_column = isinstance(data, ks.Series)
kdf, bins = HistogramPlotBase.prepare_hist_data(data, bins)
output_series = HistogramPlotBase.compute_hist(kdf, bins)
+ prev = float("%.9f" % bins[0]) # to make it prettier, truncate.
+ text_bins = []
+ for b in bins[1:]:
+ norm_b = float("%.9f" % b)
+ text_bins.append("(%s, %s)" % (prev, norm_b))
+ prev = norm_b
bins = 0.5 * (bins[:-1] + bins[1:])
- if is_single_column:
- output_series = list(output_series)
- assert len(output_series) == 1
- output_series = output_series[0]
- return express.bar(
- x=bins, y=output_series, labels={"x": str(output_series.name), "y": "count"}
- )
- else:
- output_series = list(output_series)
- fig = make_subplots(rows=1, cols=len(output_series))
- for i, series in enumerate(output_series):
- fig.add_trace(go.Bar(x=bins, y=series, name=series.name,), row=1, col=i + 1)
+ output_series = list(output_series)
+ bars = []
+ for series in output_series:
+ bars.append(
+ go.Bar(
+ x=bins,
+ y=series,
+ name=name_like_string(series.name),
+ text=text_bins,
+ hovertemplate=(
+ "variable=" + name_like_string(series.name) + "
value=%{text}
count=%{y}"
+ ),
+ )
+ )
- for i, series in enumerate(output_series):
- if i == 0:
- xaxis = "xaxis"
- yaxis = "yaxis"
- else:
- xaxis = "xaxis%s" % (i + 1)
- yaxis = "yaxis%s" % (i + 1)
- fig["layout"][xaxis]["title"] = str(series.name)
- fig["layout"][yaxis]["title"] = "count"
- return fig
+ fig = go.Figure(data=bars, layout=go.Layout(barmode="stack"))
+ fig["layout"]["xaxis"]["title"] = "value"
+ fig["layout"]["yaxis"]["title"] = "count"
+ return fig
diff --git a/databricks/koalas/tests/plot/test_frame_plot_plotly.py b/databricks/koalas/tests/plot/test_frame_plot_plotly.py
index 8c7fa68941..1e0880a818 100644
--- a/databricks/koalas/tests/plot/test_frame_plot_plotly.py
+++ b/databricks/koalas/tests/plot/test_frame_plot_plotly.py
@@ -20,7 +20,6 @@
import pandas as pd
import numpy as np
from plotly import express
-from plotly.subplots import make_subplots
import plotly.graph_objs as go
from databricks import koalas as ks
@@ -182,23 +181,37 @@ def check_pie_plot(kdf):
def test_hist_plot(self):
def check_hist_plot(kdf):
bins = np.array([1.0, 5.9, 10.8, 15.7, 20.6, 25.5, 30.4, 35.3, 40.2, 45.1, 50.0])
- bins = 0.5 * (bins[:-1] + bins[1:])
data = [
np.array([5.0, 4.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0,]),
np.array([4.0, 3.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0,]),
]
-
- fig = make_subplots(rows=1, cols=len(data))
- fig.add_trace(
- go.Bar(x=bins, y=data[0], name=name_like_string(kdf.columns[0]),), row=1, col=1
- )
- fig.add_trace(
- go.Bar(x=bins, y=data[1], name=name_like_string(kdf.columns[1]),), row=1, col=2
- )
- fig["layout"]["xaxis"]["title"] = name_like_string(kdf.columns[0])
+ prev = bins[0]
+ text_bins = []
+ for b in bins[1:]:
+ text_bins.append("(%s, %s)" % (prev, b))
+ prev = b
+ bins = 0.5 * (bins[:-1] + bins[1:])
+ name_a = name_like_string(kdf.columns[0])
+ name_b = name_like_string(kdf.columns[1])
+ bars = [
+ go.Bar(
+ x=bins,
+ y=data[0],
+ name=name_a,
+ text=text_bins,
+ hovertemplate=("variable=" + name_a + "
value=%{text}
count=%{y}"),
+ ),
+ go.Bar(
+ x=bins,
+ y=data[1],
+ name=name_b,
+ text=text_bins,
+ hovertemplate=("variable=" + name_b + "
value=%{text}
count=%{y}"),
+ ),
+ ]
+ fig = go.Figure(data=bars, layout=go.Layout(barmode="stack"))
+ fig["layout"]["xaxis"]["title"] = "value"
fig["layout"]["yaxis"]["title"] = "count"
- fig["layout"]["xaxis2"]["title"] = name_like_string(kdf.columns[1])
- fig["layout"]["yaxis2"]["title"] = "count"
self.assertEqual(
pprint.pformat(kdf.plot(kind="hist").to_dict()), pprint.pformat(fig.to_dict())
diff --git a/databricks/koalas/tests/plot/test_series_plot_plotly.py b/databricks/koalas/tests/plot/test_series_plot_plotly.py
index 834d04e02a..16750ccd64 100644
--- a/databricks/koalas/tests/plot/test_series_plot_plotly.py
+++ b/databricks/koalas/tests/plot/test_series_plot_plotly.py
@@ -20,6 +20,7 @@
import pandas as pd
import numpy as np
from plotly import express
+import plotly.graph_objs as go
from databricks import koalas as ks
from databricks.koalas.config import set_option, reset_option
@@ -135,13 +136,26 @@ def test_pie_plot(self):
def test_hist_plot(self):
def check_hist_plot(kser):
bins = np.array([1.0, 5.9, 10.8, 15.7, 20.6, 25.5, 30.4, 35.3, 40.2, 45.1, 50.0])
+ data = np.array([5.0, 4.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0,])
+ prev = bins[0]
+ text_bins = []
+ for b in bins[1:]:
+ text_bins.append("[%s, %s)" % (prev, b))
+ prev = b
bins = 0.5 * (bins[:-1] + bins[1:])
- data = pd.Series(
- np.array([5.0, 4.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0,]), name=kser.name
- )
- fig = express.bar(
- x=bins, y=data, labels={"x": name_like_string(data.name), "y": "count"}
- )
+ name_a = name_like_string(kser.name)
+ bars = [
+ go.Bar(
+ x=bins,
+ y=data,
+ name=name_a,
+ text=text_bins,
+ hovertemplate=("variable=" + name_a + "
value=%{text}
count=%{y}"),
+ ),
+ ]
+ fig = go.Figure(data=bars, layout=go.Layout(barmode="stack"))
+ fig["layout"]["xaxis"]["title"] = "value"
+ fig["layout"]["yaxis"]["title"] = "count"
self.assertEqual(
pprint.pformat(kser.plot(kind="hist").to_dict()), pprint.pformat(fig.to_dict())