Skip to content

Commit

Permalink
Fix fill color checks. wireservice#26
Browse files Browse the repository at this point in the history
  • Loading branch information
nbedi committed Jun 10, 2016
1 parent c423a09 commit d268298
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
14 changes: 10 additions & 4 deletions leather/shapes/grouped_bars.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ class GroupedBars(CategoryShape):
Render a categorized series of data as grouped bars.
:param fill_color:
A sequence of colors to fill the bars. If the sequence is shorter than
the number of values in any category, the colors will be repeated.
A sequence of colors to fill the bars. The sequence must have length
greater than or equal to the number of values in any category.
"""
def __init__(self, fill_color=None):
self._fill_color = fill_color
Expand All @@ -26,8 +26,11 @@ def validate_series(self, series):
"""
Verify this shape can be used to render a given series.
"""
if issequence(self._fill_color) and len(series.categories()) > len(self._fill_color.keys()):
raise ValueError('fill_color must have an element for every category in the series.')
if len(series.categories()) > len(self._fill_color):
raise ValueError('Fill color must have an element for every category in the series.')

if isinstance(series, CategorySeries):
raise ValueError('GroupedBars can only be used to render CategorySeries.')

def to_svg(self, width, height, x_scale, y_scale, series, palette):
"""
Expand All @@ -43,6 +46,9 @@ def to_svg(self, width, height, x_scale, y_scale, series, palette):
else:
fill_color = list(palette)

if len(series.categories()) > len(fill_color):
raise ValueError('Fill color must have an element for every category in the series.')

label_colors = self.legend_labels(series, fill_color)

categories = series.categories()
Expand Down
6 changes: 6 additions & 0 deletions tests/test_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,12 @@ def test_to_svg(self):
self.assertEqual(float(rects[1].get('width')), 100)
self.assertEqual(rects[1].get('fill'), 'white')

def test_invalid_fill_color(self):
series = leather.CategorySeries(self.rows)

with self.assertRaises(ValueError):
group = self.shape.to_svg(200, 100, self.linear, self.ordinal, series, ['one', 'two'])

def test_nulls(self):
series = leather.CategorySeries([
(0, 'foo', 'first'),
Expand Down

0 comments on commit d268298

Please sign in to comment.