Skip to content

Commit 6ef5884

Browse files
committed
feat: Mean() works with TimeDelta(), #761
1 parent 469ee27 commit 6ef5884

File tree

3 files changed

+34
-7
lines changed

3 files changed

+34
-7
lines changed

CHANGELOG.rst

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
-----
33

44
* feat: Lowercase the ``null_values`` provided to individual data types, since all comparisons to ``null_values`` are case-insensitive. (#770)
5+
* feat: :class:`.Mean` works with :class:`.TimeDelta`. (#761)
56
* fix: Allow consecutive calls to :meth:`.Table.group_by`. (#765)
67

78
1.7.1 - Jan 4, 2023

agate/aggregations/mean.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from agate.aggregations.base import Aggregation
22
from agate.aggregations.has_nulls import HasNulls
33
from agate.aggregations.sum import Sum
4-
from agate.data_types import Number
4+
from agate.data_types import Number, TimeDelta
55
from agate.exceptions import DataTypeError
66
from agate.warns import warn_null_calculation
77

@@ -18,13 +18,16 @@ def __init__(self, column_name):
1818
self._sum = Sum(column_name)
1919

2020
def get_aggregate_data_type(self, table):
21-
return Number()
21+
column = table.columns[self._column_name]
22+
23+
if isinstance(column.data_type, (Number, TimeDelta)):
24+
return column.data_type
2225

2326
def validate(self, table):
2427
column = table.columns[self._column_name]
2528

26-
if not isinstance(column.data_type, Number):
27-
raise DataTypeError('Mean can only be applied to columns containing Number data.')
29+
if not isinstance(column.data_type, (Number, TimeDelta)):
30+
raise DataTypeError('Mean can only be applied to columns containing Number or TimeDelta data.')
2831

2932
has_nulls = HasNulls(self._column_name).run(table)
3033

tests/test_aggregations.py

+26-3
Original file line numberDiff line numberDiff line change
@@ -184,11 +184,13 @@ def setUp(self):
184184
self.table = Table(self.rows, ['test', 'null'], [DateTime(), DateTime()])
185185

186186
self.time_delta_rows = [
187-
[datetime.timedelta(seconds=10), None],
188-
[datetime.timedelta(seconds=20), None],
187+
[datetime.timedelta(seconds=10), datetime.timedelta(seconds=15), None],
188+
[datetime.timedelta(seconds=20), None, None],
189189
]
190190

191-
self.time_delta_table = Table(self.time_delta_rows, ['test', 'null'], [TimeDelta(), TimeDelta()])
191+
self.time_delta_table = Table(
192+
self.time_delta_rows, ['test', 'mixed', 'null'], [TimeDelta(), TimeDelta(), TimeDelta()]
193+
)
192194

193195
def test_min(self):
194196
self.assertIsInstance(Min('test').get_aggregate_data_type(self.table), DateTime)
@@ -216,6 +218,27 @@ def test_max_time_delta(self):
216218
Max('test').validate(self.time_delta_table)
217219
self.assertEqual(Max('test').run(self.time_delta_table), datetime.timedelta(0, 20))
218220

221+
def test_mean(self):
222+
with self.assertWarns(NullCalculationWarning):
223+
Mean('mixed').validate(self.time_delta_table)
224+
225+
Mean('test').validate(self.time_delta_table)
226+
227+
self.assertEqual(Mean('test').run(self.time_delta_table), datetime.timedelta(seconds=15))
228+
229+
def test_mean_all_nulls(self):
230+
self.assertIsNone(Mean('null').run(self.time_delta_table))
231+
232+
def test_mean_with_nulls(self):
233+
warnings.simplefilter('ignore')
234+
235+
try:
236+
Mean('mixed').validate(self.time_delta_table)
237+
finally:
238+
warnings.resetwarnings()
239+
240+
self.assertAlmostEqual(Mean('mixed').run(self.time_delta_table), datetime.timedelta(seconds=15))
241+
219242
def test_sum(self):
220243
self.assertIsInstance(Sum('test').get_aggregate_data_type(self.time_delta_table), TimeDelta)
221244
Sum('test').validate(self.time_delta_table)

0 commit comments

Comments
 (0)