Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update abstract PyFunc to utilise returned treatment_config #164

Merged
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions sdk/turing/ensembler.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ class PyFunc(EnsemblerBase, mlflow.pyfunc.PythonModel, abc.ABC):
Abstract implementation of PyFunc Ensembler.
It leverages the contract of mlflow's PythonModel and implements its `predict` method.
"""

PREDICTION_COLUMN_PREFIX = '__predictions__'
TREATMENT_CONFIG_COLUMN_PREFIX = '__treatment_config__'

def load_context(self, context):
self.initialize(context.artifacts)
Expand All @@ -57,20 +57,27 @@ def initialize(self, artifacts: dict):

def predict(self, context, model_input: pandas.DataFrame) -> \
Union[numpy.ndarray, pandas.Series, pandas.DataFrame]:
prediction_columns = {
col: col[len(PyFunc.PREDICTION_COLUMN_PREFIX):]
for col in model_input.columns if col.startswith(PyFunc.PREDICTION_COLUMN_PREFIX)
}
prediction_columns = PyFunc._get_columns_with_header(model_input, PyFunc.PREDICTION_COLUMN_PREFIX)
treatment_config_columns = PyFunc._get_columns_with_header(model_input, PyFunc.TREATMENT_CONFIG_COLUMN_PREFIX)

return model_input \
.rename(columns=prediction_columns) \
.rename(columns=treatment_config_columns) \
.apply(lambda row:
self.ensemble(
features=row.drop(prediction_columns.values()),
features=row.drop(prediction_columns.values()).drop(treatment_config_columns.values()),
predictions=row[prediction_columns.values()],
treatment_config=None
treatment_config=row[treatment_config_columns.values()]
), axis=1, result_type='expand')

@staticmethod
def _get_columns_with_header(df: pandas.DataFrame, header: str):
deadlycoconuts marked this conversation as resolved.
Show resolved Hide resolved
selected_columns = {
col: col[len(header):]
for col in df.columns if col.startswith(header)
}
return selected_columns


@ApiObjectSpec(turing.generated.models.Ensembler)
class Ensembler(ApiObject):
Expand Down Expand Up @@ -273,7 +280,7 @@ def create(
conda_env: Union[str, Dict[str, Any]],
code_dir: Optional[List[str]] = None,
artifacts: Dict[str, str] = None,
) -> 'PyFuncEnsembler':
) -> 'PyFuncEnsembler':
"""
Save new pyfunc ensembler in the active project

Expand Down