Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Add scatter plot for Frame #719

Merged
merged 9 commits into from
Aug 30, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 62 additions & 12 deletions databricks/koalas/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ def _get_standard_kind(kind):

if LooseVersion(pd.__version__) < LooseVersion('0.25'):
from pandas.plotting._core import _all_kinds, BarPlot, BoxPlot, HistPlot, MPLPlot, PiePlot, \
AreaPlot, LinePlot, BarhPlot
AreaPlot, LinePlot, BarhPlot, ScatterPlot
else:
from pandas.plotting._core import PlotAccessor
from pandas.plotting._matplotlib import BarPlot, BoxPlot, HistPlot, PiePlot, AreaPlot, \
LinePlot, BarhPlot
LinePlot, BarhPlot, ScatterPlot
from pandas.plotting._matplotlib.core import MPLPlot
_all_kinds = PlotAccessor._all_kinds

Expand Down Expand Up @@ -509,6 +509,16 @@ def _make_plot(self):
super(KoalasBarhPlot, self)._make_plot()


class KoalasScatterPlot(ScatterPlot, TopNPlot):

def __init__(self, data, x, y, **kwargs):
super().__init__(self.get_top_n(data), x, y, **kwargs)

def _make_plot(self):
self.set_result_text(self._get_ax(0))
super(KoalasScatterPlot, self)._make_plot()


_klasses = [
KoalasHistPlot,
KoalasBarPlot,
Expand All @@ -517,6 +527,7 @@ def _make_plot(self):
KoalasAreaPlot,
KoalasLinePlot,
KoalasBarhPlot,
KoalasScatterPlot,
]
_plot_klass = {getattr(klass, '_kind'): klass for klass in _klasses}

Expand Down Expand Up @@ -651,15 +662,20 @@ def _plot(data, x=None, y=None, subplots=False,
else:
raise ValueError("%r is not a valid plot kind" % kind)

# check data type and do preprocess before applying plot
if isinstance(data, DataFrame):
if x is not None:
data = data.set_index(x)
# TODO: check if value of y is plottable
if y is not None:
data = data[y]
# scatter and hexbin are inherited from PlanePlot which require x and y
if kind in ('scatter', 'hexbin'):
plot_obj = klass(data, x, y, subplots=subplots, ax=ax, kind=kind, **kwds)
else:

# check data type and do preprocess before applying plot
if isinstance(data, DataFrame):
if x is not None:
data = data.set_index(x)
# TODO: check if value of y is plottable
if y is not None:
data = data[y]

plot_obj = klass(data, subplots=subplots, ax=ax, kind=kind, **kwds)
plot_obj = klass(data, subplots=subplots, ax=ax, kind=kind, **kwds)
plot_obj.generate()
plot_obj.draw()
return plot_obj.result
Expand Down Expand Up @@ -1082,8 +1098,41 @@ def box(self, bw_method=None, ind=None, **kwds):
def hist(self, bw_method=None, ind=None, **kwds):
return _unsupported_function(class_name='pd.DataFrame', method_name='hist')()

def scatter(self, bw_method=None, ind=None, **kwds):
return _unsupported_function(class_name='pd.DataFrame', method_name='scatter')()
def scatter(self, x, y, s=None, c=None, **kwds):
"""
Create a scatter plot with varying marker point size and color.
The coordinates of each point are defined by two dataframe columns and
filled circles are used to represent each point. This kind of plot is
useful to see complex correlations between two variables. Points could
be for instance natural 2D coordinates like longitude and latitude in
a map or, in general, any pair of metrics that can be plotted against
each other.
Parameters
----------
x : int or str
The column name or column position to be used as horizontal
coordinates for each point.
y : int or str
The column name or column position to be used as vertical
coordinates for each point.
s : scalar or array_like, optional
c : str, int or array_like, optional
**kwds: Optional
Keyword arguments to pass on to :meth:`databricks.koalas.DataFrame.plot`.
Returns
-------
:class:`matplotlib.axes.Axes` or numpy.ndarray of them
See Also
--------
matplotlib.pyplot.scatter : Scatter plot using multiple input data
formats.
"""
return self(kind="scatter", x=x, y=y, s=s, c=c, **kwds)


def plot_frame(data, x=None, y=None, kind='line', ax=None,
Expand Down Expand Up @@ -1116,6 +1165,7 @@ def plot_frame(data, x=None, y=None, kind='line', ax=None,
- 'density' : same as 'kde'
- 'area' : area plot
- 'pie' : pie plot
- 'scatter' : scatter plot
ax : matplotlib axes object
If not passed, uses gca()
x : label or position, default None
Expand Down
21 changes: 20 additions & 1 deletion databricks/koalas/tests/test_frame_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import matplotlib
from matplotlib import pyplot as plt
import pandas as pd
import numpy as np

from databricks import koalas
from databricks.koalas.exceptions import PandasNotImplementedError
Expand Down Expand Up @@ -173,10 +174,28 @@ def test_pie_plot_error_message(self):
error_message = "pie requires either y column or 'subplots=True'"
self.assertTrue(error_message in str(context.exception))

def test_scatter_plot(self):
# Use pandas scatter plot example
pdf = pd.DataFrame(np.random.rand(50, 4), columns=['a', 'b', 'c', 'd'])
kdf = koalas.from_pandas(pdf)

ax1 = pdf.plot.scatter(x='a', y='b')
ax2 = kdf.plot.scatter(x='a', y='b')
self.compare_plots(ax1, ax2)

ax1 = pdf.plot(kind='scatter', x='a', y='b')
ax2 = kdf.plot(kind='scatter', x='a', y='b')
self.compare_plots(ax1, ax2)

# check when keyword c is given as name of a column
ax1 = pdf.plot.scatter(x='a', y='b', c='c', s=50)
ax2 = kdf.plot.scatter(x='a', y='b', c='c', s=50)
self.compare_plots(ax1, ax2)

def test_missing(self):
ks = self.kdf1

unsupported_functions = ['box', 'density', 'hexbin', 'hist', 'kde', 'scatter']
unsupported_functions = ['box', 'density', 'hexbin', 'hist', 'kde']

for name in unsupported_functions:
with self.assertRaisesRegex(PandasNotImplementedError,
Expand Down
1 change: 1 addition & 0 deletions docs/source/reference/frame.rst
Original file line number Diff line number Diff line change
Expand Up @@ -237,3 +237,4 @@ specific plotting methods of the form ``DataFrame.plot.<kind>``.
DataFrame.plot.barh
DataFrame.plot.bar
DataFrame.plot.pie
DataFrame.plot.scatter