Skip to content

Commit

Permalink
Remove temporary pickle files (#1354)
Browse files Browse the repository at this point in the history
* Remove temporary pickle files

* Update version to 2.3.1

* Use TemporaryDirectory for pickle and log_artifact

* Fix 'CatBoostClassifier' object has no attribute '_get_param_names'
  • Loading branch information
thinkall authored Sep 21, 2024
1 parent c90946f commit 8e171bc
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 11 deletions.
26 changes: 16 additions & 10 deletions flaml/fabric/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pickle
import random
import sys
import tempfile
import time
from typing import MutableMapping

Expand Down Expand Up @@ -55,12 +56,12 @@ def get_mlflow_log_latency(model_history=False):
sk_model = tree.DecisionTreeClassifier()
mlflow.sklearn.log_model(sk_model, "sk_models")
mlflow.sklearn.log_model(Pipeline([("estimator", sk_model)]), "sk_pipeline")
pickle_fpath = f"tmp_{int(time.time()*1000)}"
with open(pickle_fpath, "wb") as f:
pickle.dump(sk_model, f)
mlflow.log_artifact(pickle_fpath, "sk_model1")
mlflow.log_artifact(pickle_fpath, "sk_model2")
os.remove(pickle_fpath)
with tempfile.TemporaryDirectory() as tmpdir:
pickle_fpath = os.path.join(tmpdir, f"tmp_{int(time.time()*1000)}")
with open(pickle_fpath, "wb") as f:
pickle.dump(sk_model, f)
mlflow.log_artifact(pickle_fpath, "sk_model1")
mlflow.log_artifact(pickle_fpath, "sk_model2")
mlflow.set_tag("synapseml.ui.visible", "false") # not shown inline in fabric
mlflow.delete_run(run.info.run_id)
et = time.time()
Expand Down Expand Up @@ -348,12 +349,17 @@ def log_model(self, model, estimator, signature=None):
else:
mlflow.sklearn.log_model(model, estimator, signature=signature)

def _pickle_and_log_artifact(self, obj, artifact_name, pickle_fpath="temp_.pkl"):
def _pickle_and_log_artifact(self, obj, artifact_name, pickle_fname="temp_.pkl"):
if not self._do_log_model:
return
with open(pickle_fpath, "wb") as f:
pickle.dump(obj, f)
mlflow.log_artifact(pickle_fpath, artifact_name)
with tempfile.TemporaryDirectory() as tmpdir:
pickle_fpath = os.path.join(tmpdir, pickle_fname)
try:
with open(pickle_fpath, "wb") as f:
pickle.dump(obj, f)
mlflow.log_artifact(pickle_fpath, artifact_name)
except Exception as e:
logger.debug(f"Failed to pickle and log artifact {artifact_name}, error: {e}")

def pickle_and_log_automl_artifacts(self, automl, model, estimator, signature=None):
"""log automl artifacts to mlflow
Expand Down
2 changes: 1 addition & 1 deletion flaml/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "2.3.0"
__version__ = "2.3.1"
2 changes: 2 additions & 0 deletions test/automl/test_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ def _check_mlflow_parameters(automl: AutoML, run_info: mlflow.entities.RunInfo):
t = pickle.load(f)
if __name__ == "__main__":
print(t)
if not hasattr(automl.model._model, "_get_param_names"):
return
for param in automl.model._model._get_param_names():
assert eval("t._final_estimator._model" + f".{param}") == eval(
"automl.model._model" + f".{param}"
Expand Down

0 comments on commit 8e171bc

Please sign in to comment.