diff --git a/src/python/nimbusml/pipeline.py b/src/python/nimbusml/pipeline.py index 9d13b5e1..692e1dea 100644 --- a/src/python/nimbusml/pipeline.py +++ b/src/python/nimbusml/pipeline.py @@ -1913,9 +1913,10 @@ def _extract_classes(self, y): self._add_classes(unique_classes) def _extract_classes_from_headers(self, headers): - classes = [x.replace('Score.', '') for x in headers] - classes = np.array(classes).astype(self.last_node.classes_.dtype) - self._add_classes(classes) + if hasattr(self.last_node, 'classes_'): + classes = [x.replace('Score.', '') for x in headers] + classes = np.array(classes).astype(self.last_node.classes_.dtype) + self._add_classes(classes) def _add_classes(self, classes): # Create classes_ attribute similar to scikit diff --git a/src/python/nimbusml/tests/pipeline/test_predict_proba_decision_function.py b/src/python/nimbusml/tests/pipeline/test_predict_proba_decision_function.py index 5da37807..f6cc1c70 100644 --- a/src/python/nimbusml/tests/pipeline/test_predict_proba_decision_function.py +++ b/src/python/nimbusml/tests/pipeline/test_predict_proba_decision_function.py @@ -120,7 +120,7 @@ def test_pass_predict_proba_multiclass_3class(self): s, 38.0, decimal=4, - err_msg=invalid_decision_function_output) + err_msg=invalid_predict_proba_output) assert_equal(set(clf.classes_), {'Blue', 'Green', 'Red'}) def test_pass_predict_proba_multiclass_with_pipeline_adds_classes(self): @@ -137,7 +137,7 @@ def test_pass_predict_proba_multiclass_with_pipeline_adds_classes(self): s, 38.0, decimal=4, - err_msg=invalid_decision_function_output) + err_msg=invalid_predict_proba_output) assert_equal(set(clf.classes_), expected_classes) assert_equal(set(pipeline.classes_), expected_classes) @@ -150,9 +150,34 @@ def test_pass_predict_proba_multiclass_3class_retains_classes_type(self): s, 38.0, decimal=4, - err_msg=invalid_decision_function_output) + err_msg=invalid_predict_proba_output) assert_equal(set(clf.classes_), {0, 1, 2}) + def test_predict_proba_multiclass_3class_no_y_input_implies_no_classes_attribute(self): + X_train = X_train_3class_int.join(y_train_3class_int) + X_test = X_test_3class_int.join(y_test_3class_int) + + clf = FastLinearClassifier(number_of_threads=1, label='Label') + clf.fit(X_train) + + if hasattr(clf, 'classes_'): + # The classes_ attribute is currently not supported + # when fitting when there is no y input specified. + self.fail("classes_ attribute not expected.") + + s = clf.predict_proba(X_test).sum() + assert_almost_equal( + s, + 38.0, + decimal=4, + err_msg=invalid_predict_proba_output) + + if hasattr(clf, 'classes_'): + # The classes_ attribute is currently not supported + # when predicting when there was no y input specified + # during fitting. + self.fail("classes_ attribute not expected.") + def test_fail_predict_proba_multiclass_with_pipeline(self): check_unsupported_predict_proba(self, Pipeline( [NaiveBayesClassifier()]), X_train, y_train, X_test) @@ -242,6 +267,31 @@ def test_pass_decision_function_multiclass_3class_retains_classes_type(self): err_msg=invalid_decision_function_output) assert_equal(set(clf.classes_), {0, 1, 2}) + def test_decision_function_multiclass_3class_no_y_input_implies_no_classes_attribute(self): + X_train = X_train_3class_int.join(y_train_3class_int) + X_test = X_test_3class_int.join(y_test_3class_int) + + clf = FastLinearClassifier(number_of_threads=1, label='Label') + clf.fit(X_train) + + if hasattr(clf, 'classes_'): + # The classes_ attribute is currently not supported + # when fitting when there is no y input specified. + self.fail("classes_ attribute not expected.") + + s = clf.decision_function(X_test).sum() + assert_almost_equal( + s, + 38.0, + decimal=4, + err_msg=invalid_decision_function_output) + + if hasattr(clf, 'classes_'): + # The classes_ attribute is currently not supported + # when predicting when there was no y input specified + # during fitting. + self.fail("classes_ attribute not expected.") + def test_fail_decision_function_multiclass(self): check_unsupported_decision_function( self, LogisticRegressionClassifier(), X_train, y_train, X_test)