-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-42412][WIP] Initial PR of Spark connect ML #40297
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 34 commits
684a9e3
97ab924
86f2fad
4f17d6c
6582037
ba4f580
941550e
c4473a6
606168d
1003787
f9f3542
ed24307
c1f9162
e178de3
130bd1e
eee1013
d72fba0
36bc69b
870c994
e500be8
7c44e5c
e87aa53
18876c2
23744b8
33be464
bb43f01
ec89d40
2ebcf45
6e12a22
9ae327b
36e6d33
e5278cb
c80414d
66c472b
39b2f24
2d4377d
e316a82
caebf75
6689a1d
2ad0e13
49e0e0c
809bd00
def7aa5
581a5ee
296099c
93b00f8
f227df9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,193 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| syntax = 'proto3'; | ||
|
|
||
| package spark.connect; | ||
|
|
||
| import "spark/connect/expressions.proto"; | ||
| import "spark/connect/relations.proto"; | ||
| import "spark/connect/ml_common.proto"; | ||
|
|
||
| option java_multiple_files = true; | ||
| option java_package = "org.apache.spark.connect.proto"; | ||
WeichenXu123 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| // MlEvaluator represents a ML Evaluator | ||
| message MlEvaluator { | ||
| // The name of the evaluator in the registry | ||
| string name = 1; | ||
| // param settings for the evaluator | ||
| MlParams params = 2; | ||
| // unique id of the evaluator | ||
| string uid = 3; | ||
| } | ||
|
|
||
|
|
||
| // a MlCommand is a type container that has exactly one ML command set | ||
| message MlCommand { | ||
| oneof ml_command_type { | ||
| // call `estimator.fit` and returns a model | ||
| Fit fit = 1; | ||
| // get model attribute | ||
| FetchModelAttr fetch_model_attr = 2; | ||
| // get model summary attribute | ||
| FetchModelSummaryAttr fetch_model_summary_attr = 3; | ||
| // load model | ||
| LoadModel load_model = 4; | ||
| // save model | ||
| SaveModel save_model = 5; | ||
| // call `evaluator.evaluate` | ||
| Evaluate evaluate = 6; | ||
| // save estimator or transformer | ||
| SaveStage save_stage = 7; | ||
| // load estimator or transformer | ||
| LoadStage load_stage = 8; | ||
| // save estimator | ||
| SaveEvaluator save_evaluator = 9; | ||
| // load estimator | ||
| LoadEvaluator load_evaluator = 10; | ||
| // copy model, returns new model reference id | ||
| CopyModel copy_model = 11; | ||
| // delete server side model object by model reference id | ||
| DeleteModel delete_model = 12; | ||
| } | ||
|
|
||
| message Fit { | ||
| MlStage estimator = 1; | ||
| Relation dataset = 2; | ||
| } | ||
|
|
||
| message Evaluate { | ||
| MlEvaluator evaluator = 1; | ||
| } | ||
|
|
||
| message LoadModel { | ||
zhengruifeng marked this conversation as resolved.
Show resolved
Hide resolved
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would this work with arbitrary model for example provided by Spark NLP?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For current PR, it does not support third-party estimators.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we want to support 3rd-party algorithm without registry, then inevitably we have to use java reflection to invoke methods (e.g. We need to invoke
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Btw, supporting 3rd-party estimators is risky, because in shared cluster we will binpack the spark workers across different customers (according to @mengxr 's explanation) |
||
| string name = 1; | ||
| string path = 2; | ||
| } | ||
|
|
||
| message SaveModel { | ||
| int64 model_ref_id = 1; | ||
| string path = 2; // saving path | ||
| bool overwrite = 3; | ||
| map<string, string> options = 4; // saving options | ||
| } | ||
|
|
||
| message LoadStage { | ||
| string name = 1; | ||
| string path = 2; | ||
| MlStage.StageType type = 3; | ||
| } | ||
|
|
||
| message SaveStage { | ||
| MlStage stage = 1; | ||
| string path = 2; // saving path | ||
| bool overwrite = 3; | ||
| map<string, string> options = 4; // saving options | ||
| } | ||
|
|
||
| message LoadEvaluator { | ||
| string name = 1; | ||
| string path = 2; | ||
| } | ||
|
|
||
| message SaveEvaluator { | ||
| MlEvaluator evaluator = 1; | ||
| string path = 2; // saving path | ||
| bool overwrite = 3; | ||
| map<string, string> options = 4; // saving options | ||
| } | ||
|
|
||
| message FetchModelAttr { | ||
| int64 model_ref_id = 1; | ||
| string name = 2; | ||
| } | ||
|
|
||
| message FetchModelSummaryAttr { | ||
| int64 model_ref_id = 1; | ||
| string name = 2; | ||
| MlParams params = 3; | ||
|
|
||
| // Evaluation dataset that it uses to computes | ||
| // the summary attribute | ||
| // If not set, get attributes from | ||
| // model.summary (i.e. the summary on training dataset) | ||
| optional Relation evaluation_dataset = 4; | ||
| } | ||
|
|
||
| message CopyModel { | ||
| int64 model_ref_id = 1; | ||
| } | ||
|
|
||
| message DeleteModel { | ||
| int64 model_ref_id = 1; | ||
| } | ||
| } | ||
|
|
||
|
|
||
| message MlCommandResponse { | ||
| oneof ml_command_response_type { | ||
| Expression.Literal literal = 1; | ||
| ModelInfo model_info = 2; | ||
| Vector vector = 3; | ||
zhengruifeng marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| Matrix matrix = 4; | ||
| MlStage stage = 5; | ||
| } | ||
| message ModelInfo { | ||
| int64 model_ref_id = 1; | ||
| string model_uid = 2; | ||
| MlParams params = 3; | ||
| } | ||
| } | ||
|
|
||
|
|
||
| message Vector { | ||
| oneof one_of { | ||
| Dense dense = 1; | ||
| Sparse sparse = 2; | ||
| } | ||
| message Dense { | ||
| repeated double value = 1; | ||
| } | ||
| message Sparse { | ||
| int32 size = 1; | ||
| repeated double index = 2; | ||
| repeated double value = 3; | ||
| } | ||
| } | ||
|
|
||
| message Matrix { | ||
| oneof one_of { | ||
| Dense dense = 1; | ||
| Sparse sparse = 2; | ||
| } | ||
| message Dense { | ||
| int32 num_rows = 1; | ||
| int32 num_cols = 2; | ||
| repeated double value = 3; | ||
| bool is_transposed = 4; | ||
| } | ||
| message Sparse { | ||
| int32 num_rows = 1; | ||
| int32 num_cols = 2; | ||
| repeated double colptr = 3; | ||
| repeated double row_index = 4; | ||
| repeated double value = 5; | ||
| bool is_transposed = 6; | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,51 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| syntax = 'proto3'; | ||
|
|
||
| package spark.connect; | ||
|
|
||
| import "spark/connect/expressions.proto"; | ||
|
|
||
| option java_multiple_files = true; | ||
| option java_package = "org.apache.spark.connect.proto"; | ||
|
|
||
|
|
||
| // MlParams stores param settings for | ||
| // ML Estimator / Transformer / Model / Evaluator | ||
| message MlParams { | ||
| // user-supplied params | ||
| map<string, Expression.Literal> params = 1; | ||
| // default params | ||
| map<string, Expression.Literal> default_params = 2; | ||
| } | ||
|
|
||
| // MlStage stores ML stage data (Estimator or Transformer) | ||
| message MlStage { | ||
| // The name of the stage in the registry | ||
| string name = 1; | ||
| // param settings for the stage | ||
| MlParams params = 2; | ||
| // unique id of the stage | ||
| string uid = 3; | ||
| StageType type = 4; | ||
| enum StageType { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this knowledge actually required on the client?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Or we can make server side infer the stage type from stage name, |
||
| UNSPECIFIED = 0; | ||
| ESTIMATOR = 1; | ||
| TRANSFORMER = 2; | ||
|
||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,6 +23,7 @@ import "google/protobuf/any.proto"; | |
| import "spark/connect/expressions.proto"; | ||
| import "spark/connect/types.proto"; | ||
| import "spark/connect/catalog.proto"; | ||
| import "spark/connect/ml_common.proto"; | ||
|
|
||
| option java_multiple_files = true; | ||
| option java_package = "org.apache.spark.connect.proto"; | ||
|
|
@@ -82,13 +83,50 @@ message Relation { | |
| // Catalog API (experimental / unstable) | ||
| Catalog catalog = 200; | ||
|
|
||
| // ML relation | ||
| MlRelation ml_relation = 300; | ||
|
|
||
| // This field is used to mark extensions to the protocol. When plugins generate arbitrary | ||
| // relations they can add them here. During the planning the correct resolution is done. | ||
| google.protobuf.Any extension = 998; | ||
| Unknown unknown = 999; | ||
| } | ||
| } | ||
|
|
||
| message MlRelation { | ||
| oneof ml_relation_type { | ||
| ModelTransform model_transform = 1; | ||
| FeatureTransform feature_transform = 2; | ||
| ModelAttr model_attr = 3; | ||
| ModelSummaryAttr model_summary_attr = 4; | ||
| } | ||
| message ModelTransform { | ||
| Relation input = 1; | ||
| int64 model_ref_id = 2; | ||
|
||
| MlParams params = 3; | ||
| } | ||
| message FeatureTransform { | ||
zhengruifeng marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| Relation input = 1; | ||
| MlStage transformer = 2; | ||
| } | ||
| message ModelAttr { | ||
| int64 model_ref_id = 1; | ||
| string name = 2; | ||
| } | ||
| message ModelSummaryAttr { | ||
| int64 model_ref_id = 1; | ||
| string name = 2; | ||
| MlParams params = 3; | ||
|
|
||
| // Evaluation dataset that it uses to computes | ||
| // the summary attribute | ||
| // If not set, get attributes from | ||
| // model.summary (i.e. the summary on training dataset) | ||
| optional Relation evaluation_dataset = 4; | ||
| } | ||
| } | ||
|
|
||
|
|
||
| // Used for testing purposes only. | ||
| message Unknown {} | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.