|
48 | 48 | validate_model_id_and_get_type, |
49 | 49 | verify_model_region_and_return_specs, |
50 | 50 | get_jumpstart_configs, |
51 | | - extract_metrics_from_deployment_configs, |
| 51 | + get_metrics_from_deployment_configs, |
52 | 52 | ) |
53 | 53 | from sagemaker.jumpstart.constants import JUMPSTART_LOGGER |
54 | 54 | from sagemaker.jumpstart.enums import JumpStartModelType |
@@ -868,7 +868,7 @@ def _get_benchmarks_data(self, config_name: str) -> Dict[str, List[str]]: |
868 | 868 | Returns: |
869 | 869 | Dict[str, List[str]]: Deployment config benchmark data. |
870 | 870 | """ |
871 | | - return extract_metrics_from_deployment_configs( |
| 871 | + return get_metrics_from_deployment_configs( |
872 | 872 | self._deployment_configs, |
873 | 873 | config_name, |
874 | 874 | ) |
@@ -905,20 +905,29 @@ def _convert_to_deployment_config_metadata( |
905 | 905 | "default_inference_instance_type" |
906 | 906 | ) |
907 | 907 |
|
908 | | - instance_rate = get_instance_rate_per_hour( |
909 | | - instance_type=default_inference_instance_type, region=self.region |
910 | | - ) |
911 | | - |
912 | 908 | benchmark_metrics = ( |
913 | 909 | metadata_config.benchmark_metrics.get(default_inference_instance_type) |
914 | 910 | if metadata_config.benchmark_metrics is not None |
915 | 911 | else None |
916 | 912 | ) |
917 | | - if instance_rate is not None: |
918 | | - if benchmark_metrics is not None: |
919 | | - benchmark_metrics.append(JumpStartBenchmarkStat(instance_rate)) |
| 913 | + |
| 914 | + should_fetch_instance_rate_metric = True |
| 915 | + if benchmark_metrics is not None: |
| 916 | + for benchmark_metric in benchmark_metrics: |
| 917 | + if benchmark_metric.name.lower() == "instance rate": |
| 918 | + should_fetch_instance_rate_metric = False |
| 919 | + break |
| 920 | + |
| 921 | + if should_fetch_instance_rate_metric: |
| 922 | + instance_rate = get_instance_rate_per_hour( |
| 923 | + instance_type=default_inference_instance_type, region=self.region |
| 924 | + ) |
| 925 | + instance_rate_metric = JumpStartBenchmarkStat(instance_rate) |
| 926 | + |
| 927 | + if benchmark_metrics is None: |
| 928 | + benchmark_metrics = [instance_rate_metric] |
920 | 929 | else: |
921 | | - benchmark_metrics = [JumpStartBenchmarkStat(instance_rate)] |
| 930 | + benchmark_metrics.append(instance_rate_metric) |
922 | 931 |
|
923 | 932 | init_kwargs = get_init_kwargs( |
924 | 933 | model_id=self.model_id, |
|
0 commit comments