Skip to content

Commit 939c77c

Browse files
committed
chore: always include inference script if available
1 parent 29243ae commit 939c77c

File tree

3 files changed

+6
-47
lines changed

3 files changed

+6
-47
lines changed

src/sagemaker/jumpstart/artifacts.py

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,6 @@ def _retrieve_model_uri(
177177
region: Optional[str] = None,
178178
tolerate_vulnerable_model: bool = False,
179179
tolerate_deprecated_model: bool = False,
180-
include_script: bool = False,
181180
):
182181
"""Retrieves the model artifact S3 URI for the model matching the given arguments.
183182
@@ -198,8 +197,6 @@ def _retrieve_model_uri(
198197
tolerate_deprecated_model (bool): True if deprecated versions of model
199198
specifications should be tolerated (exception not raised). If False, raises
200199
an exception if the version of the model is deprecated.
201-
include_script (bool): True if script artifact should be packaged alongside model
202-
tarball. (Default: False).
203200
Returns:
204201
str: the model artifact S3 URI for the corresponding model.
205202
@@ -221,24 +218,14 @@ def _retrieve_model_uri(
221218
tolerate_deprecated_model=tolerate_deprecated_model,
222219
)
223220

224-
error_msg_no_combined_artifact = (
225-
"No combined script and model tarball available "
226-
f"for {model_id} with version {model_version} for {model_scope}."
227-
)
228-
229221
if model_scope == JumpStartScriptScope.INFERENCE:
230-
if not include_script:
231-
model_artifact_key = model_specs.hosting_artifact_key
232-
else:
233-
model_artifact_key = getattr(model_specs, "hosting_prepacked_artifact_key", None)
234-
if model_artifact_key is None:
235-
raise ValueError(error_msg_no_combined_artifact)
222+
model_artifact_key = (
223+
getattr(model_specs, "hosting_prepacked_artifact_key", None)
224+
or model_specs.hosting_artifact_key
225+
)
236226

237227
elif model_scope == JumpStartScriptScope.TRAINING:
238-
if not include_script:
239-
model_artifact_key = model_specs.training_artifact_key
240-
else:
241-
raise ValueError(error_msg_no_combined_artifact)
228+
model_artifact_key = model_specs.training_artifact_key
242229

243230
bucket = os.environ.get(
244231
ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE

src/sagemaker/model_uris.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ def retrieve(
3030
model_scope: Optional[str] = None,
3131
tolerate_vulnerable_model: bool = False,
3232
tolerate_deprecated_model: bool = False,
33-
include_script: bool = False,
3433
) -> str:
3534
"""Retrieves the model artifact Amazon S3 URI for the model matching the given arguments.
3635
@@ -49,8 +48,6 @@ def retrieve(
4948
tolerate_deprecated_model (bool): ``True`` if deprecated versions of model
5049
specifications should be tolerated without raising an exception. If ``False``, raises
5150
an exception if the version of the model is deprecated. (Default: False).
52-
include_script (bool): True if script artifact should be packaged alongside model
53-
tarball. (Default: False).
5451
Returns:
5552
str: The model artifact S3 URI for the corresponding model.
5653
@@ -71,5 +68,4 @@ def retrieve(
7168
region,
7269
tolerate_vulnerable_model,
7370
tolerate_deprecated_model,
74-
include_script,
7571
)

tests/unit/sagemaker/model_uris/jumpstart/test_combined_artifact.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,8 @@
1515
from mock.mock import patch
1616

1717
from sagemaker import model_uris
18-
import pytest
1918

20-
from tests.unit.sagemaker.jumpstart.utils import get_prototype_model_spec, get_special_model_spec
19+
from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec
2120

2221

2322
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
@@ -32,31 +31,8 @@ def test_jumpstart_combined_artifacts(patched_get_model_specs):
3231
model_scope="inference",
3332
model_id=model_id_combined_model_artifact,
3433
model_version="*",
35-
include_script=True,
3634
)
3735
assert (
3836
uri == "s3://jumpstart-cache-prod-us-west-2/huggingface-infer/"
3937
"prepack/v1.0.0/infer-prepack-huggingface-text2text-flan-t5-xxl-fp16.tar.gz"
4038
)
41-
42-
with pytest.raises(ValueError):
43-
model_uris.retrieve(
44-
region="us-west-2",
45-
model_scope="training",
46-
model_id=model_id_combined_model_artifact,
47-
model_version="*",
48-
include_script=True,
49-
)
50-
51-
patched_get_model_specs.side_effect = get_prototype_model_spec
52-
53-
model_id_combined_model_artifact_unsupported = "xgboost-classification-model"
54-
55-
with pytest.raises(ValueError):
56-
model_uris.retrieve(
57-
region="us-west-2",
58-
model_scope="inference",
59-
model_id=model_id_combined_model_artifact_unsupported,
60-
model_version="*",
61-
include_script=True,
62-
)

0 commit comments

Comments
 (0)