diff --git a/README.md b/README.md index 1ec02997..fd3225fe 100644 --- a/README.md +++ b/README.md @@ -41,7 +41,7 @@ pip install m2cgen | | Classification | Regression | | --- | --- | --- | -| **Linear** | | | +| **Linear** | | | | **SVM** | | | | **Tree** | | | | **Random Forest** | | | diff --git a/m2cgen/assemblers/__init__.py b/m2cgen/assemblers/__init__.py index 8e62f3c3..577492ea 100644 --- a/m2cgen/assemblers/__init__.py +++ b/m2cgen/assemblers/__init__.py @@ -60,10 +60,11 @@ "SGDRegressor": LinearModelAssembler, "TheilSenRegressor": LinearModelAssembler, - # Logistic Regressors + # Linear Classifiers "LogisticRegression": LinearModelAssembler, "LogisticRegressionCV": LinearModelAssembler, "PassiveAggressiveClassifier": LinearModelAssembler, + "Perceptron": LinearModelAssembler, "RidgeClassifier": LinearModelAssembler, "RidgeClassifierCV": LinearModelAssembler, "SGDClassifier": LinearModelAssembler, diff --git a/tests/e2e/test_e2e.py b/tests/e2e/test_e2e.py index c8c96651..47d7203a 100644 --- a/tests/e2e/test_e2e.py +++ b/tests/e2e/test_e2e.py @@ -218,11 +218,15 @@ def classification_binary_random(model): regression(linear_model.SGDRegressor(random_state=RANDOM_SEED)), regression(linear_model.TheilSenRegressor(random_state=RANDOM_SEED)), - # Logistic Regression + # Linear Classifiers classification(linear_model.LogisticRegression( random_state=RANDOM_SEED)), classification(linear_model.LogisticRegressionCV( random_state=RANDOM_SEED)), + classification(linear_model.PassiveAggressiveClassifier( + random_state=RANDOM_SEED)), + classification(linear_model.Perceptron( + random_state=RANDOM_SEED)), classification(linear_model.RidgeClassifier(random_state=RANDOM_SEED)), classification(linear_model.RidgeClassifierCV()), classification(linear_model.SGDClassifier(random_state=RANDOM_SEED)), @@ -231,6 +235,10 @@ def classification_binary_random(model): random_state=RANDOM_SEED)), classification_binary(linear_model.LogisticRegressionCV( random_state=RANDOM_SEED)), + classification_binary(linear_model.PassiveAggressiveClassifier( + random_state=RANDOM_SEED)), + classification_binary(linear_model.Perceptron( + random_state=RANDOM_SEED)), classification_binary(linear_model.RidgeClassifier( random_state=RANDOM_SEED)), classification_binary(linear_model.RidgeClassifierCV()),