Skip to content

Commit

Permalink
Implement (DataFrame|Series).plot.hist in plotly
Browse files Browse the repository at this point in the history
  • Loading branch information
HyukjinKwon committed Jan 12, 2021
1 parent ea6ad98 commit 0eef796
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 4 deletions.
15 changes: 11 additions & 4 deletions databricks/koalas/plot/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

import pandas as pd
import numpy as np

from databricks.koalas.utils import name_like_string
from pyspark.ml.feature import Bucketizer
from pyspark.sql import functions as F
from pandas.core.base import PandasObject
Expand Down Expand Up @@ -143,8 +145,11 @@ def compute_hist(kdf, bins):

sdf = kdf._internal.spark_frame
scols = []
input_column_names = []
for label in kdf._internal.column_labels:
scols.append(kdf._internal.spark_column_for(label))
input_column_name = name_like_string(label)
input_column_names.append(input_column_name)
scols.append(kdf._internal.spark_column_for(label).alias(input_column_name))
sdf = sdf.select(*scols)

# 1. Make the bucket output flat to:
Expand Down Expand Up @@ -243,7 +248,7 @@ def compute_hist(kdf, bins):
# |0 |
# +-----------------+
output_series = []
for i, bucket_name in enumerate(bucket_names):
for i, (input_column_name, bucket_name) in enumerate(zip(input_column_names, bucket_names)):
current_bucket_result = result[result["__group_id"] == i]
# generates a pandas DF with one row for each bin
# we need this as some of the bins may be empty
Expand All @@ -252,8 +257,8 @@ def compute_hist(kdf, bins):
pdf = indexes.merge(current_bucket_result, how="left", on=["__bucket"]).fillna(0)[
["count"]
]
pdf.columns = [bucket_name]
output_series.append(pdf[bucket_name])
pdf.columns = [input_column_name]
output_series.append(pdf[input_column_name])

return output_series

Expand Down Expand Up @@ -363,6 +368,8 @@ def __call__(self, kind="line", backend=None, **kwargs):
"scatter": TopNPlotBase().get_top_n,
"area": SampledPlotBase().get_sampled,
"line": SampledPlotBase().get_sampled,
# if histogram is not supported, the backend will throw an exception
"hist": lambda data: data,
}
if not data_preprocessor_map[kind]:
raise NotImplementedError(
Expand Down
52 changes: 52 additions & 0 deletions databricks/koalas/plot/plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,17 @@
#
import pandas as pd

from databricks.koalas.plot import HistogramPlotBase


def plot_plotly(origin_plot):
def plot(data, kind, **kwargs):
# Koalas specific plots
if kind == "pie":
return plot_pie(data, **kwargs)
if kind == "hist":
# Note that here data is a Koalas DataFrame or Series unlike other type of plots.
return plot_histogram(data, **kwargs)

# Other plots.
return origin_plot(data, kind, **kwargs)
Expand Down Expand Up @@ -49,3 +54,50 @@ def plot_pie(data, **kwargs):
)
else:
raise RuntimeError("Unexpected type: [%s]" % type(data))


def plot_histogram(data, **kwargs):
from plotly import express
from plotly.subplots import make_subplots
import plotly.graph_objs as go
from databricks import koalas as ks

assert "bins" in kwargs
bins = kwargs["bins"]
data, bins = HistogramPlotBase.prepare_hist_data(data, bins)

is_single_column = False
if isinstance(data, ks.Series):
is_single_column = True
kdf = data.to_frame()
elif isinstance(data, ks.DataFrame):
kdf = data
else:
raise RuntimeError("Unexpected type: [%s]" % type(data))

output_series = HistogramPlotBase.compute_hist(kdf, bins)
bins = 0.5 * (bins[:-1] + bins[1:])
if is_single_column:
output_series = list(output_series)
assert len(output_series) == 1
output_series = output_series[0]
return express.bar(
x=bins, y=output_series, labels={"x": str(output_series.name), "y": "count"}
)
else:
output_series = list(output_series)
fig = make_subplots(rows=1, cols=len(output_series))

for i, series in enumerate(output_series):
fig.add_trace(go.Bar(x=bins, y=series,), row=1, col=i + 1)

for i, series in enumerate(output_series):
if i == 0:
xaxis = "xaxis"
yaxis = "yaxis"
else:
xaxis = "xaxis%s" % (i + 1)
yaxis = "yaxis%s" % (i + 1)
fig["layout"][xaxis]["title"] = str(series.name)
fig["layout"][yaxis]["title"] = "count"
return fig

0 comments on commit 0eef796

Please sign in to comment.