Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2027,3 +2027,18 @@ def _get_ml_cache_info(self) -> List[str]:
return [item.string for item in ml_command_result.param.array.elements]

return []

def _query_model_size(self, model_ref_id) -> int:
command = pb2.Command()
command.ml_command.read.CopyFrom(
pb2.MlCommand.GetModelSize(
model_ref=pb2.ObjectRef(id=model_ref_id)
)
)
command.ml_command.get_model_size.model_ref = pb2.ObjectRef(id=model_ref_id)
(_, properties, _) = self.execute_command(command)

assert properties is not None

ml_command_result = properties["ml_command_result"]
return ml_command_result.param.long
50 changes: 26 additions & 24 deletions python/pyspark/sql/connect/proto/ml_pb2.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@


DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n\x16spark/connect/ml.proto\x12\rspark.connect\x1a\x1dspark/connect/relations.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/ml_common.proto"\xb1\r\n\tMlCommand\x12\x30\n\x03\x66it\x18\x01 \x01(\x0b\x32\x1c.spark.connect.MlCommand.FitH\x00R\x03\x66it\x12,\n\x05\x66\x65tch\x18\x02 \x01(\x0b\x32\x14.spark.connect.FetchH\x00R\x05\x66\x65tch\x12\x39\n\x06\x64\x65lete\x18\x03 \x01(\x0b\x32\x1f.spark.connect.MlCommand.DeleteH\x00R\x06\x64\x65lete\x12\x36\n\x05write\x18\x04 \x01(\x0b\x32\x1e.spark.connect.MlCommand.WriteH\x00R\x05write\x12\x33\n\x04read\x18\x05 \x01(\x0b\x32\x1d.spark.connect.MlCommand.ReadH\x00R\x04read\x12?\n\x08\x65valuate\x18\x06 \x01(\x0b\x32!.spark.connect.MlCommand.EvaluateH\x00R\x08\x65valuate\x12\x46\n\x0b\x63lean_cache\x18\x07 \x01(\x0b\x32#.spark.connect.MlCommand.CleanCacheH\x00R\ncleanCache\x12M\n\x0eget_cache_info\x18\x08 \x01(\x0b\x32%.spark.connect.MlCommand.GetCacheInfoH\x00R\x0cgetCacheInfo\x12O\n\x0e\x63reate_summary\x18\t \x01(\x0b\x32&.spark.connect.MlCommand.CreateSummaryH\x00R\rcreateSummary\x1a\xb2\x01\n\x03\x46it\x12\x37\n\testimator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorR\testimator\x12\x34\n\x06params\x18\x02 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x00R\x06params\x88\x01\x01\x12\x31\n\x07\x64\x61taset\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationR\x07\x64\x61tasetB\t\n\x07_params\x1ap\n\x06\x44\x65lete\x12\x33\n\x08obj_refs\x18\x01 \x03(\x0b\x32\x18.spark.connect.ObjectRefR\x07objRefs\x12"\n\nevict_only\x18\x02 \x01(\x08H\x00R\tevictOnly\x88\x01\x01\x42\r\n\x0b_evict_only\x1a\x0c\n\nCleanCache\x1a\x0e\n\x0cGetCacheInfo\x1a\x9a\x03\n\x05Write\x12\x37\n\x08operator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorH\x00R\x08operator\x12\x33\n\x07obj_ref\x18\x02 \x01(\x0b\x32\x18.spark.connect.ObjectRefH\x00R\x06objRef\x12\x34\n\x06params\x18\x03 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x01R\x06params\x88\x01\x01\x12\x12\n\x04path\x18\x04 \x01(\tR\x04path\x12.\n\x10should_overwrite\x18\x05 \x01(\x08H\x02R\x0fshouldOverwrite\x88\x01\x01\x12\x45\n\x07options\x18\x06 \x03(\x0b\x32+.spark.connect.MlCommand.Write.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x06\n\x04typeB\t\n\x07_paramsB\x13\n\x11_should_overwrite\x1aQ\n\x04Read\x12\x35\n\x08operator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorR\x08operator\x12\x12\n\x04path\x18\x02 \x01(\tR\x04path\x1a\xb7\x01\n\x08\x45valuate\x12\x37\n\tevaluator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorR\tevaluator\x12\x34\n\x06params\x18\x02 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x00R\x06params\x88\x01\x01\x12\x31\n\x07\x64\x61taset\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationR\x07\x64\x61tasetB\t\n\x07_params\x1ay\n\rCreateSummary\x12\x35\n\tmodel_ref\x18\x01 \x01(\x0b\x32\x18.spark.connect.ObjectRefR\x08modelRef\x12\x31\n\x07\x64\x61taset\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x07\x64\x61tasetB\t\n\x07\x63ommand"\xd5\x03\n\x0fMlCommandResult\x12\x39\n\x05param\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x05param\x12\x1a\n\x07summary\x18\x02 \x01(\tH\x00R\x07summary\x12T\n\roperator_info\x18\x03 \x01(\x0b\x32-.spark.connect.MlCommandResult.MlOperatorInfoH\x00R\x0coperatorInfo\x1a\x85\x02\n\x0eMlOperatorInfo\x12\x33\n\x07obj_ref\x18\x01 \x01(\x0b\x32\x18.spark.connect.ObjectRefH\x00R\x06objRef\x12\x14\n\x04name\x18\x02 \x01(\tH\x00R\x04name\x12\x15\n\x03uid\x18\x03 \x01(\tH\x01R\x03uid\x88\x01\x01\x12\x34\n\x06params\x18\x04 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x02R\x06params\x88\x01\x01\x12,\n\x0fwarning_message\x18\x05 \x01(\tH\x03R\x0ewarningMessage\x88\x01\x01\x42\x06\n\x04typeB\x06\n\x04_uidB\t\n\x07_paramsB\x12\n\x10_warning_messageB\r\n\x0bresult_typeB6\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3'
b'\n\x16spark/connect/ml.proto\x12\rspark.connect\x1a\x1dspark/connect/relations.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/ml_common.proto"\xc7\x0e\n\tMlCommand\x12\x30\n\x03\x66it\x18\x01 \x01(\x0b\x32\x1c.spark.connect.MlCommand.FitH\x00R\x03\x66it\x12,\n\x05\x66\x65tch\x18\x02 \x01(\x0b\x32\x14.spark.connect.FetchH\x00R\x05\x66\x65tch\x12\x39\n\x06\x64\x65lete\x18\x03 \x01(\x0b\x32\x1f.spark.connect.MlCommand.DeleteH\x00R\x06\x64\x65lete\x12\x36\n\x05write\x18\x04 \x01(\x0b\x32\x1e.spark.connect.MlCommand.WriteH\x00R\x05write\x12\x33\n\x04read\x18\x05 \x01(\x0b\x32\x1d.spark.connect.MlCommand.ReadH\x00R\x04read\x12?\n\x08\x65valuate\x18\x06 \x01(\x0b\x32!.spark.connect.MlCommand.EvaluateH\x00R\x08\x65valuate\x12\x46\n\x0b\x63lean_cache\x18\x07 \x01(\x0b\x32#.spark.connect.MlCommand.CleanCacheH\x00R\ncleanCache\x12M\n\x0eget_cache_info\x18\x08 \x01(\x0b\x32%.spark.connect.MlCommand.GetCacheInfoH\x00R\x0cgetCacheInfo\x12O\n\x0e\x63reate_summary\x18\t \x01(\x0b\x32&.spark.connect.MlCommand.CreateSummaryH\x00R\rcreateSummary\x12M\n\x0eget_model_size\x18\n \x01(\x0b\x32%.spark.connect.MlCommand.GetModelSizeH\x00R\x0cgetModelSize\x1a\xb2\x01\n\x03\x46it\x12\x37\n\testimator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorR\testimator\x12\x34\n\x06params\x18\x02 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x00R\x06params\x88\x01\x01\x12\x31\n\x07\x64\x61taset\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationR\x07\x64\x61tasetB\t\n\x07_params\x1ap\n\x06\x44\x65lete\x12\x33\n\x08obj_refs\x18\x01 \x03(\x0b\x32\x18.spark.connect.ObjectRefR\x07objRefs\x12"\n\nevict_only\x18\x02 \x01(\x08H\x00R\tevictOnly\x88\x01\x01\x42\r\n\x0b_evict_only\x1a\x0c\n\nCleanCache\x1a\x0e\n\x0cGetCacheInfo\x1a\x9a\x03\n\x05Write\x12\x37\n\x08operator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorH\x00R\x08operator\x12\x33\n\x07obj_ref\x18\x02 \x01(\x0b\x32\x18.spark.connect.ObjectRefH\x00R\x06objRef\x12\x34\n\x06params\x18\x03 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x01R\x06params\x88\x01\x01\x12\x12\n\x04path\x18\x04 \x01(\tR\x04path\x12.\n\x10should_overwrite\x18\x05 \x01(\x08H\x02R\x0fshouldOverwrite\x88\x01\x01\x12\x45\n\x07options\x18\x06 \x03(\x0b\x32+.spark.connect.MlCommand.Write.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x06\n\x04typeB\t\n\x07_paramsB\x13\n\x11_should_overwrite\x1aQ\n\x04Read\x12\x35\n\x08operator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorR\x08operator\x12\x12\n\x04path\x18\x02 \x01(\tR\x04path\x1a\xb7\x01\n\x08\x45valuate\x12\x37\n\tevaluator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorR\tevaluator\x12\x34\n\x06params\x18\x02 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x00R\x06params\x88\x01\x01\x12\x31\n\x07\x64\x61taset\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationR\x07\x64\x61tasetB\t\n\x07_params\x1ay\n\rCreateSummary\x12\x35\n\tmodel_ref\x18\x01 \x01(\x0b\x32\x18.spark.connect.ObjectRefR\x08modelRef\x12\x31\n\x07\x64\x61taset\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x07\x64\x61taset\x1a\x45\n\x0cGetModelSize\x12\x35\n\tmodel_ref\x18\x01 \x01(\x0b\x32\x18.spark.connect.ObjectRefR\x08modelRefB\t\n\x07\x63ommand"\xd5\x03\n\x0fMlCommandResult\x12\x39\n\x05param\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x05param\x12\x1a\n\x07summary\x18\x02 \x01(\tH\x00R\x07summary\x12T\n\roperator_info\x18\x03 \x01(\x0b\x32-.spark.connect.MlCommandResult.MlOperatorInfoH\x00R\x0coperatorInfo\x1a\x85\x02\n\x0eMlOperatorInfo\x12\x33\n\x07obj_ref\x18\x01 \x01(\x0b\x32\x18.spark.connect.ObjectRefH\x00R\x06objRef\x12\x14\n\x04name\x18\x02 \x01(\tH\x00R\x04name\x12\x15\n\x03uid\x18\x03 \x01(\tH\x01R\x03uid\x88\x01\x01\x12\x34\n\x06params\x18\x04 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x02R\x06params\x88\x01\x01\x12,\n\x0fwarning_message\x18\x05 \x01(\tH\x03R\x0ewarningMessage\x88\x01\x01\x42\x06\n\x04typeB\x06\n\x04_uidB\t\n\x07_paramsB\x12\n\x10_warning_messageB\r\n\x0bresult_typeB6\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3'
)

_globals = globals()
Expand All @@ -54,27 +54,29 @@
_globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._loaded_options = None
_globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_options = b"8\001"
_globals["_MLCOMMAND"]._serialized_start = 137
_globals["_MLCOMMAND"]._serialized_end = 1850
_globals["_MLCOMMAND_FIT"]._serialized_start = 712
_globals["_MLCOMMAND_FIT"]._serialized_end = 890
_globals["_MLCOMMAND_DELETE"]._serialized_start = 892
_globals["_MLCOMMAND_DELETE"]._serialized_end = 1004
_globals["_MLCOMMAND_CLEANCACHE"]._serialized_start = 1006
_globals["_MLCOMMAND_CLEANCACHE"]._serialized_end = 1018
_globals["_MLCOMMAND_GETCACHEINFO"]._serialized_start = 1020
_globals["_MLCOMMAND_GETCACHEINFO"]._serialized_end = 1034
_globals["_MLCOMMAND_WRITE"]._serialized_start = 1037
_globals["_MLCOMMAND_WRITE"]._serialized_end = 1447
_globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_start = 1349
_globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_end = 1407
_globals["_MLCOMMAND_READ"]._serialized_start = 1449
_globals["_MLCOMMAND_READ"]._serialized_end = 1530
_globals["_MLCOMMAND_EVALUATE"]._serialized_start = 1533
_globals["_MLCOMMAND_EVALUATE"]._serialized_end = 1716
_globals["_MLCOMMAND_CREATESUMMARY"]._serialized_start = 1718
_globals["_MLCOMMAND_CREATESUMMARY"]._serialized_end = 1839
_globals["_MLCOMMANDRESULT"]._serialized_start = 1853
_globals["_MLCOMMANDRESULT"]._serialized_end = 2322
_globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_start = 2046
_globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_end = 2307
_globals["_MLCOMMAND"]._serialized_end = 2000
_globals["_MLCOMMAND_FIT"]._serialized_start = 791
_globals["_MLCOMMAND_FIT"]._serialized_end = 969
_globals["_MLCOMMAND_DELETE"]._serialized_start = 971
_globals["_MLCOMMAND_DELETE"]._serialized_end = 1083
_globals["_MLCOMMAND_CLEANCACHE"]._serialized_start = 1085
_globals["_MLCOMMAND_CLEANCACHE"]._serialized_end = 1097
_globals["_MLCOMMAND_GETCACHEINFO"]._serialized_start = 1099
_globals["_MLCOMMAND_GETCACHEINFO"]._serialized_end = 1113
_globals["_MLCOMMAND_WRITE"]._serialized_start = 1116
_globals["_MLCOMMAND_WRITE"]._serialized_end = 1526
_globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_start = 1428
_globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_end = 1486
_globals["_MLCOMMAND_READ"]._serialized_start = 1528
_globals["_MLCOMMAND_READ"]._serialized_end = 1609
_globals["_MLCOMMAND_EVALUATE"]._serialized_start = 1612
_globals["_MLCOMMAND_EVALUATE"]._serialized_end = 1795
_globals["_MLCOMMAND_CREATESUMMARY"]._serialized_start = 1797
_globals["_MLCOMMAND_CREATESUMMARY"]._serialized_end = 1918
_globals["_MLCOMMAND_GETMODELSIZE"]._serialized_start = 1920
_globals["_MLCOMMAND_GETMODELSIZE"]._serialized_end = 1989
_globals["_MLCOMMANDRESULT"]._serialized_start = 2003
_globals["_MLCOMMANDRESULT"]._serialized_end = 2472
_globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_start = 2196
_globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_end = 2457
# @@protoc_insertion_point(module_scope)
29 changes: 29 additions & 0 deletions python/pyspark/sql/connect/proto/ml_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,26 @@ class MlCommand(google.protobuf.message.Message):
field_name: typing_extensions.Literal["dataset", b"dataset", "model_ref", b"model_ref"],
) -> None: ...

class GetModelSize(google.protobuf.message.Message):
"""This is for query the model estimated in-memory size"""

DESCRIPTOR: google.protobuf.descriptor.Descriptor

MODEL_REF_FIELD_NUMBER: builtins.int
@property
def model_ref(self) -> pyspark.sql.connect.proto.ml_common_pb2.ObjectRef: ...
def __init__(
self,
*,
model_ref: pyspark.sql.connect.proto.ml_common_pb2.ObjectRef | None = ...,
) -> None: ...
def HasField(
self, field_name: typing_extensions.Literal["model_ref", b"model_ref"]
) -> builtins.bool: ...
def ClearField(
self, field_name: typing_extensions.Literal["model_ref", b"model_ref"]
) -> None: ...

FIT_FIELD_NUMBER: builtins.int
FETCH_FIELD_NUMBER: builtins.int
DELETE_FIELD_NUMBER: builtins.int
Expand All @@ -397,6 +417,7 @@ class MlCommand(google.protobuf.message.Message):
CLEAN_CACHE_FIELD_NUMBER: builtins.int
GET_CACHE_INFO_FIELD_NUMBER: builtins.int
CREATE_SUMMARY_FIELD_NUMBER: builtins.int
GET_MODEL_SIZE_FIELD_NUMBER: builtins.int
@property
def fit(self) -> global___MlCommand.Fit: ...
@property
Expand All @@ -415,6 +436,8 @@ class MlCommand(google.protobuf.message.Message):
def get_cache_info(self) -> global___MlCommand.GetCacheInfo: ...
@property
def create_summary(self) -> global___MlCommand.CreateSummary: ...
@property
def get_model_size(self) -> global___MlCommand.GetModelSize: ...
def __init__(
self,
*,
Expand All @@ -427,6 +450,7 @@ class MlCommand(google.protobuf.message.Message):
clean_cache: global___MlCommand.CleanCache | None = ...,
get_cache_info: global___MlCommand.GetCacheInfo | None = ...,
create_summary: global___MlCommand.CreateSummary | None = ...,
get_model_size: global___MlCommand.GetModelSize | None = ...,
) -> None: ...
def HasField(
self,
Expand All @@ -447,6 +471,8 @@ class MlCommand(google.protobuf.message.Message):
b"fit",
"get_cache_info",
b"get_cache_info",
"get_model_size",
b"get_model_size",
"read",
b"read",
"write",
Expand All @@ -472,6 +498,8 @@ class MlCommand(google.protobuf.message.Message):
b"fit",
"get_cache_info",
b"get_cache_info",
"get_model_size",
b"get_model_size",
"read",
b"read",
"write",
Expand All @@ -491,6 +519,7 @@ class MlCommand(google.protobuf.message.Message):
"clean_cache",
"get_cache_info",
"create_summary",
"get_model_size",
]
| None
): ...
Expand Down
6 changes: 6 additions & 0 deletions sql/connect/common/src/main/protobuf/spark/connect/ml.proto
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ message MlCommand {
CleanCache clean_cache = 7;
GetCacheInfo get_cache_info = 8;
CreateSummary create_summary = 9;
GetModelSize get_model_size = 10;
}

// Command for estimator.fit(dataset)
Expand Down Expand Up @@ -109,6 +110,11 @@ message MlCommand {
ObjectRef model_ref = 1;
Relation dataset = 2;
}

// This is for query the model estimated in-memory size
message GetModelSize {
ObjectRef model_ref = 1;
}
}

// The result of MlCommand
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging {
def clear(): Int = this.synchronized {
val size = cachedModel.size()
cachedModel.clear()
totalMLCacheSizeBytes.set(0L)
if (getMemoryControlEnabled) {
SparkFileUtils.cleanDirectory(new File(offloadedModelsDir.toString))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,15 @@ private[connect] object MLHandler extends Logging {
val createSummaryCmd = mlCommand.getCreateSummary
createModelSummary(sessionHolder, createSummaryCmd)

case proto.MlCommand.CommandCase.GET_MODEL_SIZE =>
val modelRefId = mlCommand.getGetModelSize.getModelRef.getId
val model = mlCache.get(modelRefId)
val modelSize = model.asInstanceOf[Model[_]].estimatedSize
proto.MlCommandResult
.newBuilder()
.setParam(LiteralValueProtoConverter.toLiteralProto(modelSize))
.build()

case other => throw MlUnsupportedException(s"$other not supported")
}
}
Expand Down