Skip to content

Commit e908546

Browse files
authored
Unknown values handling type added to the lib interface (#21)
1 parent aa62e30 commit e908546

13 files changed

+170
-11
lines changed

CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# Changelog
22

3+
## 5.2.0
4+
- `UnknownValueHandlingType` enum added to the lib's public API
5+
36
## 5.1.2
47
- `ml_dataframe` 0.2.0 supported
58

autotest.sh

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
#!/bin/bash
2+
3+
pub run build_runner test -- -p vm

lib/ml_preprocessing.dart

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ export 'package:ml_linalg/norm.dart';
22
export 'package:ml_preprocessing/src/encoder/encode_as_integer_labels.dart';
33
export 'package:ml_preprocessing/src/encoder/encode_as_one_hot_labels.dart';
44
export 'package:ml_preprocessing/src/encoder/encoder.dart';
5+
export 'package:ml_preprocessing/src/encoder/unknown_value_handling_type.dart';
56
export 'package:ml_preprocessing/src/normalizer/normalize.dart';
67
export 'package:ml_preprocessing/src/normalizer/normalizer.dart';
78
export 'package:ml_preprocessing/src/pipeline/pipeline.dart';
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import 'package:ml_preprocessing/src/encoder/encoder_impl.dart';
22
import 'package:ml_preprocessing/src/encoder/encoder_type.dart';
33
import 'package:ml_preprocessing/src/encoder/series_encoder/series_encoder_factory_impl.dart';
4+
import 'package:ml_preprocessing/src/encoder/unknown_value_handling_type.dart';
45
import 'package:ml_preprocessing/src/pipeline/pipeable.dart';
56

67
/// A factory function to use label categorical data encoder in pipeline
@@ -9,12 +10,15 @@ PipeableOperatorFn encodeAsIntegerLabels({
910
Iterable<String> featureNames,
1011
String headerPrefix = '',
1112
String headerPostfix = '',
13+
UnknownValueHandlingType unknownValueHandlingType =
14+
defaultUnknownValueHandlingType,
1215
}) => (data, {dtype}) => EncoderImpl(
1316
data,
1417
EncoderType.label,
15-
SeriesEncoderFactoryImpl(),
18+
const SeriesEncoderFactoryImpl(),
1619
featureIds: features,
1720
featureNames: featureNames,
1821
encodedHeaderPostfix: headerPostfix,
1922
encodedHeaderPrefix: headerPrefix,
23+
unknownValueHandlingType: unknownValueHandlingType,
2024
);
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import 'package:ml_preprocessing/src/encoder/encoder_impl.dart';
22
import 'package:ml_preprocessing/src/encoder/encoder_type.dart';
33
import 'package:ml_preprocessing/src/encoder/series_encoder/series_encoder_factory_impl.dart';
4+
import 'package:ml_preprocessing/src/encoder/unknown_value_handling_type.dart';
45
import 'package:ml_preprocessing/src/pipeline/pipeable.dart';
56

67
/// A factory function to use `one hot` categorical data encoder in pipeline
@@ -9,12 +10,15 @@ PipeableOperatorFn encodeAsOneHotLabels({
910
Iterable<String> featureNames,
1011
String headerPrefix = '',
1112
String headerPostfix = '',
13+
UnknownValueHandlingType unknownValueHandlingType =
14+
defaultUnknownValueHandlingType,
1215
}) => (data, {dtype}) => EncoderImpl(
1316
data,
1417
EncoderType.oneHot,
15-
SeriesEncoderFactoryImpl(),
18+
const SeriesEncoderFactoryImpl(),
1619
featureIds: features,
1720
featureNames: featureNames,
1821
encodedHeaderPostfix: headerPostfix,
1922
encodedHeaderPrefix: headerPrefix,
23+
unknownValueHandlingType: unknownValueHandlingType,
2024
);

lib/src/encoder/encoder.dart

+8-1
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@ import 'package:ml_dataframe/ml_dataframe.dart';
22
import 'package:ml_preprocessing/src/encoder/encoder_impl.dart';
33
import 'package:ml_preprocessing/src/encoder/encoder_type.dart';
44
import 'package:ml_preprocessing/src/encoder/series_encoder/series_encoder_factory_impl.dart';
5+
import 'package:ml_preprocessing/src/encoder/unknown_value_handling_type.dart';
56
import 'package:ml_preprocessing/src/pipeline/pipeable.dart';
67

7-
final _seriesEncoderFactory = SeriesEncoderFactoryImpl();
8+
final _seriesEncoderFactory = const SeriesEncoderFactoryImpl();
89

910
/// Categorical data encoder factory
1011
abstract class Encoder implements Pipeable {
@@ -13,24 +14,30 @@ abstract class Encoder implements Pipeable {
1314
Iterable<String> featureNames,
1415
String headerPrefix,
1516
String headerPostfix,
17+
UnknownValueHandlingType unknownValueHandlingType =
18+
defaultUnknownValueHandlingType,
1619
}) => EncoderImpl(
1720
fittingData,
1821
EncoderType.oneHot,
1922
_seriesEncoderFactory,
2023
featureNames: featureNames,
2124
featureIds: featureIds,
25+
unknownValueHandlingType: unknownValueHandlingType,
2226
);
2327

2428
factory Encoder.label(DataFrame fittingData, {
2529
Iterable<int> featureIds,
2630
Iterable<String> featureNames,
2731
String headerPrefix,
2832
String headerPostfix,
33+
UnknownValueHandlingType unknownValueHandlingType =
34+
defaultUnknownValueHandlingType,
2935
}) => EncoderImpl(
3036
fittingData,
3137
EncoderType.label,
3238
_seriesEncoderFactory,
3339
featureNames: featureNames,
3440
featureIds: featureIds,
41+
unknownValueHandlingType: unknownValueHandlingType,
3542
);
3643
}

lib/src/encoder/encoder_impl.dart

+4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import 'package:ml_preprocessing/src/encoder/encoder_type.dart';
44
import 'package:ml_preprocessing/src/encoder/helpers/create_encoder_to_series_mapping.dart';
55
import 'package:ml_preprocessing/src/encoder/series_encoder/series_encoder.dart';
66
import 'package:ml_preprocessing/src/encoder/series_encoder/series_encoder_factory.dart';
7+
import 'package:ml_preprocessing/src/encoder/unknown_value_handling_type.dart';
78

89
class EncoderImpl implements Encoder {
910
EncoderImpl(
@@ -14,6 +15,8 @@ class EncoderImpl implements Encoder {
1415
Iterable<String> featureNames,
1516
String encodedHeaderPrefix = '',
1617
String encodedHeaderPostfix = '',
18+
UnknownValueHandlingType unknownValueHandlingType =
19+
defaultUnknownValueHandlingType,
1720
}) :
1821
_encoderBySeries = createEncoderToSeriesMapping(
1922
fittingData, featureNames, featureIds,
@@ -22,6 +25,7 @@ class EncoderImpl implements Encoder {
2225
series,
2326
headerPostfix: encodedHeaderPostfix,
2427
headerPrefix: encodedHeaderPrefix,
28+
unknownValueHandlingType: unknownValueHandlingType,
2529
));
2630

2731
final Map<String, SeriesEncoder> _encoderBySeries;
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import 'package:ml_dataframe/ml_dataframe.dart';
22
import 'package:ml_preprocessing/src/encoder/encoder_type.dart';
33
import 'package:ml_preprocessing/src/encoder/series_encoder/series_encoder.dart';
4+
import 'package:ml_preprocessing/src/encoder/unknown_value_handling_type.dart';
45

56
abstract class SeriesEncoderFactory {
67
SeriesEncoder createByType(EncoderType type, Series fittingData, {
78
String headerPrefix,
89
String headerPostfix,
10+
UnknownValueHandlingType unknownValueHandlingType,
911
});
10-
}
12+
}

lib/src/encoder/series_encoder/series_encoder_factory_impl.dart

+14-6
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,34 @@ import 'package:ml_preprocessing/src/encoder/series_encoder/label_series_encoder
44
import 'package:ml_preprocessing/src/encoder/series_encoder/one_hot_series_encoder.dart';
55
import 'package:ml_preprocessing/src/encoder/series_encoder/series_encoder.dart';
66
import 'package:ml_preprocessing/src/encoder/series_encoder/series_encoder_factory.dart';
7+
import 'package:ml_preprocessing/src/encoder/unknown_value_handling_type.dart';
78

89
class SeriesEncoderFactoryImpl implements SeriesEncoderFactory {
10+
const SeriesEncoderFactoryImpl();
11+
912
@override
1013
SeriesEncoder createByType(EncoderType type, Series fittingData, {
1114
String headerPrefix = '',
1215
String headerPostfix = '',
16+
UnknownValueHandlingType unknownValueHandlingType,
1317
}) {
1418
switch (type) {
1519
case EncoderType.label:
1620
return LabelSeriesEncoder(
17-
fittingData,
18-
headerPrefix: headerPrefix,
19-
headerPostfix: headerPostfix
21+
fittingData,
22+
headerPrefix: headerPrefix,
23+
headerPostfix: headerPostfix,
24+
unknownValueHandlingType: unknownValueHandlingType,
2025
);
26+
2127
case EncoderType.oneHot:
2228
return OneHotSeriesEncoder(
23-
fittingData,
24-
headerPrefix: headerPrefix,
25-
headerPostfix: headerPostfix
29+
fittingData,
30+
headerPrefix: headerPrefix,
31+
headerPostfix: headerPostfix,
32+
unknownValueHandlingType: unknownValueHandlingType,
2633
);
34+
2735
default:
2836
throw UnsupportedError('Unsupported encoder type - $type');
2937
}
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
enum UnknownValueHandlingType {
22
error, ignore,
33
}
4+
5+
const defaultUnknownValueHandlingType = UnknownValueHandlingType.ignore;

pubspec.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name: ml_preprocessing
22
description: Popular algorithms of data preprocessing for machine learning
3-
version: 5.1.2
3+
version: 5.2.0
44
author: Ilia Gyrdymov <[email protected]>
55
homepage: https://github.com/gyrdym/ml_preprocessing
66

test/encoder/encoder_impl_test.dart

+80
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import 'package:ml_dataframe/ml_dataframe.dart';
2+
import 'package:ml_preprocessing/ml_preprocessing.dart';
23
import 'package:ml_preprocessing/src/encoder/encoder.dart';
34
import 'package:test/test.dart';
45

@@ -15,6 +16,13 @@ void main() {
1516
[ 55, 'F', 'category_val_3', 10 ],
1617
];
1718

19+
final unseenData = [
20+
['first', 'second', 'third', 'fourth'],
21+
[ 1, 'F', 'category_val_5', 10 ],
22+
[ 10, 'F', 'category_val_2', 20 ],
23+
[ 11, 'M', 'category_val_6', 10 ],
24+
];
25+
1826
group('Encoder.oneHot', () {
1927
test('should encode multiple columns', () {
2028
final dataFrame = DataFrame(data);
@@ -53,6 +61,42 @@ void main() {
5361
[ 55, 1, 0, 0, 0, 1, 1, 0, 0, ],
5462
]));
5563
});
64+
65+
test('should throw error if unknown value handling type is error', () {
66+
final trainingDataFrame = DataFrame(data);
67+
final unseenDataDataframe = DataFrame(unseenData);
68+
final encoder = Encoder.oneHot(
69+
trainingDataFrame,
70+
featureNames: ['second', 'third', 'fourth'],
71+
unknownValueHandlingType: UnknownValueHandlingType.error,
72+
);
73+
final actual = () => encoder
74+
.process(unseenDataDataframe)
75+
.toMatrix();
76+
final expected = throwsException;
77+
78+
expect(actual, expected);
79+
});
80+
81+
test('should ignore unknown value if unknown value handling type is ignpre', () {
82+
final trainingDataFrame = DataFrame(data);
83+
final unseenDataDataframe = DataFrame(unseenData);
84+
final encoder = Encoder.oneHot(
85+
trainingDataFrame,
86+
featureNames: ['second', 'third', 'fourth'],
87+
unknownValueHandlingType: UnknownValueHandlingType.ignore,
88+
);
89+
final actual = encoder
90+
.process(unseenDataDataframe)
91+
.toMatrix();
92+
final expected = [
93+
[ 1, 1, 0, 0, 0, 0, 1, 0, 0, ],
94+
[ 10, 1, 0, 0, 1, 0, 0, 1, 0, ],
95+
[ 11, 0, 1, 0, 0, 0, 1, 0, 0, ],
96+
];
97+
98+
expect(actual, expected);
99+
});
56100
});
57101

58102
group('Encoder.label', () {
@@ -92,6 +136,42 @@ void main() {
92136
[ 55, 0, 2, 0, ],
93137
]));
94138
});
139+
140+
test('should throw error if unknown value handling type is error', () {
141+
final trainingDataFrame = DataFrame(data);
142+
final unseenDataDataframe = DataFrame(unseenData);
143+
final encoder = Encoder.label(
144+
trainingDataFrame,
145+
featureNames: ['second', 'third', 'fourth'],
146+
unknownValueHandlingType: UnknownValueHandlingType.error,
147+
);
148+
final actual = () => encoder
149+
.process(unseenDataDataframe)
150+
.toMatrix();
151+
final expected = throwsException;
152+
153+
expect(actual, expected);
154+
});
155+
156+
test('should ignore unknown value if unknown value handling type is ignpre', () {
157+
final trainingDataFrame = DataFrame(data);
158+
final unseenDataDataframe = DataFrame(unseenData);
159+
final encoder = Encoder.label(
160+
trainingDataFrame,
161+
featureNames: ['second', 'third', 'fourth'],
162+
unknownValueHandlingType: UnknownValueHandlingType.ignore,
163+
);
164+
final actual = encoder
165+
.process(unseenDataDataframe)
166+
.toMatrix();
167+
final expected = [
168+
[ 1, 0, 3, 0, ],
169+
[ 10, 0, 1, 1, ],
170+
[ 11, 1, 3, 0, ],
171+
];
172+
173+
expect(actual, expected);
174+
});
95175
});
96176
});
97177
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import 'package:ml_dataframe/ml_dataframe.dart';
2+
import 'package:ml_preprocessing/src/encoder/encoder_type.dart';
3+
import 'package:ml_preprocessing/src/encoder/series_encoder/label_series_encoder.dart';
4+
import 'package:ml_preprocessing/src/encoder/series_encoder/one_hot_series_encoder.dart';
5+
import 'package:ml_preprocessing/src/encoder/series_encoder/series_encoder_factory_impl.dart';
6+
import 'package:test/test.dart';
7+
8+
void main() {
9+
group('SeriesEncoderFactoryImpl', () {
10+
final factory = const SeriesEncoderFactoryImpl();
11+
final series = Series(
12+
'some_series',
13+
<String>['value_1', 'value_2', 'value_3'],
14+
isDiscrete: true,
15+
);
16+
17+
test('should create LabelSeriesEncoder', () {
18+
final encoderType = EncoderType.label;
19+
final actual = factory.createByType(encoderType, series);
20+
final expected = isA<LabelSeriesEncoder>();
21+
22+
expect(actual, expected);
23+
});
24+
25+
test('should create OneHotSeriesEncoder', () {
26+
final encoderType = EncoderType.oneHot;
27+
final actual = factory.createByType(encoderType, series);
28+
final expected = isA<OneHotSeriesEncoder>();
29+
30+
expect(actual, expected);
31+
});
32+
33+
test('should throw exception if null is unknown encoder type is '
34+
'provided', () {
35+
final actual = () => factory.createByType(null, series);
36+
final expected = throwsUnsupportedError;
37+
38+
expect(actual, expected);
39+
});
40+
});
41+
}

0 commit comments

Comments
 (0)