Skip to content
Closed
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
684a9e3
update
WeichenXu123 Mar 3, 2023
97ab924
update
WeichenXu123 Mar 6, 2023
86f2fad
update
WeichenXu123 Mar 6, 2023
4f17d6c
update
WeichenXu123 Mar 6, 2023
6582037
merge master & fix
WeichenXu123 Mar 7, 2023
ba4f580
update
WeichenXu123 Mar 8, 2023
941550e
update
WeichenXu123 Mar 8, 2023
c4473a6
update
WeichenXu123 Mar 9, 2023
606168d
update
WeichenXu123 Mar 9, 2023
1003787
update
WeichenXu123 Mar 9, 2023
f9f3542
update
WeichenXu123 Mar 9, 2023
ed24307
fix
WeichenXu123 Mar 10, 2023
c1f9162
merge master
WeichenXu123 Mar 10, 2023
e178de3
update
WeichenXu123 Mar 10, 2023
130bd1e
update
WeichenXu123 Mar 10, 2023
eee1013
fix
WeichenXu123 Mar 11, 2023
d72fba0
update
WeichenXu123 Mar 11, 2023
36bc69b
update
WeichenXu123 Mar 13, 2023
870c994
merge master
WeichenXu123 Mar 13, 2023
e500be8
update
WeichenXu123 Mar 13, 2023
7c44e5c
update
WeichenXu123 Mar 13, 2023
e87aa53
merge master
WeichenXu123 Mar 13, 2023
18876c2
update
WeichenXu123 Mar 13, 2023
23744b8
model gc
WeichenXu123 Mar 13, 2023
33be464
try_remote_ml_class
WeichenXu123 Mar 13, 2023
bb43f01
format
WeichenXu123 Mar 14, 2023
ec89d40
format
WeichenXu123 Mar 14, 2023
2ebcf45
fix tests
WeichenXu123 Mar 14, 2023
6e12a22
update
WeichenXu123 Mar 14, 2023
9ae327b
doctests
WeichenXu123 Mar 14, 2023
36e6d33
merge master
WeichenXu123 Mar 14, 2023
e5278cb
Merge branch 'master' into spark-connect-ml-1
WeichenXu123 Mar 14, 2023
c80414d
add proto comments
WeichenXu123 Mar 15, 2023
66c472b
address comments
WeichenXu123 Mar 15, 2023
39b2f24
model_ref message
WeichenXu123 Mar 16, 2023
2d4377d
move to ml.connect
WeichenXu123 Mar 16, 2023
e316a82
Merge branch 'master' into spark-connect-ml-1
WeichenXu123 Mar 16, 2023
caebf75
fix
WeichenXu123 Mar 16, 2023
6689a1d
merge master
WeichenXu123 Mar 20, 2023
2ad0e13
update
WeichenXu123 Mar 20, 2023
49e0e0c
update
WeichenXu123 Mar 21, 2023
809bd00
update
WeichenXu123 Mar 21, 2023
def7aa5
update doctest
WeichenXu123 Mar 21, 2023
581a5ee
update
WeichenXu123 Mar 21, 2023
296099c
merge master
WeichenXu123 Mar 21, 2023
93b00f8
update
WeichenXu123 Mar 21, 2023
f227df9
update
WeichenXu123 Mar 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import "spark/connect/commands.proto";
import "spark/connect/expressions.proto";
import "spark/connect/relations.proto";
import "spark/connect/types.proto";
import "spark/connect/ml.proto";

option java_multiple_files = true;
option java_package = "org.apache.spark.connect.proto";
Expand All @@ -36,6 +37,7 @@ message Plan {
oneof op_type {
Relation root = 1;
Command command = 2;
MlCommand ml_command = 3;
}
}

Expand Down Expand Up @@ -261,6 +263,9 @@ message ExecutePlanResponse {
// Special case for executing SQL commands.
SqlCommandResult sql_command_result = 5;

// ML command response
MlCommandResponse ml_command_result = 100;

// Support arbitrary result objects.
google.protobuf.Any extension = 999;
}
Expand Down
193 changes: 193 additions & 0 deletions connector/connect/common/src/main/protobuf/spark/connect/ml.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

syntax = 'proto3';

package spark.connect;

import "spark/connect/expressions.proto";
import "spark/connect/relations.proto";
import "spark/connect/ml_common.proto";

option java_multiple_files = true;
option java_package = "org.apache.spark.connect.proto";


// MlEvaluator represents a ML Evaluator
message MlEvaluator {
// The name of the evaluator in the registry
string name = 1;
// param settings for the evaluator
MlParams params = 2;
// unique id of the evaluator
string uid = 3;
}


// a MlCommand is a type container that has exactly one ML command set
message MlCommand {
oneof ml_command_type {
// call `estimator.fit` and returns a model
Fit fit = 1;
// get model attribute
FetchModelAttr fetch_model_attr = 2;
// get model summary attribute
FetchModelSummaryAttr fetch_model_summary_attr = 3;
// load model
LoadModel load_model = 4;
// save model
SaveModel save_model = 5;
// call `evaluator.evaluate`
Evaluate evaluate = 6;
// save estimator or transformer
SaveStage save_stage = 7;
// load estimator or transformer
LoadStage load_stage = 8;
// save estimator
SaveEvaluator save_evaluator = 9;
// load estimator
LoadEvaluator load_evaluator = 10;
// copy model, returns new model reference id
CopyModel copy_model = 11;
// delete server side model object by model reference id
DeleteModel delete_model = 12;
}

message Fit {
MlStage estimator = 1;
Relation dataset = 2;
}

message Evaluate {
MlEvaluator evaluator = 1;
}

message LoadModel {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this work with arbitrary model for example provided by Spark NLP?

Copy link
Contributor Author

@WeichenXu123 WeichenXu123 Mar 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For current PR, it does not support third-party estimators.
We need to register related class for 3rd-party algorithm to AlgorithmRegistry class.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want to support 3rd-party algorithm without registry, then inevitably we have to use java reflection to invoke methods (e.g. We need to invoke XXXModel.load to load model, which is unsafe.

Copy link
Contributor Author

@WeichenXu123 WeichenXu123 Mar 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, supporting 3rd-party estimators is risky, because in shared cluster we will binpack the spark workers across different customers (according to @mengxr 's explanation)
But 3rd-party estimators implementation might invoke RDD transformation (e.g. RDD.map) that we cannot isolate them by container. So it is risky if we allow user uses 3rd-party estimators on shared cluster.

string name = 1;
string path = 2;
}

message SaveModel {
int64 model_ref_id = 1;
string path = 2; // saving path
bool overwrite = 3;
map<string, string> options = 4; // saving options
}

message LoadStage {
string name = 1;
string path = 2;
MlStage.StageType type = 3;
}

message SaveStage {
MlStage stage = 1;
string path = 2; // saving path
bool overwrite = 3;
map<string, string> options = 4; // saving options
}

message LoadEvaluator {
string name = 1;
string path = 2;
}

message SaveEvaluator {
MlEvaluator evaluator = 1;
string path = 2; // saving path
bool overwrite = 3;
map<string, string> options = 4; // saving options
}

message FetchModelAttr {
int64 model_ref_id = 1;
string name = 2;
}

message FetchModelSummaryAttr {
int64 model_ref_id = 1;
string name = 2;
MlParams params = 3;

// Evaluation dataset that it uses to computes
// the summary attribute
// If not set, get attributes from
// model.summary (i.e. the summary on training dataset)
optional Relation evaluation_dataset = 4;
}

message CopyModel {
int64 model_ref_id = 1;
}

message DeleteModel {
int64 model_ref_id = 1;
}
}


message MlCommandResponse {
oneof ml_command_response_type {
Expression.Literal literal = 1;
ModelInfo model_info = 2;
Vector vector = 3;
Matrix matrix = 4;
MlStage stage = 5;
}
message ModelInfo {
int64 model_ref_id = 1;
string model_uid = 2;
MlParams params = 3;
}
}


message Vector {
oneof one_of {
Dense dense = 1;
Sparse sparse = 2;
}
message Dense {
repeated double value = 1;
}
message Sparse {
int32 size = 1;
repeated double index = 2;
repeated double value = 3;
}
}

message Matrix {
oneof one_of {
Dense dense = 1;
Sparse sparse = 2;
}
message Dense {
int32 num_rows = 1;
int32 num_cols = 2;
repeated double value = 3;
bool is_transposed = 4;
}
message Sparse {
int32 num_rows = 1;
int32 num_cols = 2;
repeated double colptr = 3;
repeated double row_index = 4;
repeated double value = 5;
bool is_transposed = 6;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

syntax = 'proto3';

package spark.connect;

import "spark/connect/expressions.proto";

option java_multiple_files = true;
option java_package = "org.apache.spark.connect.proto";


// MlParams stores param settings for
// ML Estimator / Transformer / Model / Evaluator
message MlParams {
// user-supplied params
map<string, Expression.Literal> params = 1;
// default params
map<string, Expression.Literal> default_params = 2;
}

// MlStage stores ML stage data (Estimator or Transformer)
message MlStage {
// The name of the stage in the registry
string name = 1;
// param settings for the stage
MlParams params = 2;
// unique id of the stage
string uid = 3;
StageType type = 4;
enum StageType {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this knowledge actually required on the client?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or we can make server side infer the stage type from stage name,
but let client fill the stage type is easier for code.

UNSPECIFIED = 0;
ESTIMATOR = 1;
TRANSFORMER = 2;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we normally name enums like this

    STAGE_TYPE_UNSPECIFIED = 0;
    STAGE_TYPE_ESTIMATOR = 1;
    STAGE_TYPE_TRANSFORMER = 2;

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import "google/protobuf/any.proto";
import "spark/connect/expressions.proto";
import "spark/connect/types.proto";
import "spark/connect/catalog.proto";
import "spark/connect/ml_common.proto";

option java_multiple_files = true;
option java_package = "org.apache.spark.connect.proto";
Expand Down Expand Up @@ -82,13 +83,50 @@ message Relation {
// Catalog API (experimental / unstable)
Catalog catalog = 200;

// ML relation
MlRelation ml_relation = 300;

// This field is used to mark extensions to the protocol. When plugins generate arbitrary
// relations they can add them here. During the planning the correct resolution is done.
google.protobuf.Any extension = 998;
Unknown unknown = 999;
}
}

message MlRelation {
oneof ml_relation_type {
ModelTransform model_transform = 1;
FeatureTransform feature_transform = 2;
ModelAttr model_attr = 3;
ModelSummaryAttr model_summary_attr = 4;
}
message ModelTransform {
Relation input = 1;
int64 model_ref_id = 2;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My suggestion here is to maybe wrap the moddel_ref_id into an extra message object that becomes easier to extend.

message ModelRef {
  int64 id = 1;
}

That said, is there a reason the ID is numeric vs a string?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ID is generated from a increamental counter. So I think int64 type should be fine.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

message ModelRef {
int64 id = 1;
}

This sounds good.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ID is generated from a increamental counter.

Using random UUID might be a better idea , if we want to support server failover in future (we need to persist status and restore it, random UUID can help avoiding reusing ID that is generated before.)

MlParams params = 3;
}
message FeatureTransform {
Relation input = 1;
MlStage transformer = 2;
}
message ModelAttr {
int64 model_ref_id = 1;
string name = 2;
}
message ModelSummaryAttr {
int64 model_ref_id = 1;
string name = 2;
MlParams params = 3;

// Evaluation dataset that it uses to computes
// the summary attribute
// If not set, get attributes from
// model.summary (i.e. the summary on training dataset)
optional Relation evaluation_dataset = 4;
}
}


// Used for testing purposes only.
message Unknown {}

Expand Down
Loading