Skip to content
2 changes: 1 addition & 1 deletion src/sagemaker/serve/builder/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing,
in order for model builder to build the artifacts correctly (according
to the model server). Possible values for this argument are
``TORCHSERVE``, ``MMS``, ``TENSORFLOW_SERVING``, ``DJL_SERVING``,
``TRITON``, and``TGI``.
``TRITON``,``TGI``, and ``TEI``.
model_metadata (Optional[Dict[str, Any]): Dictionary used to override model metadata.
Currently, ``HF_TASK`` is overridable for HuggingFace model. HF_TASK should be set for
new models without task metadata in the Hub, adding unsupported task types will throw
Expand Down
18 changes: 10 additions & 8 deletions src/sagemaker/serve/builder/tei_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
_get_nb_instance,
)
from sagemaker.serve.model_server.tgi.prepare import _create_dir_structure
from sagemaker.serve.utils.predictors import TgiLocalModePredictor
from sagemaker.serve.utils.predictors import TeiLocalModePredictor
from sagemaker.serve.utils.types import ModelServer
from sagemaker.serve.mode.function_pointers import Mode
from sagemaker.serve.utils.telemetry_logger import _capture_telemetry
Expand Down Expand Up @@ -74,16 +74,16 @@ def _prepare_for_mode(self):
def _get_client_translators(self):
"""Placeholder docstring"""

def _set_to_tgi(self):
def _set_to_tei(self):
"""Placeholder docstring"""
if self.model_server != ModelServer.TGI:
if self.model_server != ModelServer.TEI:
messaging = (
"HuggingFace Model ID support on model server: "
f"{self.model_server} is not currently supported. "
f"Defaulting to {ModelServer.TGI}"
f"Defaulting to {ModelServer.TEI}"
)
logger.warning(messaging)
self.model_server = ModelServer.TGI
self.model_server = ModelServer.TEI

def _create_tei_model(self, **kwargs) -> Type[Model]:
"""Placeholder docstring"""
Expand Down Expand Up @@ -142,7 +142,7 @@ def _tei_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa
if self.mode == Mode.LOCAL_CONTAINER:
timeout = kwargs.get("model_data_download_timeout")

predictor = TgiLocalModePredictor(
predictor = TeiLocalModePredictor(
self.modes[str(Mode.LOCAL_CONTAINER)], serializer, deserializer
)

Expand Down Expand Up @@ -180,7 +180,9 @@ def _tei_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa
if "endpoint_logging" not in kwargs:
kwargs["endpoint_logging"] = True

if not self.nb_instance_type and "instance_type" not in kwargs:
if self.nb_instance_type and "instance_type" not in kwargs:
kwargs.update({"instance_type": self.nb_instance_type})
elif not self.nb_instance_type and "instance_type" not in kwargs:
raise ValueError(
"Instance type must be provided when deploying " "to SageMaker Endpoint mode."
)
Expand Down Expand Up @@ -216,7 +218,7 @@ def _build_for_tei(self):
"""Placeholder docstring"""
self.secret_key = None

self._set_to_tgi()
self._set_to_tei()

self.pysdk_model = self._build_for_hf_tei()
return self.pysdk_model
15 changes: 15 additions & 0 deletions src/sagemaker/serve/mode/local_container_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from sagemaker.serve.model_server.djl_serving.server import LocalDJLServing
from sagemaker.serve.model_server.triton.server import LocalTritonServer
from sagemaker.serve.model_server.tgi.server import LocalTgiServing
from sagemaker.serve.model_server.tei.server import LocalTeiServing
from sagemaker.serve.model_server.multi_model_server.server import LocalMultiModelServer
from sagemaker.session import Session

Expand Down Expand Up @@ -69,6 +70,7 @@ def __init__(
self.container = None
self.secret_key = None
self._ping_container = None
self._invoke_serving = None

def load(self, model_path: str = None):
"""Placeholder docstring"""
Expand Down Expand Up @@ -156,6 +158,19 @@ def create_server(
env_vars=env_vars if env_vars else self.env_vars,
)
self._ping_container = self._tensorflow_serving_deep_ping
elif self.model_server == ModelServer.TEI:
tei_serving = LocalTeiServing()
tei_serving._start_tei_serving(
client=self.client,
image=image,
model_path=model_path if model_path else self.model_path,
secret_key=secret_key,
env_vars=env_vars if env_vars else self.env_vars,
)
tei_serving.schema_builder = self.schema_builder
self.container = tei_serving.container
self._ping_container = tei_serving._tei_deep_ping
self._invoke_serving = tei_serving._invoke_tei_serving

# allow some time for container to be ready
time.sleep(10)
Expand Down
27 changes: 21 additions & 6 deletions src/sagemaker/serve/mode/sagemaker_endpoint_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging
from typing import Type

from sagemaker.serve.model_server.tei.server import SageMakerTeiServing
from sagemaker.serve.model_server.tensorflow_serving.server import SageMakerTensorflowServing
from sagemaker.session import Session
from sagemaker.serve.utils.types import ModelServer
Expand Down Expand Up @@ -37,6 +38,8 @@ def __init__(self, inference_spec: Type[InferenceSpec], model_server: ModelServe
self.inference_spec = inference_spec
self.model_server = model_server

self._tei_serving = SageMakerTeiServing()

def load(self, model_path: str):
"""Placeholder docstring"""
path = Path(model_path)
Expand Down Expand Up @@ -66,8 +69,9 @@ def prepare(
+ "session to be created or supply `sagemaker_session` into @serve.invoke."
) from e

upload_artifacts = None
if self.model_server == ModelServer.TORCHSERVE:
return self._upload_torchserve_artifacts(
upload_artifacts = self._upload_torchserve_artifacts(
model_path=model_path,
sagemaker_session=sagemaker_session,
secret_key=secret_key,
Expand All @@ -76,7 +80,7 @@ def prepare(
)

if self.model_server == ModelServer.TRITON:
return self._upload_triton_artifacts(
upload_artifacts = self._upload_triton_artifacts(
model_path=model_path,
sagemaker_session=sagemaker_session,
secret_key=secret_key,
Expand All @@ -85,15 +89,15 @@ def prepare(
)

if self.model_server == ModelServer.DJL_SERVING:
return self._upload_djl_artifacts(
upload_artifacts = self._upload_djl_artifacts(
model_path=model_path,
sagemaker_session=sagemaker_session,
s3_model_data_url=s3_model_data_url,
image=image,
)

if self.model_server == ModelServer.TGI:
return self._upload_tgi_artifacts(
upload_artifacts = self._upload_tgi_artifacts(
model_path=model_path,
sagemaker_session=sagemaker_session,
s3_model_data_url=s3_model_data_url,
Expand All @@ -102,20 +106,31 @@ def prepare(
)

if self.model_server == ModelServer.MMS:
return self._upload_server_artifacts(
upload_artifacts = self._upload_server_artifacts(
model_path=model_path,
sagemaker_session=sagemaker_session,
s3_model_data_url=s3_model_data_url,
image=image,
)

if self.model_server == ModelServer.TENSORFLOW_SERVING:
return self._upload_tensorflow_serving_artifacts(
upload_artifacts = self._upload_tensorflow_serving_artifacts(
model_path=model_path,
sagemaker_session=sagemaker_session,
secret_key=secret_key,
s3_model_data_url=s3_model_data_url,
image=image,
)

if self.model_server == ModelServer.TEI:
upload_artifacts = self._tei_serving._upload_tei_artifacts(
model_path=model_path,
sagemaker_session=sagemaker_session,
s3_model_data_url=s3_model_data_url,
image=image,
)

if upload_artifacts:
return upload_artifacts

raise ValueError("%s model server is not supported" % self.model_server)
Empty file.
160 changes: 160 additions & 0 deletions src/sagemaker/serve/model_server/tei/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
"""Module for Local TEI Serving"""

from __future__ import absolute_import

import requests
import logging
from pathlib import Path
from docker.types import DeviceRequest
from sagemaker import Session, fw_utils
from sagemaker.serve.utils.exceptions import LocalModelInvocationException
from sagemaker.base_predictor import PredictorBase
from sagemaker.s3_utils import determine_bucket_and_prefix, parse_s3_url, s3_path_join
from sagemaker.s3 import S3Uploader
from sagemaker.local.utils import get_docker_host


MODE_DIR_BINDING = "/opt/ml/model/"
_SHM_SIZE = "2G"
_DEFAULT_ENV_VARS = {
"TRANSFORMERS_CACHE": "/opt/ml/model/",
"HUGGINGFACE_HUB_CACHE": "/opt/ml/model/",
}

logger = logging.getLogger(__name__)


class LocalTeiServing:
"""LocalTeiServing class"""

def _start_tei_serving(
self, client: object, image: str, model_path: str, secret_key: str, env_vars: dict
):
"""Starts a local tei serving container.

Args:
client: Docker client
image: Image to use
model_path: Path to the model
secret_key: Secret key to use for authentication
env_vars: Environment variables to set
"""
if env_vars and secret_key:
env_vars["SAGEMAKER_SERVE_SECRET_KEY"] = secret_key

self.container = client.containers.run(
image,
shm_size=_SHM_SIZE,
device_requests=[DeviceRequest(count=-1, capabilities=[["gpu"]])],
network_mode="host",
detach=True,
auto_remove=True,
volumes={
Path(model_path).joinpath("code"): {
"bind": MODE_DIR_BINDING,
"mode": "rw",
},
},
environment=_update_env_vars(env_vars),
)

def _invoke_tei_serving(self, request: object, content_type: str, accept: str):
"""Invokes a local tei serving container.

Args:
request: Request to send
content_type: Content type to use
accept: Accept to use
"""
try:
response = requests.post(
f"http://{get_docker_host()}:8080/invocations",
data=request,
headers={"Content-Type": content_type, "Accept": accept},
timeout=600,
)
response.raise_for_status()
return response.content
except Exception as e:
raise Exception("Unable to send request to the local container server") from e

def _tei_deep_ping(self, predictor: PredictorBase):
"""Checks if the local tei serving container is up and running.

If the container is not up and running, it will raise an exception.
"""
response = None
try:
response = predictor.predict(self.schema_builder.sample_input)
return (True, response)
# pylint: disable=broad-except
except Exception as e:
if "422 Client Error: Unprocessable Entity for url" in str(e):
raise LocalModelInvocationException(str(e))
return (False, response)

return (True, response)


class SageMakerTeiServing:
"""SageMakerTeiServing class"""

def _upload_tei_artifacts(
self,
model_path: str,
sagemaker_session: Session,
s3_model_data_url: str = None,
image: str = None,
env_vars: dict = None,
):
"""Uploads the model artifacts to S3.

Args:
model_path: Path to the model
sagemaker_session: SageMaker session
s3_model_data_url: S3 model data URL
image: Image to use
env_vars: Environment variables to set
"""
if s3_model_data_url:
bucket, key_prefix = parse_s3_url(url=s3_model_data_url)
else:
bucket, key_prefix = None, None

code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image)

bucket, code_key_prefix = determine_bucket_and_prefix(
bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session
)

code_dir = Path(model_path).joinpath("code")

s3_location = s3_path_join("s3://", bucket, code_key_prefix, "code")

logger.debug("Uploading TEI Model Resources uncompressed to: %s", s3_location)

model_data_url = S3Uploader.upload(
str(code_dir),
s3_location,
None,
sagemaker_session,
)

model_data = {
"S3DataSource": {
"CompressionType": "None",
"S3DataType": "S3Prefix",
"S3Uri": model_data_url + "/",
}
}

return (model_data, _update_env_vars(env_vars))


def _update_env_vars(env_vars: dict) -> dict:
"""Placeholder docstring"""
updated_env_vars = {}
updated_env_vars.update(_DEFAULT_ENV_VARS)
if env_vars:
updated_env_vars.update(env_vars)
return updated_env_vars
Loading