Skip to content

Commit

Permalink
Make a better shape
Browse files Browse the repository at this point in the history
  • Loading branch information
HyukjinKwon committed Jan 12, 2021
1 parent ea3d20a commit 4d09feb
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 84 deletions.
29 changes: 19 additions & 10 deletions databricks/koalas/plot/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,14 @@

import pandas as pd
import numpy as np

from databricks.koalas.utils import name_like_string
from pyspark.ml.feature import Bucketizer
from pyspark.sql import functions as F
from pandas.core.base import PandasObject
from pandas.core.dtypes.inference import is_integer

from databricks.koalas.missing import unsupported_function
from databricks.koalas.config import get_option
from databricks.koalas.utils import name_like_string


class TopNPlotBase:
Expand Down Expand Up @@ -330,11 +329,12 @@ def _get_plot_backend(backend=None):

backend = backend or get_option("plotting.backend")

# TODO: leverage koalas_plotting_backends, and remove the codes below.
if backend == "matplotlib":
# Because matplotlib is an optional dependency and first-party backend,
# we need to attempt an import here to raise an ImportError if needed.
try:
import databricks.koalas.plot as module
from databricks.koalas.plot import matplotlib as module
except ImportError:
raise ImportError(
"matplotlib is required for plotting when the "
Expand All @@ -343,24 +343,33 @@ def _get_plot_backend(backend=None):

KoalasPlotAccessor._backends["matplotlib"] = module

if backend == "plotly":
try:
# test if plotly can be imported
import plotly # noqa: F401
from databricks.koalas.plot import plotly as module
except ImportError:
raise ImportError(
"matplotlib is required for plotting when the "
"default backend 'matplotlib' is selected."
) from None

KoalasPlotAccessor._backends["plotly"] = module

if backend in KoalasPlotAccessor._backends:
return KoalasPlotAccessor._backends[backend]

module = KoalasPlotAccessor._find_backend(backend)

if backend == "plotly":
from databricks.koalas.plot.plotly import plot_plotly

module.plot = plot_plotly(module.plot)

KoalasPlotAccessor._backends[backend] = module
return module

def __call__(self, kind="line", backend=None, **kwargs):
plot_backend = KoalasPlotAccessor._get_plot_backend(backend)
plot_data = self.data

if plot_backend.__name__ != "databricks.koalas.plot":
# TODO: make 'databricks.koalas.plot.matplotlib' module to implement
# plot interface.
if plot_backend.__name__ != "databricks.koalas.plot.matplotlib":
data_preprocessor_map = {
"pie": TopNPlotBase().get_top_n,
"bar": TopNPlotBase().get_top_n,
Expand Down
28 changes: 13 additions & 15 deletions databricks/koalas/plot/matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from distutils.version import LooseVersion

import matplotlib
import matplotlib as mat
import numpy as np
import pandas as pd
from matplotlib.axes._base import _process_plot_format
Expand Down Expand Up @@ -111,9 +111,7 @@ def update_dict(dictionary, rc_name, properties):
if dictionary is None:
dictionary = dict()
for prop_dict in properties:
dictionary.setdefault(
prop_dict, matplotlib.rcParams[rc_str.format(rc_name, prop_dict)]
)
dictionary.setdefault(prop_dict, mat.rcParams[rc_str.format(rc_name, prop_dict)])
return dictionary

# Common property dictionaries loading from rc
Expand Down Expand Up @@ -203,7 +201,7 @@ def update_dict(dictionary, rc_name, properties):
if manage_ticks is not None:
should_manage_ticks = manage_ticks

if LooseVersion(matplotlib.__version__) < LooseVersion("3.1.0"):
if LooseVersion(mat.__version__) < LooseVersion("3.1.0"):
extra_args = {"manage_xticks": should_manage_ticks}
else:
extra_args = {"manage_ticks": should_manage_ticks}
Expand Down Expand Up @@ -333,26 +331,26 @@ def rc_defaults(
):
# Missing arguments default to rcParams.
if whis is None:
whis = matplotlib.rcParams["boxplot.whiskers"]
whis = mat.rcParams["boxplot.whiskers"]
if bootstrap is None:
bootstrap = matplotlib.rcParams["boxplot.bootstrap"]
bootstrap = mat.rcParams["boxplot.bootstrap"]

if notch is None:
notch = matplotlib.rcParams["boxplot.notch"]
notch = mat.rcParams["boxplot.notch"]
if vert is None:
vert = matplotlib.rcParams["boxplot.vertical"]
vert = mat.rcParams["boxplot.vertical"]
if patch_artist is None:
patch_artist = matplotlib.rcParams["boxplot.patchartist"]
patch_artist = mat.rcParams["boxplot.patchartist"]
if meanline is None:
meanline = matplotlib.rcParams["boxplot.meanline"]
meanline = mat.rcParams["boxplot.meanline"]
if showmeans is None:
showmeans = matplotlib.rcParams["boxplot.showmeans"]
showmeans = mat.rcParams["boxplot.showmeans"]
if showcaps is None:
showcaps = matplotlib.rcParams["boxplot.showcaps"]
showcaps = mat.rcParams["boxplot.showcaps"]
if showbox is None:
showbox = matplotlib.rcParams["boxplot.showbox"]
showbox = mat.rcParams["boxplot.showbox"]
if showfliers is None:
showfliers = matplotlib.rcParams["boxplot.showfliers"]
showfliers = mat.rcParams["boxplot.showfliers"]

return dict(
whis=whis,
Expand Down
77 changes: 38 additions & 39 deletions databricks/koalas/plot/plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,21 @@
#
import pandas as pd

from databricks.koalas.plot import HistogramPlotBase
from databricks.koalas.plot import HistogramPlotBase, name_like_string


def plot_plotly(origin_plot):
def plot(data, kind, **kwargs):
# Koalas specific plots
if kind == "pie":
return plot_pie(data, **kwargs)
if kind == "hist":
# Note that here data is a Koalas DataFrame or Series unlike other type of plots.
return plot_histogram(data, **kwargs)
def plot(data, kind, **kwargs):
import plotly

# Other plots.
return origin_plot(data, kind, **kwargs)
# Koalas specific plots
if kind == "pie":
return plot_pie(data, **kwargs)
if kind == "hist":
# Note that here data is a Koalas DataFrame or Series unlike other type of plots.
return plot_histogram(data, **kwargs)

return plot
# Other plots.
return plotly.plot(data, kind, **kwargs)


def plot_pie(data, **kwargs):
Expand All @@ -40,7 +39,6 @@ def plot_pie(data, **kwargs):
pdf = data.to_frame()
return express.pie(pdf, values=pdf.columns[0], names=pdf.index, **kwargs)
elif isinstance(data, pd.DataFrame):
# DataFrame
values = kwargs.pop("y", None)
default_names = None
if values is not None:
Expand All @@ -57,37 +55,38 @@ def plot_pie(data, **kwargs):


def plot_histogram(data, **kwargs):
from plotly import express
from plotly.subplots import make_subplots
import plotly.graph_objs as go
from databricks import koalas as ks

bins = kwargs.get("bins", 10)
is_single_column = isinstance(data, ks.Series)
kdf, bins = HistogramPlotBase.prepare_hist_data(data, bins)
assert len(bins) > 2, "the number of buckets must be higher than 2."
output_series = HistogramPlotBase.compute_hist(kdf, bins)
prev = float("%.9f" % bins[0]) # to make it prettier, truncate.
text_bins = []
for b in bins[1:]:
norm_b = float("%.9f" % b)
text_bins.append("[%s, %s)" % (prev, norm_b))
prev = norm_b
text_bins[-1] = text_bins[-1][:-1] + "]" # replace ) to ] for the last bucket.

bins = 0.5 * (bins[:-1] + bins[1:])
if is_single_column:
output_series = list(output_series)
assert len(output_series) == 1
output_series = output_series[0]
return express.bar(
x=bins, y=output_series, labels={"x": str(output_series.name), "y": "count"}
)
else:
output_series = list(output_series)
fig = make_subplots(rows=1, cols=len(output_series))

for i, series in enumerate(output_series):
fig.add_trace(go.Bar(x=bins, y=series, name=series.name,), row=1, col=i + 1)
output_series = list(output_series)
bars = []
for series in output_series:
bars.append(
go.Bar(
x=bins,
y=series,
name=name_like_string(series.name),
text=text_bins,
hovertemplate=(
"variable=" + name_like_string(series.name) + "<br>value=%{text}<br>count=%{y}"
),
)
)

for i, series in enumerate(output_series):
if i == 0:
xaxis = "xaxis"
yaxis = "yaxis"
else:
xaxis = "xaxis%s" % (i + 1)
yaxis = "yaxis%s" % (i + 1)
fig["layout"][xaxis]["title"] = str(series.name)
fig["layout"][yaxis]["title"] = "count"
return fig
fig = go.Figure(data=bars, layout=go.Layout(barmode="stack"))
fig["layout"]["xaxis"]["title"] = "value"
fig["layout"]["yaxis"]["title"] = "count"
return fig
40 changes: 27 additions & 13 deletions databricks/koalas/tests/plot/test_frame_plot_plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import pandas as pd
import numpy as np
from plotly import express
from plotly.subplots import make_subplots
import plotly.graph_objs as go

from databricks import koalas as ks
Expand Down Expand Up @@ -182,23 +181,38 @@ def check_pie_plot(kdf):
def test_hist_plot(self):
def check_hist_plot(kdf):
bins = np.array([1.0, 5.9, 10.8, 15.7, 20.6, 25.5, 30.4, 35.3, 40.2, 45.1, 50.0])
bins = 0.5 * (bins[:-1] + bins[1:])
data = [
np.array([5.0, 4.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0,]),
np.array([4.0, 3.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0,]),
]

fig = make_subplots(rows=1, cols=len(data))
fig.add_trace(
go.Bar(x=bins, y=data[0], name=name_like_string(kdf.columns[0]),), row=1, col=1
)
fig.add_trace(
go.Bar(x=bins, y=data[1], name=name_like_string(kdf.columns[1]),), row=1, col=2
)
fig["layout"]["xaxis"]["title"] = name_like_string(kdf.columns[0])
prev = bins[0]
text_bins = []
for b in bins[1:]:
text_bins.append("[%s, %s)" % (prev, b))
prev = b
text_bins[-1] = text_bins[-1][:-1] + "]"
bins = 0.5 * (bins[:-1] + bins[1:])
name_a = name_like_string(kdf.columns[0])
name_b = name_like_string(kdf.columns[1])
bars = [
go.Bar(
x=bins,
y=data[0],
name=name_a,
text=text_bins,
hovertemplate=("variable=" + name_a + "<br>value=%{text}<br>count=%{y}"),
),
go.Bar(
x=bins,
y=data[1],
name=name_b,
text=text_bins,
hovertemplate=("variable=" + name_b + "<br>value=%{text}<br>count=%{y}"),
),
]
fig = go.Figure(data=bars, layout=go.Layout(barmode="stack"))
fig["layout"]["xaxis"]["title"] = "value"
fig["layout"]["yaxis"]["title"] = "count"
fig["layout"]["xaxis2"]["title"] = name_like_string(kdf.columns[1])
fig["layout"]["yaxis2"]["title"] = "count"

self.assertEqual(
pprint.pformat(kdf.plot(kind="hist").to_dict()), pprint.pformat(fig.to_dict())
Expand Down
2 changes: 1 addition & 1 deletion databricks/koalas/tests/plot/test_series_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_plot_backends(self):
self.assertEqual(ks.options.plotting.backend, plot_backend)

module = KoalasPlotAccessor._get_plot_backend(plot_backend)
self.assertEqual(module.__name__, plot_backend)
self.assertEqual(module.__name__, "databricks.koalas.plot.plotly")

def test_plot_backends_incorrect(self):
fake_plot_backend = "none_plotting_module"
Expand Down
27 changes: 21 additions & 6 deletions databricks/koalas/tests/plot/test_series_plot_plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import pandas as pd
import numpy as np
from plotly import express
import plotly.graph_objs as go

from databricks import koalas as ks
from databricks.koalas.config import set_option, reset_option
Expand Down Expand Up @@ -135,13 +136,27 @@ def test_pie_plot(self):
def test_hist_plot(self):
def check_hist_plot(kser):
bins = np.array([1.0, 5.9, 10.8, 15.7, 20.6, 25.5, 30.4, 35.3, 40.2, 45.1, 50.0])
data = np.array([5.0, 4.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0,])
prev = bins[0]
text_bins = []
for b in bins[1:]:
text_bins.append("[%s, %s)" % (prev, b))
prev = b
text_bins[-1] = text_bins[-1][:-1] + "]"
bins = 0.5 * (bins[:-1] + bins[1:])
data = pd.Series(
np.array([5.0, 4.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0,]), name=kser.name
)
fig = express.bar(
x=bins, y=data, labels={"x": name_like_string(data.name), "y": "count"}
)
name_a = name_like_string(kser.name)
bars = [
go.Bar(
x=bins,
y=data,
name=name_a,
text=text_bins,
hovertemplate=("variable=" + name_a + "<br>value=%{text}<br>count=%{y}"),
),
]
fig = go.Figure(data=bars, layout=go.Layout(barmode="stack"))
fig["layout"]["xaxis"]["title"] = "value"
fig["layout"]["yaxis"]["title"] = "count"

self.assertEqual(
pprint.pformat(kser.plot(kind="hist").to_dict()), pprint.pformat(fig.to_dict())
Expand Down

0 comments on commit 4d09feb

Please sign in to comment.