diff --git a/python/pyspark/ml/tests/connect/test_connect_cache.py b/python/pyspark/ml/tests/connect/test_connect_cache.py index f911ab22286c0..b6c801f32eaf0 100644 --- a/python/pyspark/ml/tests/connect/test_connect_cache.py +++ b/python/pyspark/ml/tests/connect/test_connect_cache.py @@ -51,6 +51,9 @@ def test_delete_model(self): # the `model._summary` holds another ref to the remote model. assert model._java_obj._ref_count == 2 + model_size = spark.client._query_model_size(model._java_obj.ref_id) + assert isinstance(model_size, int) and model_size > 0 + model2 = model.copy() cache_info = spark.client._get_ml_cache_info() self.assertEqual(len(cache_info), 1) diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 6ac4cc1894c72..9d2e18ebb7600 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -2027,3 +2027,15 @@ def _get_ml_cache_info(self) -> List[str]: return [item.string for item in ml_command_result.param.array.elements] return [] + + def _query_model_size(self, model_ref_id: str) -> int: + command = pb2.Command() + command.ml_command.get_model_size.CopyFrom( + pb2.MlCommand.GetModelSize(model_ref=pb2.ObjectRef(id=model_ref_id)) + ) + (_, properties, _) = self.execute_command(command) + + assert properties is not None + + ml_command_result = properties["ml_command_result"] + return ml_command_result.param.long diff --git a/python/pyspark/sql/connect/proto/ml_pb2.py b/python/pyspark/sql/connect/proto/ml_pb2.py index 1ede558b94140..4c1b4038c35e3 100644 --- a/python/pyspark/sql/connect/proto/ml_pb2.py +++ b/python/pyspark/sql/connect/proto/ml_pb2.py @@ -40,7 +40,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x16spark/connect/ml.proto\x12\rspark.connect\x1a\x1dspark/connect/relations.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/ml_common.proto"\xb1\r\n\tMlCommand\x12\x30\n\x03\x66it\x18\x01 \x01(\x0b\x32\x1c.spark.connect.MlCommand.FitH\x00R\x03\x66it\x12,\n\x05\x66\x65tch\x18\x02 \x01(\x0b\x32\x14.spark.connect.FetchH\x00R\x05\x66\x65tch\x12\x39\n\x06\x64\x65lete\x18\x03 \x01(\x0b\x32\x1f.spark.connect.MlCommand.DeleteH\x00R\x06\x64\x65lete\x12\x36\n\x05write\x18\x04 \x01(\x0b\x32\x1e.spark.connect.MlCommand.WriteH\x00R\x05write\x12\x33\n\x04read\x18\x05 \x01(\x0b\x32\x1d.spark.connect.MlCommand.ReadH\x00R\x04read\x12?\n\x08\x65valuate\x18\x06 \x01(\x0b\x32!.spark.connect.MlCommand.EvaluateH\x00R\x08\x65valuate\x12\x46\n\x0b\x63lean_cache\x18\x07 \x01(\x0b\x32#.spark.connect.MlCommand.CleanCacheH\x00R\ncleanCache\x12M\n\x0eget_cache_info\x18\x08 \x01(\x0b\x32%.spark.connect.MlCommand.GetCacheInfoH\x00R\x0cgetCacheInfo\x12O\n\x0e\x63reate_summary\x18\t \x01(\x0b\x32&.spark.connect.MlCommand.CreateSummaryH\x00R\rcreateSummary\x1a\xb2\x01\n\x03\x46it\x12\x37\n\testimator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorR\testimator\x12\x34\n\x06params\x18\x02 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x00R\x06params\x88\x01\x01\x12\x31\n\x07\x64\x61taset\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationR\x07\x64\x61tasetB\t\n\x07_params\x1ap\n\x06\x44\x65lete\x12\x33\n\x08obj_refs\x18\x01 \x03(\x0b\x32\x18.spark.connect.ObjectRefR\x07objRefs\x12"\n\nevict_only\x18\x02 \x01(\x08H\x00R\tevictOnly\x88\x01\x01\x42\r\n\x0b_evict_only\x1a\x0c\n\nCleanCache\x1a\x0e\n\x0cGetCacheInfo\x1a\x9a\x03\n\x05Write\x12\x37\n\x08operator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorH\x00R\x08operator\x12\x33\n\x07obj_ref\x18\x02 \x01(\x0b\x32\x18.spark.connect.ObjectRefH\x00R\x06objRef\x12\x34\n\x06params\x18\x03 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x01R\x06params\x88\x01\x01\x12\x12\n\x04path\x18\x04 \x01(\tR\x04path\x12.\n\x10should_overwrite\x18\x05 \x01(\x08H\x02R\x0fshouldOverwrite\x88\x01\x01\x12\x45\n\x07options\x18\x06 \x03(\x0b\x32+.spark.connect.MlCommand.Write.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x06\n\x04typeB\t\n\x07_paramsB\x13\n\x11_should_overwrite\x1aQ\n\x04Read\x12\x35\n\x08operator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorR\x08operator\x12\x12\n\x04path\x18\x02 \x01(\tR\x04path\x1a\xb7\x01\n\x08\x45valuate\x12\x37\n\tevaluator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorR\tevaluator\x12\x34\n\x06params\x18\x02 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x00R\x06params\x88\x01\x01\x12\x31\n\x07\x64\x61taset\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationR\x07\x64\x61tasetB\t\n\x07_params\x1ay\n\rCreateSummary\x12\x35\n\tmodel_ref\x18\x01 \x01(\x0b\x32\x18.spark.connect.ObjectRefR\x08modelRef\x12\x31\n\x07\x64\x61taset\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x07\x64\x61tasetB\t\n\x07\x63ommand"\xd5\x03\n\x0fMlCommandResult\x12\x39\n\x05param\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x05param\x12\x1a\n\x07summary\x18\x02 \x01(\tH\x00R\x07summary\x12T\n\roperator_info\x18\x03 \x01(\x0b\x32-.spark.connect.MlCommandResult.MlOperatorInfoH\x00R\x0coperatorInfo\x1a\x85\x02\n\x0eMlOperatorInfo\x12\x33\n\x07obj_ref\x18\x01 \x01(\x0b\x32\x18.spark.connect.ObjectRefH\x00R\x06objRef\x12\x14\n\x04name\x18\x02 \x01(\tH\x00R\x04name\x12\x15\n\x03uid\x18\x03 \x01(\tH\x01R\x03uid\x88\x01\x01\x12\x34\n\x06params\x18\x04 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x02R\x06params\x88\x01\x01\x12,\n\x0fwarning_message\x18\x05 \x01(\tH\x03R\x0ewarningMessage\x88\x01\x01\x42\x06\n\x04typeB\x06\n\x04_uidB\t\n\x07_paramsB\x12\n\x10_warning_messageB\r\n\x0bresult_typeB6\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' + b'\n\x16spark/connect/ml.proto\x12\rspark.connect\x1a\x1dspark/connect/relations.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/ml_common.proto"\xc7\x0e\n\tMlCommand\x12\x30\n\x03\x66it\x18\x01 \x01(\x0b\x32\x1c.spark.connect.MlCommand.FitH\x00R\x03\x66it\x12,\n\x05\x66\x65tch\x18\x02 \x01(\x0b\x32\x14.spark.connect.FetchH\x00R\x05\x66\x65tch\x12\x39\n\x06\x64\x65lete\x18\x03 \x01(\x0b\x32\x1f.spark.connect.MlCommand.DeleteH\x00R\x06\x64\x65lete\x12\x36\n\x05write\x18\x04 \x01(\x0b\x32\x1e.spark.connect.MlCommand.WriteH\x00R\x05write\x12\x33\n\x04read\x18\x05 \x01(\x0b\x32\x1d.spark.connect.MlCommand.ReadH\x00R\x04read\x12?\n\x08\x65valuate\x18\x06 \x01(\x0b\x32!.spark.connect.MlCommand.EvaluateH\x00R\x08\x65valuate\x12\x46\n\x0b\x63lean_cache\x18\x07 \x01(\x0b\x32#.spark.connect.MlCommand.CleanCacheH\x00R\ncleanCache\x12M\n\x0eget_cache_info\x18\x08 \x01(\x0b\x32%.spark.connect.MlCommand.GetCacheInfoH\x00R\x0cgetCacheInfo\x12O\n\x0e\x63reate_summary\x18\t \x01(\x0b\x32&.spark.connect.MlCommand.CreateSummaryH\x00R\rcreateSummary\x12M\n\x0eget_model_size\x18\n \x01(\x0b\x32%.spark.connect.MlCommand.GetModelSizeH\x00R\x0cgetModelSize\x1a\xb2\x01\n\x03\x46it\x12\x37\n\testimator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorR\testimator\x12\x34\n\x06params\x18\x02 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x00R\x06params\x88\x01\x01\x12\x31\n\x07\x64\x61taset\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationR\x07\x64\x61tasetB\t\n\x07_params\x1ap\n\x06\x44\x65lete\x12\x33\n\x08obj_refs\x18\x01 \x03(\x0b\x32\x18.spark.connect.ObjectRefR\x07objRefs\x12"\n\nevict_only\x18\x02 \x01(\x08H\x00R\tevictOnly\x88\x01\x01\x42\r\n\x0b_evict_only\x1a\x0c\n\nCleanCache\x1a\x0e\n\x0cGetCacheInfo\x1a\x9a\x03\n\x05Write\x12\x37\n\x08operator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorH\x00R\x08operator\x12\x33\n\x07obj_ref\x18\x02 \x01(\x0b\x32\x18.spark.connect.ObjectRefH\x00R\x06objRef\x12\x34\n\x06params\x18\x03 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x01R\x06params\x88\x01\x01\x12\x12\n\x04path\x18\x04 \x01(\tR\x04path\x12.\n\x10should_overwrite\x18\x05 \x01(\x08H\x02R\x0fshouldOverwrite\x88\x01\x01\x12\x45\n\x07options\x18\x06 \x03(\x0b\x32+.spark.connect.MlCommand.Write.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x06\n\x04typeB\t\n\x07_paramsB\x13\n\x11_should_overwrite\x1aQ\n\x04Read\x12\x35\n\x08operator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorR\x08operator\x12\x12\n\x04path\x18\x02 \x01(\tR\x04path\x1a\xb7\x01\n\x08\x45valuate\x12\x37\n\tevaluator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorR\tevaluator\x12\x34\n\x06params\x18\x02 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x00R\x06params\x88\x01\x01\x12\x31\n\x07\x64\x61taset\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationR\x07\x64\x61tasetB\t\n\x07_params\x1ay\n\rCreateSummary\x12\x35\n\tmodel_ref\x18\x01 \x01(\x0b\x32\x18.spark.connect.ObjectRefR\x08modelRef\x12\x31\n\x07\x64\x61taset\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x07\x64\x61taset\x1a\x45\n\x0cGetModelSize\x12\x35\n\tmodel_ref\x18\x01 \x01(\x0b\x32\x18.spark.connect.ObjectRefR\x08modelRefB\t\n\x07\x63ommand"\xd5\x03\n\x0fMlCommandResult\x12\x39\n\x05param\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x05param\x12\x1a\n\x07summary\x18\x02 \x01(\tH\x00R\x07summary\x12T\n\roperator_info\x18\x03 \x01(\x0b\x32-.spark.connect.MlCommandResult.MlOperatorInfoH\x00R\x0coperatorInfo\x1a\x85\x02\n\x0eMlOperatorInfo\x12\x33\n\x07obj_ref\x18\x01 \x01(\x0b\x32\x18.spark.connect.ObjectRefH\x00R\x06objRef\x12\x14\n\x04name\x18\x02 \x01(\tH\x00R\x04name\x12\x15\n\x03uid\x18\x03 \x01(\tH\x01R\x03uid\x88\x01\x01\x12\x34\n\x06params\x18\x04 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x02R\x06params\x88\x01\x01\x12,\n\x0fwarning_message\x18\x05 \x01(\tH\x03R\x0ewarningMessage\x88\x01\x01\x42\x06\n\x04typeB\x06\n\x04_uidB\t\n\x07_paramsB\x12\n\x10_warning_messageB\r\n\x0bresult_typeB6\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' ) _globals = globals() @@ -54,27 +54,29 @@ _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._loaded_options = None _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_options = b"8\001" _globals["_MLCOMMAND"]._serialized_start = 137 - _globals["_MLCOMMAND"]._serialized_end = 1850 - _globals["_MLCOMMAND_FIT"]._serialized_start = 712 - _globals["_MLCOMMAND_FIT"]._serialized_end = 890 - _globals["_MLCOMMAND_DELETE"]._serialized_start = 892 - _globals["_MLCOMMAND_DELETE"]._serialized_end = 1004 - _globals["_MLCOMMAND_CLEANCACHE"]._serialized_start = 1006 - _globals["_MLCOMMAND_CLEANCACHE"]._serialized_end = 1018 - _globals["_MLCOMMAND_GETCACHEINFO"]._serialized_start = 1020 - _globals["_MLCOMMAND_GETCACHEINFO"]._serialized_end = 1034 - _globals["_MLCOMMAND_WRITE"]._serialized_start = 1037 - _globals["_MLCOMMAND_WRITE"]._serialized_end = 1447 - _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_start = 1349 - _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_end = 1407 - _globals["_MLCOMMAND_READ"]._serialized_start = 1449 - _globals["_MLCOMMAND_READ"]._serialized_end = 1530 - _globals["_MLCOMMAND_EVALUATE"]._serialized_start = 1533 - _globals["_MLCOMMAND_EVALUATE"]._serialized_end = 1716 - _globals["_MLCOMMAND_CREATESUMMARY"]._serialized_start = 1718 - _globals["_MLCOMMAND_CREATESUMMARY"]._serialized_end = 1839 - _globals["_MLCOMMANDRESULT"]._serialized_start = 1853 - _globals["_MLCOMMANDRESULT"]._serialized_end = 2322 - _globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_start = 2046 - _globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_end = 2307 + _globals["_MLCOMMAND"]._serialized_end = 2000 + _globals["_MLCOMMAND_FIT"]._serialized_start = 791 + _globals["_MLCOMMAND_FIT"]._serialized_end = 969 + _globals["_MLCOMMAND_DELETE"]._serialized_start = 971 + _globals["_MLCOMMAND_DELETE"]._serialized_end = 1083 + _globals["_MLCOMMAND_CLEANCACHE"]._serialized_start = 1085 + _globals["_MLCOMMAND_CLEANCACHE"]._serialized_end = 1097 + _globals["_MLCOMMAND_GETCACHEINFO"]._serialized_start = 1099 + _globals["_MLCOMMAND_GETCACHEINFO"]._serialized_end = 1113 + _globals["_MLCOMMAND_WRITE"]._serialized_start = 1116 + _globals["_MLCOMMAND_WRITE"]._serialized_end = 1526 + _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_start = 1428 + _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_end = 1486 + _globals["_MLCOMMAND_READ"]._serialized_start = 1528 + _globals["_MLCOMMAND_READ"]._serialized_end = 1609 + _globals["_MLCOMMAND_EVALUATE"]._serialized_start = 1612 + _globals["_MLCOMMAND_EVALUATE"]._serialized_end = 1795 + _globals["_MLCOMMAND_CREATESUMMARY"]._serialized_start = 1797 + _globals["_MLCOMMAND_CREATESUMMARY"]._serialized_end = 1918 + _globals["_MLCOMMAND_GETMODELSIZE"]._serialized_start = 1920 + _globals["_MLCOMMAND_GETMODELSIZE"]._serialized_end = 1989 + _globals["_MLCOMMANDRESULT"]._serialized_start = 2003 + _globals["_MLCOMMANDRESULT"]._serialized_end = 2472 + _globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_start = 2196 + _globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_end = 2457 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/ml_pb2.pyi b/python/pyspark/sql/connect/proto/ml_pb2.pyi index 0a72c207b5264..156ef846a8d10 100644 --- a/python/pyspark/sql/connect/proto/ml_pb2.pyi +++ b/python/pyspark/sql/connect/proto/ml_pb2.pyi @@ -388,6 +388,26 @@ class MlCommand(google.protobuf.message.Message): field_name: typing_extensions.Literal["dataset", b"dataset", "model_ref", b"model_ref"], ) -> None: ... + class GetModelSize(google.protobuf.message.Message): + """This is for query the model estimated in-memory size""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + MODEL_REF_FIELD_NUMBER: builtins.int + @property + def model_ref(self) -> pyspark.sql.connect.proto.ml_common_pb2.ObjectRef: ... + def __init__( + self, + *, + model_ref: pyspark.sql.connect.proto.ml_common_pb2.ObjectRef | None = ..., + ) -> None: ... + def HasField( + self, field_name: typing_extensions.Literal["model_ref", b"model_ref"] + ) -> builtins.bool: ... + def ClearField( + self, field_name: typing_extensions.Literal["model_ref", b"model_ref"] + ) -> None: ... + FIT_FIELD_NUMBER: builtins.int FETCH_FIELD_NUMBER: builtins.int DELETE_FIELD_NUMBER: builtins.int @@ -397,6 +417,7 @@ class MlCommand(google.protobuf.message.Message): CLEAN_CACHE_FIELD_NUMBER: builtins.int GET_CACHE_INFO_FIELD_NUMBER: builtins.int CREATE_SUMMARY_FIELD_NUMBER: builtins.int + GET_MODEL_SIZE_FIELD_NUMBER: builtins.int @property def fit(self) -> global___MlCommand.Fit: ... @property @@ -415,6 +436,8 @@ class MlCommand(google.protobuf.message.Message): def get_cache_info(self) -> global___MlCommand.GetCacheInfo: ... @property def create_summary(self) -> global___MlCommand.CreateSummary: ... + @property + def get_model_size(self) -> global___MlCommand.GetModelSize: ... def __init__( self, *, @@ -427,6 +450,7 @@ class MlCommand(google.protobuf.message.Message): clean_cache: global___MlCommand.CleanCache | None = ..., get_cache_info: global___MlCommand.GetCacheInfo | None = ..., create_summary: global___MlCommand.CreateSummary | None = ..., + get_model_size: global___MlCommand.GetModelSize | None = ..., ) -> None: ... def HasField( self, @@ -447,6 +471,8 @@ class MlCommand(google.protobuf.message.Message): b"fit", "get_cache_info", b"get_cache_info", + "get_model_size", + b"get_model_size", "read", b"read", "write", @@ -472,6 +498,8 @@ class MlCommand(google.protobuf.message.Message): b"fit", "get_cache_info", b"get_cache_info", + "get_model_size", + b"get_model_size", "read", b"read", "write", @@ -491,6 +519,7 @@ class MlCommand(google.protobuf.message.Message): "clean_cache", "get_cache_info", "create_summary", + "get_model_size", ] | None ): ... diff --git a/sql/connect/common/src/main/protobuf/spark/connect/ml.proto b/sql/connect/common/src/main/protobuf/spark/connect/ml.proto index 3497284af4ab8..ef5c406dedd26 100644 --- a/sql/connect/common/src/main/protobuf/spark/connect/ml.proto +++ b/sql/connect/common/src/main/protobuf/spark/connect/ml.proto @@ -39,6 +39,7 @@ message MlCommand { CleanCache clean_cache = 7; GetCacheInfo get_cache_info = 8; CreateSummary create_summary = 9; + GetModelSize get_model_size = 10; } // Command for estimator.fit(dataset) @@ -109,6 +110,11 @@ message MlCommand { ObjectRef model_ref = 1; Relation dataset = 2; } + + // This is for query the model estimated in-memory size + message GetModelSize { + ObjectRef model_ref = 1; + } } // The result of MlCommand diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala index 4de4f238e41a9..40f1172677a50 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala @@ -448,6 +448,15 @@ private[connect] object MLHandler extends Logging { val createSummaryCmd = mlCommand.getCreateSummary createModelSummary(sessionHolder, createSummaryCmd) + case proto.MlCommand.CommandCase.GET_MODEL_SIZE => + val modelRefId = mlCommand.getGetModelSize.getModelRef.getId + val model = mlCache.get(modelRefId) + val modelSize = model.asInstanceOf[Model[_]].estimatedSize + proto.MlCommandResult + .newBuilder() + .setParam(LiteralValueProtoConverter.toLiteralProto(modelSize)) + .build() + case other => throw MlUnsupportedException(s"$other not supported") } }