Skip to content

Commit 720e356

Browse files
authored
Support spark dataframe as input dataset and spark models as estimators (microsoft#934)
* add basic support to Spark dataframe add support to SynapseML LightGBM model update to pyspark>=3.2.0 to leverage pandas_on_Spark API * clean code, add TODOs * add sample_train_data for pyspark.pandas dataframe, fix bugs * improve some functions, fix bugs * fix dict change size during iteration * update model predict * update LightGBM model, update test * update SynapseML LightGBM params * update synapseML and tests * update TODOs * Added support to roc_auc for spark models * Added support to score of spark estimator * Added test for automl score of spark estimator * Added cv support to pyspark.pandas dataframe * Update test, fix bugs * Added tests * Updated docs, tests, added a notebook * Fix bugs in non-spark env * Fix bugs and improve tests * Fix uninstall pyspark * Fix tests error * Fix java.lang.OutOfMemoryError: Java heap space * Fix test_performance * Update test_sparkml to test_0sparkml to use the expected spark conf * Remove unnecessary widgets in notebook * Fix iloc java.lang.StackOverflowError * fix pre-commit * Added params check for spark dataframes * Refactor code for train_test_split to a function * Update train_test_split_pyspark * Refactor if-else, remove unnecessary code * Remove y from predict, remove mem control from n_iter compute * Update workflow * Improve _split_pyspark * Fix test failure of too short training time * Fix typos, improve docstrings * Fix index errors of pandas_on_spark, add spark loss metric * Fix typo of ndcgAtK * Update NDCG metrics and tests * Remove unuseful logger * Use cache and count to ensure consistent indexes * refactor for merge maain * fix errors of refactor * Updated SparkLightGBMEstimator and cache * Updated config2params * Remove unused import * Fix unknown parameters * Update default_estimator_list * Add unit tests for spark metrics
1 parent f33cb42 commit 720e356

24 files changed

+3017
-235
lines changed

.github/workflows/python-package.yml

+10-9
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ jobs:
2525
matrix:
2626
os: [ubuntu-latest, macos-latest, windows-2019]
2727
python-version: ["3.7", "3.8", "3.9", "3.10"]
28-
2928
steps:
3029
- uses: actions/checkout@v3
3130
- name: Set up Python ${{ matrix.python-version }}
@@ -45,21 +44,18 @@ jobs:
4544
export CFLAGS="$CFLAGS -I/usr/local/opt/libomp/include"
4645
export CXXFLAGS="$CXXFLAGS -I/usr/local/opt/libomp/include"
4746
export LDFLAGS="$LDFLAGS -Wl,-rpath,/usr/local/opt/libomp/lib -L/usr/local/opt/libomp/lib -lomp"
48-
- name: On Linux, install Spark stand-alone cluster and PySpark
49-
if: matrix.os == 'ubuntu-latest'
47+
- name: On Linux + python 3.8, install pyspark 3.2.3
48+
if: matrix.os == 'ubuntu-latest' && matrix.python-version == '3.8'
5049
run: |
51-
sudo apt-get update && sudo apt-get install -y --allow-downgrades --allow-change-held-packages --no-install-recommends ca-certificates-java ca-certificates openjdk-17-jdk-headless && sudo apt-get clean && sudo rm -rf /var/lib/apt/lists/*
52-
wget --progress=dot:giga "https://www.apache.org/dyn/closer.lua/spark/spark-3.3.0/spark-3.3.0-bin-hadoop2.tgz?action=download" -O - | tar -xzC /tmp; archive=$(basename "spark-3.3.0/spark-3.3.0-bin-hadoop2.tgz") bash -c "sudo mv -v /tmp/\${archive/%.tgz/} /spark"
53-
pip install --no-cache-dir pyspark>=3.0
54-
export SPARK_HOME=/spark
55-
export PYTHONPATH=/spark/python/lib/py4j-0.10.9.5-src.zip:/spark/python
56-
export PATH=$PATH:$SPARK_HOME/bin
50+
python -m pip install --upgrade pip wheel
51+
pip install pyspark==3.2.3
5752
- name: Install packages and dependencies
5853
run: |
5954
python -m pip install --upgrade pip wheel
6055
pip install -e .
6156
python -c "import flaml"
6257
pip install -e .[test]
58+
pip list | grep "pyspark"
6359
- name: If linux, install ray 2
6460
if: matrix.os == 'ubuntu-latest'
6561
run: |
@@ -76,6 +72,11 @@ jobs:
7672
if: matrix.python-version != '3.10'
7773
run: |
7874
pip install -e .[vw]
75+
- name: Uninstall pyspark on python 3.9
76+
if: matrix.python-version == '3.9'
77+
run: |
78+
# Uninstall pyspark to test env without pyspark
79+
pip uninstall -y pyspark
7980
- name: Lint with flake8
8081
run: |
8182
# stop the build if there are Python syntax errors or undefined names

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -159,3 +159,6 @@ automl.pkl
159159

160160
test/nlp/testtmp.py
161161
test/nlp/testtmpfl.py
162+
163+
flaml/tune/spark/mylearner.py
164+
*.pkl

flaml/automl/automl.py

+50-7
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import os
88
import sys
99
from typing import Callable, List, Union, Optional
10-
import inspect
1110
from functools import partial
1211
import numpy as np
1312
from sklearn.base import BaseEstimator
@@ -17,7 +16,6 @@
1716

1817
from flaml.automl.state import SearchState, AutoMLState
1918
from flaml.automl.ml import (
20-
compute_estimator,
2119
train_estimator,
2220
get_estimator_class,
2321
)
@@ -31,7 +29,6 @@
3129
N_SPLITS,
3230
SAMPLE_MULTIPLY_FACTOR,
3331
)
34-
from flaml.automl.data import concat
3532

3633
# TODO check to see when we can remove these
3734
from flaml.automl.task.task import CLASSIFICATION, TS_FORECAST, Task
@@ -43,6 +40,34 @@
4340
from flaml.version import __version__ as flaml_version
4441
from flaml.tune.spark.utils import check_spark, get_broadcast_data
4542

43+
try:
44+
from flaml.automl.spark.utils import (
45+
train_test_split_pyspark,
46+
unique_pandas_on_spark,
47+
len_labels,
48+
unique_value_first_index,
49+
)
50+
except ImportError:
51+
train_test_split_pyspark = None
52+
unique_pandas_on_spark = None
53+
from flaml.automl.utils import (
54+
len_labels,
55+
unique_value_first_index,
56+
)
57+
try:
58+
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
59+
import pyspark.pandas as ps
60+
from pyspark.pandas import DataFrame as psDataFrame, Series as psSeries
61+
from pyspark.pandas.config import set_option, reset_option
62+
except ImportError:
63+
ps = None
64+
65+
class psDataFrame:
66+
pass
67+
68+
class psSeries:
69+
pass
70+
4671

4772
try:
4873
import mlflow
@@ -511,7 +536,12 @@ def time_to_find_best_model(self) -> float:
511536
"""Time taken to find best model in seconds."""
512537
return self.__dict__.get("_time_taken_best_iter")
513538

514-
def score(self, X: pd.DataFrame, y: pd.Series, **kwargs):
539+
def score(
540+
self,
541+
X: Union[pd.DataFrame, psDataFrame],
542+
y: Union[pd.Series, psSeries],
543+
**kwargs,
544+
):
515545
estimator = getattr(self, "_trained_estimator", None)
516546
if estimator is None:
517547
logger.warning(
@@ -525,13 +555,14 @@ def score(self, X: pd.DataFrame, y: pd.Series, **kwargs):
525555

526556
def predict(
527557
self,
528-
X: Union[np.array, pd.DataFrame, List[str], List[List[str]]],
558+
X: Union[np.array, pd.DataFrame, List[str], List[List[str]], psDataFrame],
529559
**pred_kwargs,
530560
):
531561
"""Predict label from features.
532562
533563
Args:
534-
X: A numpy array of featurized instances, shape n * m,
564+
X: A numpy array or pandas dataframe or pyspark.pandas dataframe
565+
of featurized instances, shape n * m,
535566
or for time series forcast tasks:
536567
a pandas dataframe with the first column containing
537568
timestamp values (datetime type) or an integer n for
@@ -1859,7 +1890,19 @@ def is_to_reverse_metric(metric, task):
18591890
error_metric = "customized metric"
18601891
logger.info(f"Minimizing error metric: {error_metric}")
18611892

1862-
estimator_list = task.default_estimator_list(estimator_list)
1893+
is_spark_dataframe = isinstance(X_train, psDataFrame) or isinstance(
1894+
dataframe, psDataFrame
1895+
)
1896+
estimator_list = task.default_estimator_list(estimator_list, is_spark_dataframe)
1897+
1898+
if is_spark_dataframe and self._use_spark:
1899+
# For spark dataframe, use_spark must be False because spark models are trained in parallel themselves
1900+
self._use_spark = False
1901+
logger.warning(
1902+
"Spark dataframes support only spark.ml type models, which will be trained "
1903+
"with spark themselves, no need to start spark trials in flaml. "
1904+
"`use_spark` is set to False."
1905+
)
18631906

18641907
# When no search budget is specified
18651908
if no_budget:

flaml/automl/data.py

+32
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,22 @@
1212
from datetime import datetime
1313
from typing import TYPE_CHECKING, Union
1414

15+
import os
16+
17+
try:
18+
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
19+
import pyspark.pandas as ps
20+
from pyspark.pandas import DataFrame as psDataFrame, Series as psSeries
21+
except ImportError:
22+
ps = None
23+
24+
class psDataFrame:
25+
pass
26+
27+
class psSeries:
28+
pass
29+
30+
1531
if TYPE_CHECKING:
1632
from flaml.automl.task import Task
1733

@@ -198,6 +214,15 @@ def get_output_from_log(filename, time_budget):
198214

199215
def concat(X1, X2):
200216
"""concatenate two matrices vertically."""
217+
if type(X1) != type(X2):
218+
if isinstance(X2, (psDataFrame, psSeries)):
219+
X1 = ps.from_pandas(pd.DataFrame(X1))
220+
elif isinstance(X1, (psDataFrame, psSeries)):
221+
X2 = ps.from_pandas(pd.DataFrame(X2))
222+
else:
223+
X1 = pd.DataFrame(X1)
224+
X2 = pd.DataFrame(X2)
225+
201226
if isinstance(X1, (DataFrame, Series)):
202227
df = pd.concat([X1, X2], sort=False)
203228
df.reset_index(drop=True, inplace=True)
@@ -206,6 +231,13 @@ def concat(X1, X2):
206231
if len(cat_columns):
207232
df[cat_columns] = df[cat_columns].astype("category")
208233
return df
234+
if isinstance(X1, (psDataFrame, psSeries)):
235+
df = ps.concat([X1, X2], ignore_index=True)
236+
if isinstance(X1, psDataFrame):
237+
cat_columns = X1.select_dtypes(include="category").columns.values.tolist()
238+
if len(cat_columns):
239+
df[cat_columns] = df[cat_columns].astype("category")
240+
return df
209241
if issparse(X1):
210242
return vstack((X1, X2))
211243
else:

flaml/automl/ml.py

+39-12
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# * Copyright (c) FLAML authors. All rights reserved.
33
# * Licensed under the MIT License. See LICENSE file in the
44
# * project root for license information.
5+
import os
56
import time
67
import numpy as np
78
import pandas as pd
@@ -19,12 +20,6 @@
1920
mean_absolute_percentage_error,
2021
ndcg_score,
2122
)
22-
from sklearn.model_selection import (
23-
RepeatedStratifiedKFold,
24-
GroupKFold,
25-
TimeSeriesSplit,
26-
StratifiedGroupKFold,
27-
)
2823
from flaml.automl.model import (
2924
XGBoostSklearnEstimator,
3025
XGBoost_TS,
@@ -46,14 +41,33 @@
4641
TransformersEstimator,
4742
TemporalFusionTransformerEstimator,
4843
TransformersEstimatorModelSelection,
44+
SparkLGBMEstimator,
4945
)
5046
from flaml.automl.data import group_counts
5147
from flaml.automl.task.task import TS_FORECAST, Task
5248
from flaml.automl.model import BaseEstimator
5349

54-
import logging
50+
try:
51+
from flaml.automl.spark.utils import len_labels
52+
except ImportError:
53+
from flaml.automl.utils import len_labels
54+
try:
55+
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
56+
from pyspark.sql.functions import col
57+
import pyspark.pandas as ps
58+
from pyspark.pandas import DataFrame as psDataFrame, Series as psSeries
59+
from flaml.automl.spark.utils import to_pandas_on_spark, iloc_pandas_on_spark
60+
from flaml.automl.spark.metrics import spark_metric_loss_score
61+
except ImportError:
62+
ps = None
63+
64+
class psDataFrame:
65+
pass
66+
67+
class psSeries:
68+
pass
69+
5570

56-
logger = logging.getLogger(__name__)
5771
EstimatorSubclass = TypeVar("EstimatorSubclass", bound=BaseEstimator)
5872

5973
sklearn_metric_name_set = {
@@ -124,6 +138,8 @@ def get_estimator_class(task: str, estimator_name: str) -> EstimatorSubclass:
124138
estimator_class = RF_TS if task in TS_FORECAST else RandomForestEstimator
125139
elif "lgbm" == estimator_name:
126140
estimator_class = LGBM_TS if task in TS_FORECAST else LGBMEstimator
141+
elif "lgbm_spark" == estimator_name:
142+
estimator_class = SparkLGBMEstimator
127143
elif "lrl1" == estimator_name:
128144
estimator_class = LRL1Classifier
129145
elif "lrl2" == estimator_name:
@@ -163,7 +179,15 @@ def metric_loss_score(
163179
groups=None,
164180
):
165181
# y_processed_predict and y_processed_true are processed id labels if the original were the token labels
166-
if is_in_sklearn_metric_name_set(metric_name):
182+
if isinstance(y_processed_predict, (psDataFrame, psSeries)):
183+
return spark_metric_loss_score(
184+
metric_name,
185+
y_processed_predict,
186+
y_processed_true,
187+
sample_weight,
188+
groups,
189+
)
190+
elif is_in_sklearn_metric_name_set(metric_name):
167191
return sklearn_metric_loss_score(
168192
metric_name,
169193
y_processed_predict,
@@ -359,7 +383,10 @@ def sklearn_metric_loss_score(
359383
def get_y_pred(estimator, X, eval_metric, task: Task):
360384
if eval_metric in ["roc_auc", "ap", "roc_auc_weighted"] and task.is_binary():
361385
y_pred_classes = estimator.predict_proba(X)
362-
y_pred = y_pred_classes[:, 1] if y_pred_classes.ndim > 1 else y_pred_classes
386+
if isinstance(y_pred_classes, (psSeries, psDataFrame)):
387+
y_pred = y_pred_classes
388+
else:
389+
y_pred = y_pred_classes[:, 1] if y_pred_classes.ndim > 1 else y_pred_classes
363390
elif eval_metric in [
364391
"log_loss",
365392
"roc_auc",
@@ -525,7 +552,7 @@ def compute_estimator(
525552
fit_kwargs: Optional[dict] = None,
526553
free_mem_ratio=0,
527554
):
528-
if not fit_kwargs:
555+
if fit_kwargs is None:
529556
fit_kwargs = {}
530557

531558
estimator_class = estimator_class or get_estimator_class(task, estimator_name)
@@ -605,7 +632,7 @@ def train_estimator(
605632
task=task,
606633
n_jobs=n_jobs,
607634
)
608-
if not fit_kwargs:
635+
if fit_kwargs is None:
609636
fit_kwargs = {}
610637

611638
if isinstance(estimator, TransformersEstimator):

0 commit comments

Comments
 (0)