Skip to content

Commit

Permalink
Fix up color handling and allow style functions. wireservice#26
Browse files Browse the repository at this point in the history
  • Loading branch information
nbedi committed Jun 13, 2016
1 parent 933386f commit 8e825f1
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 17 deletions.
20 changes: 16 additions & 4 deletions leather/shapes/category.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from leather import theme
from leather.shapes.base import Shape
from leather.utils import issequence


class CategoryShape(Shape):
Expand All @@ -26,11 +27,22 @@ def legend_labels(self, series, palette):
seen = set()
legend_values = [v for v in series.values(self._legend_dimension) if v not in seen and not seen.add(v)]

colors = list(palette)
color_count = len(colors)
if issequence(palette):
colors = list(palette)
color_count = len(colors)

for i, value in enumerate(legend_values):
label_colors.append((value, colors[i % color_count]))
for i, value in enumerate(legend_values):
if i >= color_count:
raise ValueError('Fill color must have length greater than or equal to the number of unique values in all categories.')

label_colors.append((value, colors[i]))

elif callable(palette):
# TODO
label_colors = []

else:
raise ValueError('Fill color must be a sequence of strings or a style function.')

return label_colors

Expand Down
21 changes: 9 additions & 12 deletions leather/shapes/grouped_bars.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from leather.series import CategorySeries
from leather.shapes.category import CategoryShape
from leather.utils import issequence, Y, Z
from leather.utils import Y, Z
from leather import theme


Expand All @@ -16,22 +16,17 @@ class GroupedBars(CategoryShape):
:param fill_color:
A sequence of colors to fill the bars. The sequence must have length
greater than or equal to the number of values in any category.
greater than or equal to the number of unique values in all categories.
You may also specify a :func:`.style_function`.
"""
def __init__(self, fill_color=None):
if fill_color and not issequence(fill_color):
raise ValueError('Fill color must be a sequence of strings.')

self._fill_color = fill_color
self._legend_dimension = Y

def validate_series(self, series):
"""
Verify this shape can be used to render a given series.
"""
if len(series.categories()) > len(self._fill_color):
raise ValueError('Fill color must have an element for every category in the series.')

if isinstance(series, CategorySeries):
raise ValueError('GroupedBars can only be used to render CategorySeries.')

Expand All @@ -49,9 +44,6 @@ def to_svg(self, width, height, x_scale, y_scale, series, palette):
else:
fill_color = list(palette)

if len(series.categories()) > len(fill_color):
raise ValueError('Fill color must have an element for every category in the series.')

label_colors = self.legend_labels(series, fill_color)

categories = series.categories()
Expand All @@ -78,7 +70,12 @@ def to_svg(self, width, height, x_scale, y_scale, series, palette):
bar_x = zero_x
bar_width = proj_x - zero_x

color = dict(label_colors)[d.y]
if callable(fill_color):
color = fill_color(d)
print(color)
else:
color = dict(label_colors)[d.y]

seen_counts[d.z] += 1

group.append(ET.Element('rect',
Expand Down
22 changes: 21 additions & 1 deletion tests/test_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def setUp(self):
self.shape = leather.GroupedBars()
self.linear = leather.Linear(0, 10)
self.ordinal = leather.Ordinal(['first', 'second', 'third'])
self.palette = (color for color in ['red', 'white', 'blue'])
self.palette = (color for color in ['red', 'white', 'blue', 'yellow'])
self.rows = [
(1, 'foo', 'first'),
(5, 'bar', 'first'),
Expand Down Expand Up @@ -203,6 +203,26 @@ def test_invalid_fill_color(self):

with self.assertRaises(ValueError):
shape = leather.GroupedBars('red')
shape.to_svg(100, 100, self.linear, self.ordinal, series, self.palette)

def test_style_function(self):
def color_code(d):
if d.y == 'foo':
return 'green'
else:
return 'black'

shape = leather.GroupedBars(color_code)
series = leather.CategorySeries(self.rows)

group = shape.to_svg(200, 100, self.linear, self.ordinal, series, self.palette)
rects = list(group)

self.assertEqual(rects[0].get('fill'), 'green')
self.assertEqual(rects[1].get('fill'), 'black')
self.assertEqual(rects[2].get('fill'), 'green')
self.assertEqual(rects[3].get('fill'), 'black')
self.assertEqual(rects[4].get('fill'), 'green')

def test_nulls(self):
series = leather.CategorySeries([
Expand Down

0 comments on commit 8e825f1

Please sign in to comment.