Skip to content

Commit

Permalink
Tests
Browse files Browse the repository at this point in the history
  • Loading branch information
HyukjinKwon committed Jan 12, 2021
1 parent 71e13d5 commit ea3d20a
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 15 deletions.
16 changes: 3 additions & 13 deletions databricks/koalas/plot/plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,19 +62,9 @@ def plot_histogram(data, **kwargs):
import plotly.graph_objs as go
from databricks import koalas as ks

assert "bins" in kwargs
bins = kwargs["bins"]
data, bins = HistogramPlotBase.prepare_hist_data(data, bins)

is_single_column = False
if isinstance(data, ks.Series):
is_single_column = True
kdf = data.to_frame()
elif isinstance(data, ks.DataFrame):
kdf = data
else:
raise RuntimeError("Unexpected type: [%s]" % type(data))

bins = kwargs.get("bins", 10)
is_single_column = isinstance(data, ks.Series)
kdf, bins = HistogramPlotBase.prepare_hist_data(data, bins)
output_series = HistogramPlotBase.compute_hist(kdf, bins)
bins = 0.5 * (bins[:-1] + bins[1:])
if is_single_column:
Expand Down
4 changes: 2 additions & 2 deletions databricks/koalas/tests/plot/test_frame_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def test_compute_hist_single_column(self):
expected_histogram = np.array([5, 4, 1, 0, 0, 0, 0, 0, 0, 1])
histogram = HistogramPlotBase.compute_hist(kdf[["a"]], bins)[0]
self.assert_eq(pd.Series(expected_bins), pd.Series(bins))
self.assert_eq(pd.Series(expected_histogram, name="__a_bucket"), histogram, almost=True)
self.assert_eq(pd.Series(expected_histogram, name="a"), histogram, almost=True)

def test_compute_hist_multi_columns(self):
expected_bins = np.linspace(1, 50, 11)
Expand All @@ -100,7 +100,7 @@ def test_compute_hist_multi_columns(self):
np.array([4, 1, 0, 0, 1, 3, 0, 0, 0, 2]),
]
histograms = HistogramPlotBase.compute_hist(kdf, bins)
expected_names = ["__a_bucket", "__b_bucket"]
expected_names = ["a", "b"]

for histogram, expected_histogram, expected_name in zip(
histograms, expected_histograms, expected_names
Expand Down
36 changes: 36 additions & 0 deletions databricks/koalas/tests/plot/test_frame_plot_plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,18 @@
#
import unittest
from distutils.version import LooseVersion
import pprint

import pandas as pd
import numpy as np
from plotly import express
from plotly.subplots import make_subplots
import plotly.graph_objs as go

from databricks import koalas as ks
from databricks.koalas.config import set_option, reset_option
from databricks.koalas.testing.utils import ReusedSQLTestCase, TestUtils
from databricks.koalas.utils import name_like_string


@unittest.skipIf(
Expand Down Expand Up @@ -174,3 +178,35 @@ def check_pie_plot(kdf):
# index=pd.MultiIndex.from_tuples([("x", "y")] * 11),
# )
# check_pie_plot(kdf1)

def test_hist_plot(self):
def check_hist_plot(kdf):
bins = np.array([1.0, 5.9, 10.8, 15.7, 20.6, 25.5, 30.4, 35.3, 40.2, 45.1, 50.0])
bins = 0.5 * (bins[:-1] + bins[1:])
data = [
np.array([5.0, 4.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0,]),
np.array([4.0, 3.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0,]),
]

fig = make_subplots(rows=1, cols=len(data))
fig.add_trace(
go.Bar(x=bins, y=data[0], name=name_like_string(kdf.columns[0]),), row=1, col=1
)
fig.add_trace(
go.Bar(x=bins, y=data[1], name=name_like_string(kdf.columns[1]),), row=1, col=2
)
fig["layout"]["xaxis"]["title"] = name_like_string(kdf.columns[0])
fig["layout"]["yaxis"]["title"] = "count"
fig["layout"]["xaxis2"]["title"] = name_like_string(kdf.columns[1])
fig["layout"]["yaxis2"]["title"] = "count"

self.assertEqual(
pprint.pformat(kdf.plot(kind="hist").to_dict()), pprint.pformat(fig.to_dict())
)

kdf1 = self.kdf1
check_hist_plot(kdf1)

columns = pd.MultiIndex.from_tuples([("x", "y"), ("y", "z")])
kdf1.columns = columns
check_hist_plot(kdf1)
25 changes: 25 additions & 0 deletions databricks/koalas/tests/plot/test_series_plot_plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@
#
import unittest
from distutils.version import LooseVersion
import pprint

import pandas as pd
import numpy as np
from plotly import express

from databricks import koalas as ks
from databricks.koalas.config import set_option, reset_option
from databricks.koalas.testing.utils import ReusedSQLTestCase, TestUtils
from databricks.koalas.utils import name_like_string


@unittest.skipIf(
Expand Down Expand Up @@ -128,3 +131,25 @@ def test_pie_plot(self):
# self.assertEqual(
# kdf["a"].plot(kind="pie"), express.pie(pdf, values=pdf.columns[0], names=pdf.index),
# )

def test_hist_plot(self):
def check_hist_plot(kser):
bins = np.array([1.0, 5.9, 10.8, 15.7, 20.6, 25.5, 30.4, 35.3, 40.2, 45.1, 50.0])
bins = 0.5 * (bins[:-1] + bins[1:])
data = pd.Series(
np.array([5.0, 4.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0,]), name=kser.name
)
fig = express.bar(
x=bins, y=data, labels={"x": name_like_string(data.name), "y": "count"}
)

self.assertEqual(
pprint.pformat(kser.plot(kind="hist").to_dict()), pprint.pformat(fig.to_dict())
)

kdf1 = self.kdf1
check_hist_plot(kdf1["a"])

columns = pd.MultiIndex.from_tuples([("x", "y")])
kdf1.columns = columns
check_hist_plot(kdf1[("x", "y")])

0 comments on commit ea3d20a

Please sign in to comment.