diff --git a/databricks/koalas/plot/core.py b/databricks/koalas/plot/core.py index dd5c0b508c..3114af6f87 100644 --- a/databricks/koalas/plot/core.py +++ b/databricks/koalas/plot/core.py @@ -330,11 +330,12 @@ def _get_plot_backend(backend=None): backend = backend or get_option("plotting.backend") + # TODO: leverage koalas_plotting_backends, and remove the codes below. if backend == "matplotlib": # Because matplotlib is an optional dependency and first-party backend, # we need to attempt an import here to raise an ImportError if needed. try: - import databricks.koalas.plot as module + from databricks.koalas.plot import matplotlib as module except ImportError: raise ImportError( "matplotlib is required for plotting when the " @@ -343,16 +344,23 @@ def _get_plot_backend(backend=None): KoalasPlotAccessor._backends["matplotlib"] = module + if backend == "plotly": + try: + # test if plotly can be imported + import plotly # noqa: F401 + from databricks.koalas.plot import plotly as module + except ImportError: + raise ImportError( + "matplotlib is required for plotting when the " + "default backend 'matplotlib' is selected." + ) from None + + KoalasPlotAccessor._backends["plotly"] = module + if backend in KoalasPlotAccessor._backends: return KoalasPlotAccessor._backends[backend] module = KoalasPlotAccessor._find_backend(backend) - - if backend == "plotly": - from databricks.koalas.plot.plotly import plot_plotly - - module.plot = plot_plotly(module.plot) - KoalasPlotAccessor._backends[backend] = module return module @@ -360,7 +368,9 @@ def __call__(self, kind="line", backend=None, **kwargs): plot_backend = KoalasPlotAccessor._get_plot_backend(backend) plot_data = self.data - if plot_backend.__name__ != "databricks.koalas.plot": + # TODO: make 'databricks.koalas.plot.matplotlib' module to implement + # plot interface. + if plot_backend.__name__ != "databricks.koalas.plot.matplotlib": data_preprocessor_map = { "pie": TopNPlotBase().get_top_n, "bar": TopNPlotBase().get_top_n, diff --git a/databricks/koalas/plot/matplotlib.py b/databricks/koalas/plot/matplotlib.py index a2b8b39597..ebaaeee37c 100644 --- a/databricks/koalas/plot/matplotlib.py +++ b/databricks/koalas/plot/matplotlib.py @@ -16,7 +16,7 @@ from distutils.version import LooseVersion -import matplotlib +import matplotlib as mat import numpy as np import pandas as pd from matplotlib.axes._base import _process_plot_format @@ -111,9 +111,7 @@ def update_dict(dictionary, rc_name, properties): if dictionary is None: dictionary = dict() for prop_dict in properties: - dictionary.setdefault( - prop_dict, matplotlib.rcParams[rc_str.format(rc_name, prop_dict)] - ) + dictionary.setdefault(prop_dict, mat.rcParams[rc_str.format(rc_name, prop_dict)]) return dictionary # Common property dictionaries loading from rc @@ -203,7 +201,7 @@ def update_dict(dictionary, rc_name, properties): if manage_ticks is not None: should_manage_ticks = manage_ticks - if LooseVersion(matplotlib.__version__) < LooseVersion("3.1.0"): + if LooseVersion(mat.__version__) < LooseVersion("3.1.0"): extra_args = {"manage_xticks": should_manage_ticks} else: extra_args = {"manage_ticks": should_manage_ticks} @@ -333,26 +331,26 @@ def rc_defaults( ): # Missing arguments default to rcParams. if whis is None: - whis = matplotlib.rcParams["boxplot.whiskers"] + whis = mat.rcParams["boxplot.whiskers"] if bootstrap is None: - bootstrap = matplotlib.rcParams["boxplot.bootstrap"] + bootstrap = mat.rcParams["boxplot.bootstrap"] if notch is None: - notch = matplotlib.rcParams["boxplot.notch"] + notch = mat.rcParams["boxplot.notch"] if vert is None: - vert = matplotlib.rcParams["boxplot.vertical"] + vert = mat.rcParams["boxplot.vertical"] if patch_artist is None: - patch_artist = matplotlib.rcParams["boxplot.patchartist"] + patch_artist = mat.rcParams["boxplot.patchartist"] if meanline is None: - meanline = matplotlib.rcParams["boxplot.meanline"] + meanline = mat.rcParams["boxplot.meanline"] if showmeans is None: - showmeans = matplotlib.rcParams["boxplot.showmeans"] + showmeans = mat.rcParams["boxplot.showmeans"] if showcaps is None: - showcaps = matplotlib.rcParams["boxplot.showcaps"] + showcaps = mat.rcParams["boxplot.showcaps"] if showbox is None: - showbox = matplotlib.rcParams["boxplot.showbox"] + showbox = mat.rcParams["boxplot.showbox"] if showfliers is None: - showfliers = matplotlib.rcParams["boxplot.showfliers"] + showfliers = mat.rcParams["boxplot.showfliers"] return dict( whis=whis, diff --git a/databricks/koalas/plot/plotly.py b/databricks/koalas/plot/plotly.py index ac14b4dc9b..ad1de26ec8 100644 --- a/databricks/koalas/plot/plotly.py +++ b/databricks/koalas/plot/plotly.py @@ -15,22 +15,21 @@ # import pandas as pd -from databricks.koalas.plot import HistogramPlotBase +from databricks.koalas.plot import HistogramPlotBase, name_like_string -def plot_plotly(origin_plot): - def plot(data, kind, **kwargs): - # Koalas specific plots - if kind == "pie": - return plot_pie(data, **kwargs) - if kind == "hist": - # Note that here data is a Koalas DataFrame or Series unlike other type of plots. - return plot_histogram(data, **kwargs) +def plot(data, kind, **kwargs): + import plotly - # Other plots. - return origin_plot(data, kind, **kwargs) + # Koalas specific plots + if kind == "pie": + return plot_pie(data, **kwargs) + if kind == "hist": + # Note that here data is a Koalas DataFrame or Series unlike other type of plots. + return plot_histogram(data, **kwargs) - return plot + # Other plots. + return plotly.plot(data, kind, **kwargs) def plot_pie(data, **kwargs): @@ -40,7 +39,6 @@ def plot_pie(data, **kwargs): pdf = data.to_frame() return express.pie(pdf, values=pdf.columns[0], names=pdf.index, **kwargs) elif isinstance(data, pd.DataFrame): - # DataFrame values = kwargs.pop("y", None) default_names = None if values is not None: @@ -57,37 +55,35 @@ def plot_pie(data, **kwargs): def plot_histogram(data, **kwargs): - from plotly import express - from plotly.subplots import make_subplots import plotly.graph_objs as go - from databricks import koalas as ks bins = kwargs.get("bins", 10) - is_single_column = isinstance(data, ks.Series) kdf, bins = HistogramPlotBase.prepare_hist_data(data, bins) output_series = HistogramPlotBase.compute_hist(kdf, bins) + prev = float("%.9f" % bins[0]) # to make it prettier, truncate. + text_bins = [] + for b in bins[1:]: + norm_b = float("%.9f" % b) + text_bins.append("(%s, %s)" % (prev, norm_b)) + prev = norm_b bins = 0.5 * (bins[:-1] + bins[1:]) - if is_single_column: - output_series = list(output_series) - assert len(output_series) == 1 - output_series = output_series[0] - return express.bar( - x=bins, y=output_series, labels={"x": str(output_series.name), "y": "count"} - ) - else: - output_series = list(output_series) - fig = make_subplots(rows=1, cols=len(output_series)) - for i, series in enumerate(output_series): - fig.add_trace(go.Bar(x=bins, y=series, name=series.name,), row=1, col=i + 1) + output_series = list(output_series) + bars = [] + for series in output_series: + bars.append( + go.Bar( + x=bins, + y=series, + name=name_like_string(series.name), + text=text_bins, + hovertemplate=( + "variable=" + name_like_string(series.name) + "
value=%{text}
count=%{y}" + ), + ) + ) - for i, series in enumerate(output_series): - if i == 0: - xaxis = "xaxis" - yaxis = "yaxis" - else: - xaxis = "xaxis%s" % (i + 1) - yaxis = "yaxis%s" % (i + 1) - fig["layout"][xaxis]["title"] = str(series.name) - fig["layout"][yaxis]["title"] = "count" - return fig + fig = go.Figure(data=bars, layout=go.Layout(barmode="stack")) + fig["layout"]["xaxis"]["title"] = "value" + fig["layout"]["yaxis"]["title"] = "count" + return fig diff --git a/databricks/koalas/tests/plot/test_frame_plot_plotly.py b/databricks/koalas/tests/plot/test_frame_plot_plotly.py index 8c7fa68941..1e0880a818 100644 --- a/databricks/koalas/tests/plot/test_frame_plot_plotly.py +++ b/databricks/koalas/tests/plot/test_frame_plot_plotly.py @@ -20,7 +20,6 @@ import pandas as pd import numpy as np from plotly import express -from plotly.subplots import make_subplots import plotly.graph_objs as go from databricks import koalas as ks @@ -182,23 +181,37 @@ def check_pie_plot(kdf): def test_hist_plot(self): def check_hist_plot(kdf): bins = np.array([1.0, 5.9, 10.8, 15.7, 20.6, 25.5, 30.4, 35.3, 40.2, 45.1, 50.0]) - bins = 0.5 * (bins[:-1] + bins[1:]) data = [ np.array([5.0, 4.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0,]), np.array([4.0, 3.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0,]), ] - - fig = make_subplots(rows=1, cols=len(data)) - fig.add_trace( - go.Bar(x=bins, y=data[0], name=name_like_string(kdf.columns[0]),), row=1, col=1 - ) - fig.add_trace( - go.Bar(x=bins, y=data[1], name=name_like_string(kdf.columns[1]),), row=1, col=2 - ) - fig["layout"]["xaxis"]["title"] = name_like_string(kdf.columns[0]) + prev = bins[0] + text_bins = [] + for b in bins[1:]: + text_bins.append("(%s, %s)" % (prev, b)) + prev = b + bins = 0.5 * (bins[:-1] + bins[1:]) + name_a = name_like_string(kdf.columns[0]) + name_b = name_like_string(kdf.columns[1]) + bars = [ + go.Bar( + x=bins, + y=data[0], + name=name_a, + text=text_bins, + hovertemplate=("variable=" + name_a + "
value=%{text}
count=%{y}"), + ), + go.Bar( + x=bins, + y=data[1], + name=name_b, + text=text_bins, + hovertemplate=("variable=" + name_b + "
value=%{text}
count=%{y}"), + ), + ] + fig = go.Figure(data=bars, layout=go.Layout(barmode="stack")) + fig["layout"]["xaxis"]["title"] = "value" fig["layout"]["yaxis"]["title"] = "count" - fig["layout"]["xaxis2"]["title"] = name_like_string(kdf.columns[1]) - fig["layout"]["yaxis2"]["title"] = "count" self.assertEqual( pprint.pformat(kdf.plot(kind="hist").to_dict()), pprint.pformat(fig.to_dict()) diff --git a/databricks/koalas/tests/plot/test_series_plot_plotly.py b/databricks/koalas/tests/plot/test_series_plot_plotly.py index 834d04e02a..16750ccd64 100644 --- a/databricks/koalas/tests/plot/test_series_plot_plotly.py +++ b/databricks/koalas/tests/plot/test_series_plot_plotly.py @@ -20,6 +20,7 @@ import pandas as pd import numpy as np from plotly import express +import plotly.graph_objs as go from databricks import koalas as ks from databricks.koalas.config import set_option, reset_option @@ -135,13 +136,26 @@ def test_pie_plot(self): def test_hist_plot(self): def check_hist_plot(kser): bins = np.array([1.0, 5.9, 10.8, 15.7, 20.6, 25.5, 30.4, 35.3, 40.2, 45.1, 50.0]) + data = np.array([5.0, 4.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0,]) + prev = bins[0] + text_bins = [] + for b in bins[1:]: + text_bins.append("[%s, %s)" % (prev, b)) + prev = b bins = 0.5 * (bins[:-1] + bins[1:]) - data = pd.Series( - np.array([5.0, 4.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0,]), name=kser.name - ) - fig = express.bar( - x=bins, y=data, labels={"x": name_like_string(data.name), "y": "count"} - ) + name_a = name_like_string(kser.name) + bars = [ + go.Bar( + x=bins, + y=data, + name=name_a, + text=text_bins, + hovertemplate=("variable=" + name_a + "
value=%{text}
count=%{y}"), + ), + ] + fig = go.Figure(data=bars, layout=go.Layout(barmode="stack")) + fig["layout"]["xaxis"]["title"] = "value" + fig["layout"]["yaxis"]["title"] = "count" self.assertEqual( pprint.pformat(kser.plot(kind="hist").to_dict()), pprint.pformat(fig.to_dict())