Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/sagemaker/jumpstart/factory/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
_retrieve_model_package_model_artifact_s3_uri,
)
from sagemaker.jumpstart.artifacts.resource_names import _retrieve_resource_name_base
from sagemaker.jumpstart.session_utils import get_model_info_from_training_job
from sagemaker.session import Session
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig
from sagemaker.base_deserializers import BaseDeserializer
Expand Down Expand Up @@ -815,7 +814,7 @@ def _add_config_name_to_kwargs(
config_name=kwargs.config_name,
)

if specs.training_configs and specs.training_configs.get_top_config_from_ranking().config_name:
if specs.training_configs and specs.training_configs.get_top_config_from_ranking():
kwargs.config_name = (
kwargs.config_name or specs.training_configs.get_top_config_from_ranking().config_name
)
Expand Down
21 changes: 7 additions & 14 deletions src/sagemaker/jumpstart/factory/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,13 +588,9 @@ def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSta
model_type=kwargs.model_type,
config_name=kwargs.config_name,
)
if (
specs.inference_configs
and specs.inference_configs.get_top_config_from_ranking().config_name
):
kwargs.config_name = (
kwargs.config_name or specs.inference_configs.get_top_config_from_ranking().config_name
)
if specs.inference_configs and specs.inference_configs.get_top_config_from_ranking():
default_config_name = specs.inference_configs.get_top_config_from_ranking().config_name
kwargs.config_name = kwargs.config_name or default_config_name

if not kwargs.config_name:
return kwargs
Expand All @@ -614,6 +610,7 @@ def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSta

return kwargs


def _add_config_name_to_deploy_kwargs(
kwargs: JumpStartModelDeployKwargs, training_config_name: Optional[str] = None
) -> JumpStartModelInitKwargs:
Expand Down Expand Up @@ -643,13 +640,9 @@ def _add_config_name_to_deploy_kwargs(
specs=specs, training_config_name=training_config_name
)

if (
specs.inference_configs
and specs.inference_configs.get_top_config_from_ranking().config_name
):
kwargs.config_name = (
kwargs.config_name or specs.inference_configs.get_top_config_from_ranking().config_name
)
if specs.inference_configs and specs.inference_configs.get_top_config_from_ranking():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems you just got rid of .config_name from the if-statement

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry extracted to default_config_name and removed the redundant check ..get_top_config_from_ranking() in the if clause. This is not needed since it fallbacks to None anyway

default_config_name = specs.inference_configs.get_top_config_from_ranking().config_name
kwargs.config_name = kwargs.config_name or default_config_name

return kwargs

Expand Down