Skip to content

Commit

Permalink
Merge pull request #383 from weixuanfu2016/unit_tests_for_input_check
Browse files Browse the repository at this point in the history
Unit tests for input check
  • Loading branch information
rhiever authored Mar 22, 2017
2 parents 00c1a20 + b49f1ac commit 2f7a399
Showing 1 changed file with 31 additions and 7 deletions.
38 changes: 31 additions & 7 deletions tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,37 @@ def test_timeout_func():
ret_timeout = int(test_timeout_func())
assert ret_timeout == 1


def test_init_default_scoring():
"""Assert that TPOT intitializes with the correct default scoring function"""

tpot_obj = TPOTRegressor()
assert tpot_obj.scoring_function == 'neg_mean_squared_error'

tpot_obj = TPOTClassifier()
assert tpot_obj.scoring_function == 'accuracy'

def test_invaild_score_warning():
"""Assert that the TPOT fit function raises a ValueError when the scoring metrics is not available in SCORERS"""
try:
tpot_obj = TPOTClassifier(scoring='balanced_accuray') # typo for balanced_accuracy
assert False
except ValueError:
pass
try:
tpot_obj = TPOTClassifier(scoring='balanced_accuracy') # correct one
assert True
except:
assert False

def test_invaild_dataset_warning():
"""Assert that the TPOT fit function raises a ValueError when dataset is not in right format"""
tpot_obj = TPOTClassifier(random_state=42, population_size=1, offspring_size=2, generations=1, verbosity=0)
bad_training_classes = training_classes.reshape((1, len(training_classes)))# common mistake in classes
try:
tpot_obj.fit(training_features ,bad_training_classes) # typo for balanced_accuracy
assert False
except ValueError:
pass

def test_init_max_time_mins():
"""Assert that the TPOT init stores max run time and sets generations to 1000000"""
Expand Down Expand Up @@ -185,14 +209,14 @@ def test_random_ind_2():
assert expected_code == export_pipeline(pipeline, tpot_obj.operators, tpot_obj._pset)

def test_score():
"""Assert that the TPOT score function raises a ValueError when no optimized pipeline exists"""
"""Assert that the TPOT score function raises a RuntimeError when no optimized pipeline exists"""

tpot_obj = TPOTClassifier()

try:
tpot_obj.score(testing_features, testing_classes)
assert False # Should be unreachable
except ValueError:
except RuntimeError:
pass


Expand Down Expand Up @@ -292,14 +316,14 @@ def isclose(a, b, rel_tol=1e-09, abs_tol=0.0):


def test_predict():
"""Assert that the TPOT predict function raises a ValueError when no optimized pipeline exists"""
"""Assert that the TPOT predict function raises a RuntimeError when no optimized pipeline exists"""

tpot_obj = TPOTClassifier()

try:
tpot_obj.predict(testing_features)
assert False # Should be unreachable
except ValueError:
except RuntimeError:
pass


Expand Down Expand Up @@ -445,13 +469,13 @@ def test_operators():


def test_export():
"""Assert that TPOT's export function throws a ValueError when no optimized pipeline exists"""
"""Assert that TPOT's export function throws a RuntimeError when no optimized pipeline exists"""
tpot_obj = TPOTClassifier()

try:
tpot_obj.export("test_export.py")
assert False # Should be unreachable
except ValueError:
except RuntimeError:
pass


Expand Down

0 comments on commit 2f7a399

Please sign in to comment.