Skip to content

Commit

Permalink
ENH: Add scatter plot for Frame (#719)
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesdong1991 authored and HyukjinKwon committed Aug 30, 2019
1 parent 5537d71 commit a1efa61
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 13 deletions.
74 changes: 62 additions & 12 deletions databricks/koalas/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ def _get_standard_kind(kind):

if LooseVersion(pd.__version__) < LooseVersion('0.25'):
from pandas.plotting._core import _all_kinds, BarPlot, BoxPlot, HistPlot, MPLPlot, PiePlot, \
AreaPlot, LinePlot, BarhPlot
AreaPlot, LinePlot, BarhPlot, ScatterPlot
else:
from pandas.plotting._core import PlotAccessor
from pandas.plotting._matplotlib import BarPlot, BoxPlot, HistPlot, PiePlot, AreaPlot, \
LinePlot, BarhPlot
LinePlot, BarhPlot, ScatterPlot
from pandas.plotting._matplotlib.core import MPLPlot
_all_kinds = PlotAccessor._all_kinds

Expand Down Expand Up @@ -509,6 +509,16 @@ def _make_plot(self):
super(KoalasBarhPlot, self)._make_plot()


class KoalasScatterPlot(ScatterPlot, TopNPlot):

def __init__(self, data, x, y, **kwargs):
super().__init__(self.get_top_n(data), x, y, **kwargs)

def _make_plot(self):
self.set_result_text(self._get_ax(0))
super(KoalasScatterPlot, self)._make_plot()


_klasses = [
KoalasHistPlot,
KoalasBarPlot,
Expand All @@ -517,6 +527,7 @@ def _make_plot(self):
KoalasAreaPlot,
KoalasLinePlot,
KoalasBarhPlot,
KoalasScatterPlot,
]
_plot_klass = {getattr(klass, '_kind'): klass for klass in _klasses}

Expand Down Expand Up @@ -651,15 +662,20 @@ def _plot(data, x=None, y=None, subplots=False,
else:
raise ValueError("%r is not a valid plot kind" % kind)

# check data type and do preprocess before applying plot
if isinstance(data, DataFrame):
if x is not None:
data = data.set_index(x)
# TODO: check if value of y is plottable
if y is not None:
data = data[y]
# scatter and hexbin are inherited from PlanePlot which require x and y
if kind in ('scatter', 'hexbin'):
plot_obj = klass(data, x, y, subplots=subplots, ax=ax, kind=kind, **kwds)
else:

# check data type and do preprocess before applying plot
if isinstance(data, DataFrame):
if x is not None:
data = data.set_index(x)
# TODO: check if value of y is plottable
if y is not None:
data = data[y]

plot_obj = klass(data, subplots=subplots, ax=ax, kind=kind, **kwds)
plot_obj = klass(data, subplots=subplots, ax=ax, kind=kind, **kwds)
plot_obj.generate()
plot_obj.draw()
return plot_obj.result
Expand Down Expand Up @@ -1082,8 +1098,41 @@ def box(self, bw_method=None, ind=None, **kwds):
def hist(self, bw_method=None, ind=None, **kwds):
return _unsupported_function(class_name='pd.DataFrame', method_name='hist')()

def scatter(self, bw_method=None, ind=None, **kwds):
return _unsupported_function(class_name='pd.DataFrame', method_name='scatter')()
def scatter(self, x, y, s=None, c=None, **kwds):
"""
Create a scatter plot with varying marker point size and color.
The coordinates of each point are defined by two dataframe columns and
filled circles are used to represent each point. This kind of plot is
useful to see complex correlations between two variables. Points could
be for instance natural 2D coordinates like longitude and latitude in
a map or, in general, any pair of metrics that can be plotted against
each other.
Parameters
----------
x : int or str
The column name or column position to be used as horizontal
coordinates for each point.
y : int or str
The column name or column position to be used as vertical
coordinates for each point.
s : scalar or array_like, optional
c : str, int or array_like, optional
**kwds: Optional
Keyword arguments to pass on to :meth:`databricks.koalas.DataFrame.plot`.
Returns
-------
:class:`matplotlib.axes.Axes` or numpy.ndarray of them
See Also
--------
matplotlib.pyplot.scatter : Scatter plot using multiple input data
formats.
"""
return self(kind="scatter", x=x, y=y, s=s, c=c, **kwds)


def plot_frame(data, x=None, y=None, kind='line', ax=None,
Expand Down Expand Up @@ -1116,6 +1165,7 @@ def plot_frame(data, x=None, y=None, kind='line', ax=None,
- 'density' : same as 'kde'
- 'area' : area plot
- 'pie' : pie plot
- 'scatter' : scatter plot
ax : matplotlib axes object
If not passed, uses gca()
x : label or position, default None
Expand Down
21 changes: 20 additions & 1 deletion databricks/koalas/tests/test_frame_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import matplotlib
from matplotlib import pyplot as plt
import pandas as pd
import numpy as np

from databricks import koalas
from databricks.koalas.exceptions import PandasNotImplementedError
Expand Down Expand Up @@ -173,10 +174,28 @@ def test_pie_plot_error_message(self):
error_message = "pie requires either y column or 'subplots=True'"
self.assertTrue(error_message in str(context.exception))

def test_scatter_plot(self):
# Use pandas scatter plot example
pdf = pd.DataFrame(np.random.rand(50, 4), columns=['a', 'b', 'c', 'd'])
kdf = koalas.from_pandas(pdf)

ax1 = pdf.plot.scatter(x='a', y='b')
ax2 = kdf.plot.scatter(x='a', y='b')
self.compare_plots(ax1, ax2)

ax1 = pdf.plot(kind='scatter', x='a', y='b')
ax2 = kdf.plot(kind='scatter', x='a', y='b')
self.compare_plots(ax1, ax2)

# check when keyword c is given as name of a column
ax1 = pdf.plot.scatter(x='a', y='b', c='c', s=50)
ax2 = kdf.plot.scatter(x='a', y='b', c='c', s=50)
self.compare_plots(ax1, ax2)

def test_missing(self):
ks = self.kdf1

unsupported_functions = ['box', 'density', 'hexbin', 'hist', 'kde', 'scatter']
unsupported_functions = ['box', 'density', 'hexbin', 'hist', 'kde']

for name in unsupported_functions:
with self.assertRaisesRegex(PandasNotImplementedError,
Expand Down
1 change: 1 addition & 0 deletions docs/source/reference/frame.rst
Original file line number Diff line number Diff line change
Expand Up @@ -237,3 +237,4 @@ specific plotting methods of the form ``DataFrame.plot.<kind>``.
DataFrame.plot.barh
DataFrame.plot.bar
DataFrame.plot.pie
DataFrame.plot.scatter

0 comments on commit a1efa61

Please sign in to comment.