diff --git a/src/sagemaker_core/main/resources.py b/src/sagemaker_core/main/resources.py index 58e4471c..3d85563a 100644 --- a/src/sagemaker_core/main/resources.py +++ b/src/sagemaker_core/main/resources.py @@ -16518,6 +16518,7 @@ def wrapper(*args, **kwargs): "s3_data_source": { "s3_uri": {"type": "string"}, "s3_data_type": {"type": "string"}, + "manifest_s3_uri": {"type": "string"}, } } }, @@ -27953,6 +27954,55 @@ def batch_put_metrics( response = client.batch_put_metrics(**operation_input_args) logger.debug(f"Response: {response}") + @classmethod + @Base.add_validate_call + def batch_get_metrics( + cls, + metric_queries: List[MetricQuery], + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> Optional[BatchGetMetricsResponse]: + """ + Used to retrieve training metrics from SageMaker. + + Parameters: + metric_queries: Queries made to retrieve training metrics from SageMaker. + session: Boto3 session. + region: Region name. + + Returns: + BatchGetMetricsResponse + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + """ + + operation_input_args = { + "MetricQueries": metric_queries, + } + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker-metrics" + ) + + logger.debug(f"Calling batch_get_metrics API") + response = client.batch_get_metrics(**operation_input_args) + logger.debug(f"Response: {response}") + + transformed_response = transform(response, "BatchGetMetricsResponse") + return BatchGetMetricsResponse(**transformed_response) + class UserProfile(Base): """ diff --git a/src/sagemaker_core/main/shapes.py b/src/sagemaker_core/main/shapes.py index 7c8a5ed1..a32b7438 100644 --- a/src/sagemaker_core/main/shapes.py +++ b/src/sagemaker_core/main/shapes.py @@ -406,6 +406,18 @@ class MetricQueryResult(Base): message: Optional[str] = Unassigned() +class BatchGetMetricsResponse(Base): + """ + BatchGetMetricsResponse + + Attributes + ---------------------- + metric_query_results: The results of a query to retrieve training metrics from SageMaker. + """ + + metric_query_results: Optional[List[MetricQueryResult]] = Unassigned() + + class BatchPutMetricsError(Base): """ BatchPutMetricsError diff --git a/src/sagemaker_core/tools/additional_operations.json b/src/sagemaker_core/tools/additional_operations.json index b910921d..9c7a8f7d 100644 --- a/src/sagemaker_core/tools/additional_operations.json +++ b/src/sagemaker_core/tools/additional_operations.json @@ -169,6 +169,14 @@ "return_type": "None", "method_type": "object", "service_name": "sagemaker-metrics" + }, + "BatchGetMetrics": { + "operation_name": "BatchGetMetrics", + "resource_name": "TrialComponent", + "method_name": "batch_get_metrics", + "return_type": "BatchGetMetricsResponse", + "method_type": "class", + "service_name": "sagemaker-metrics" } }, "HubContent": { diff --git a/src/sagemaker_core/tools/api_coverage.json b/src/sagemaker_core/tools/api_coverage.json index b1924b4a..962a4ef2 100644 --- a/src/sagemaker_core/tools/api_coverage.json +++ b/src/sagemaker_core/tools/api_coverage.json @@ -1 +1 @@ -{"SupportedAPIs": 338, "UnsupportedAPIs": 6} \ No newline at end of file +{"SupportedAPIs": 339, "UnsupportedAPIs": 5} \ No newline at end of file