Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix joblib model loading #542

Merged
merged 2 commits into from
Sep 15, 2022
Merged

fix joblib model loading #542

merged 2 commits into from
Sep 15, 2022

Conversation

amfonelic
Copy link
Contributor

@amfonelic amfonelic commented Sep 8, 2022

Hi! I faced the problem that some models that were dumped using joblib cannot be loaded using pickle in m2cgen i.e. DecisionTreeClassifier from scikit-learn. This leads to undefined behavior further down in the code.

Works file:

from pickle import dump
f = open("model.pkl","wb")
clf = DecisionTreeClassifier()
clf = clf.fit(X_train,y_train)
dump(clf, f) 

Won't work:

from joblib import dump
clf = DecisionTreeClassifier()
clf = clf.fit(X_train,y_train)
dump(clf, "model.pkl") 

Traceback looks like:

Traceback (most recent call last):
  File "~/miniconda3/bin/m2cgen", line 8, in <module>
    sys.exit(main())
  File "~/miniconda3/lib/python3.9/site-packages/m2cgen/cli.py", line 137, in main
    print(generate_code(args))
  File "~/miniconda3/lib/python3.9/site-packages/m2cgen/cli.py", line 132, in generate_code
    return exporter(model, **kwargs)
  File "~/miniconda3/lib/python3.9/site-packages/m2cgen/exporters.py", line 33, in export_to_java
    return _export(model, interpreter)
  File "~/miniconda3/lib/python3.9/site-packages/m2cgen/exporters.py", line 458, in _export
    assembler_cls = get_assembler_cls(model)
  File "~/miniconda3/lib/python3.9/site-packages/m2cgen/assemblers/__init__.py", line 147, in get_assembler_cls
    raise NotImplementedError(f"Model '{model_name}' is not supported")
NotImplementedError: Model 'numpy_ndarray' is not supported

By the way, I'm not the only one with this problem.
#287

@@ -109,7 +114,8 @@ def generate_code(args):
sys.setrecursionlimit(args.recursion_limit)

with args.infile as f:
model = pickle.load(f)
pickle_lib = __import__(args.lib)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unconditional loading of whatever module name is supplied into the CLI feels like a security concern to me. Can we perhaps maintain a list of supported libraries and check whether the supplied argument matches one of them?

parser.add_argument(
"--saved-by", "-sb",
type=str,
dest="lib",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we be a little more specific? Perhaps pickle_lib?

Copy link
Member

@izeigerman izeigerman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much for submitting this PR! A few comments, but is good to go otherwise.

@amfonelic
Copy link
Contributor Author

Thanks for review! Fixed issues you mentioned. Glad to be of service, you make brilliant software :)

@izeigerman
Copy link
Member

@pinktoxin thanks for addressing comments! Would you mind fixing the following linter errors please:

./tests/test_cli.py:126:1: E302 expected 2 blank lines, found 1
./tests/test_cli.py:132:1: E302 expected 2 blank lines, found 1
./m2cgen/cli.py:106:22: E231 missing whitespace after ','

Copy link
Member

@izeigerman izeigerman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing comments! 🚀

@izeigerman izeigerman merged commit b0f76ef into BayesWitnesses:master Sep 15, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Development

Successfully merging this pull request may close these issues.

2 participants