Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unit tests for input check #383

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 31 additions & 7 deletions tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,37 @@ def test_timeout_func():
ret_timeout = int(test_timeout_func())
assert ret_timeout == 1


def test_init_default_scoring():
"""Assert that TPOT intitializes with the correct default scoring function"""

tpot_obj = TPOTRegressor()
assert tpot_obj.scoring_function == 'neg_mean_squared_error'

tpot_obj = TPOTClassifier()
assert tpot_obj.scoring_function == 'accuracy'

def test_invaild_score_warning():
"""Assert that the TPOT fit function raises a ValueError when the scoring metrics is not available in SCORERS"""
try:
tpot_obj = TPOTClassifier(scoring='balanced_accuray') # typo for balanced_accuracy
assert False
except ValueError:
pass
try:
tpot_obj = TPOTClassifier(scoring='balanced_accuracy') # correct one
assert True
except:
assert False

def test_invaild_dataset_warning():
"""Assert that the TPOT fit function raises a ValueError when dataset is not in right format"""
tpot_obj = TPOTClassifier(random_state=42, population_size=1, offspring_size=2, generations=1, verbosity=0)
bad_training_classes = training_classes.reshape((1, len(training_classes)))# common mistake in classes
try:
tpot_obj.fit(training_features ,bad_training_classes) # typo for balanced_accuracy
assert False
except ValueError:
pass

def test_init_max_time_mins():
"""Assert that the TPOT init stores max run time and sets generations to 1000000"""
Expand Down Expand Up @@ -185,14 +209,14 @@ def test_random_ind_2():
assert expected_code == export_pipeline(pipeline, tpot_obj.operators, tpot_obj._pset)

def test_score():
"""Assert that the TPOT score function raises a ValueError when no optimized pipeline exists"""
"""Assert that the TPOT score function raises a RuntimeError when no optimized pipeline exists"""

tpot_obj = TPOTClassifier()

try:
tpot_obj.score(testing_features, testing_classes)
assert False # Should be unreachable
except ValueError:
except RuntimeError:
pass


Expand Down Expand Up @@ -292,14 +316,14 @@ def isclose(a, b, rel_tol=1e-09, abs_tol=0.0):


def test_predict():
"""Assert that the TPOT predict function raises a ValueError when no optimized pipeline exists"""
"""Assert that the TPOT predict function raises a RuntimeError when no optimized pipeline exists"""

tpot_obj = TPOTClassifier()

try:
tpot_obj.predict(testing_features)
assert False # Should be unreachable
except ValueError:
except RuntimeError:
pass


Expand Down Expand Up @@ -445,13 +469,13 @@ def test_operators():


def test_export():
"""Assert that TPOT's export function throws a ValueError when no optimized pipeline exists"""
"""Assert that TPOT's export function throws a RuntimeError when no optimized pipeline exists"""
tpot_obj = TPOTClassifier()

try:
tpot_obj.export("test_export.py")
assert False # Should be unreachable
except ValueError:
except RuntimeError:
pass


Expand Down