Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions text-embeddings-inference/Dockerfile-hpu
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
FROM python:3.9 as tei
WORKDIR /tmp
RUN git clone https://github.com/huggingface/text-embeddings-inference.git && cd text-embeddings-inference/ && git checkout a059696a33f3b2cd28ce5e69d3195d5b03189d96

FROM lukemathwalker/cargo-chef:latest-rust-1.73-bookworm AS chef
WORKDIR /usr/src

ENV SCCACHE=0.5.4
ENV RUSTC_WRAPPER=/usr/local/bin/sccache

RUN curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache && \
chmod +x /usr/local/bin/sccache

FROM chef AS planner

COPY --from=tei /tmp/text-embeddings-inference/backends backends
COPY --from=tei /tmp/text-embeddings-inference/core core
COPY --from=tei /tmp/text-embeddings-inference/router router
COPY --from=tei /tmp/text-embeddings-inference/Cargo.toml Cargo.toml
COPY --from=tei /tmp/text-embeddings-inference/Cargo.lock Cargo.lock

RUN cargo chef prepare --recipe-path recipe.json

FROM chef AS builder

ARG GIT_SHA
ARG DOCKER_LABEL

ARG ACTIONS_CACHE_URL
ARG ACTIONS_RUNTIME_TOKEN
ARG SCCACHE_GHA_ENABLED

COPY --from=planner /usr/src/recipe.json recipe.json
RUN cargo chef cook --release --features python --no-default-features --recipe-path recipe.json && sccache -s

COPY --from=tei /tmp/text-embeddings-inference/backends backends
COPY --from=tei /tmp/text-embeddings-inference/core core
COPY --from=tei /tmp/text-embeddings-inference/router router
COPY --from=tei /tmp/text-embeddings-inference/Cargo.toml Cargo.toml
COPY --from=tei /tmp/text-embeddings-inference/Cargo.lock Cargo.lock

RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
rm -f $PROTOC_ZIP

FROM builder as http-builder

RUN cargo build --release --bin text-embeddings-router -F python -F http --no-default-features && sccache -s

FROM builder as grpc-builder

COPY --from=tei /tmp/text-embeddings-inference/proto proto

RUN cargo build --release --bin text-embeddings-router -F grpc -F python --no-default-features && sccache -s

FROM vault.habana.ai/gaudi-docker/1.14.0/ubuntu22.04/habanalabs/pytorch-installer-2.1.1:latest as base

ENV HUGGINGFACE_HUB_CACHE=/data \
PORT=80

WORKDIR /usr/src
COPY --from=tei /tmp/text-embeddings-inference/backends backends
COPY backends/python/server/text_embeddings_server/models/__init__.py backends/python/server/text_embeddings_server/models/__init__.py
COPY backends/python/server/pyproject.toml backends/python/server/pyproject.toml
COPY backends/python/server/requirements.txt backends/python/server/requirements.txt
RUN cd backends/python/server && \
make install

FROM base as grpc

COPY --from=grpc-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router

ENTRYPOINT ["text-embeddings-router"]
CMD ["--json-output"]

FROM base

COPY --from=http-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router

ENTRYPOINT ["text-embeddings-router"]
CMD ["--json-output"]
14 changes: 14 additions & 0 deletions text-embeddings-inference/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
## Docker build

```shell
docker build . -f Dockerfile-hpu -t text-embeddings-inference:hpu-0.6
```

## Docker run

```shell
model=sentence-transformers/all-MiniLM-L6-v2
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run

docker run --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --ipc=host -p 8080:80 -v $volume:/data text-embeddings-inference:hpu-0.6 --model-id $model --pooling cls
```
40 changes: 40 additions & 0 deletions text-embeddings-inference/backends/python/server/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
[tool.poetry]
name = "text-embeddings-server"
version = "0.1.0"
description = "Text Embeddings Python gRPC Server"
authors = ["Olivier Dehaene <olivier@huggingface.co>"]

[tool.poetry.scripts]
python-text-embeddings-server = 'text_embeddings_server.cli:app'

[tool.poetry.dependencies]
python = ">=3.9,<3.13"
protobuf = "^4.21.7"
grpcio = "^1.51.1"
grpcio-status = "^1.51.1"
grpcio-reflection = "^1.51.1"
grpc-interceptor = "^0.15.0"
typer = "^0.6.1"
safetensors = "^0.3.2"
loguru = "^0.6.0"
opentelemetry-api = "^1.15.0"
opentelemetry-exporter-otlp = "^1.15.0"
opentelemetry-instrumentation-grpc = "^0.36b0"

[tool.poetry.extras]

[tool.poetry.group.dev.dependencies]
grpcio-tools = "^1.51.1"
pytest = "^7.3.0"

[[tool.poetry.source]]
name = "pytorch-gpu-src"
url = "https://download.pytorch.org/whl/cu118"
priority = "explicit"

[tool.pytest.ini_options]
markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"]

[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
43 changes: 43 additions & 0 deletions text-embeddings-inference/backends/python/server/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.12.3 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2023.9.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.58.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.58.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.58.0 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.16.4 ; python_version >= "3.9" and python_version < "3.13"
idna==3.4 ; python_version >= "3.9" and python_version < "3.13"
jinja2==3.1.2 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
markupsafe==2.1.3 ; python_version >= "3.9" and python_version < "3.13"
mpmath==1.3.0 ; python_version >= "3.9" and python_version < "3.13"
networkx==3.1 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
packaging==23.1 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.24.3 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
setuptools==68.2.0 ; python_version >= "3.9" and python_version < "3.13"
sympy==1.12 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.35.0 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.0.4 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import torch

from loguru import logger
from pathlib import Path
from typing import Optional
from transformers import AutoConfig
from transformers.models.bert import BertConfig

from text_embeddings_server.models.model import Model
from text_embeddings_server.models.default_model import DefaultModel

__all__ = ["Model"]

HTCORE_AVAILABLE = True
try:
import habana_frameworks.torch.core as htcore
except ImportError as e:
logger.warning(f"Could not import htcore: {e}")
HTCORE_AVAILABLE = False

# Disable gradients
torch.set_grad_enabled(False)

FLASH_ATTENTION = True
try:
from text_embeddings_server.models.flash_bert import FlashBert
except ImportError as e:
logger.warning(f"Could not import Flash Attention enabled models: {e}")
FLASH_ATTENTION = False

if FLASH_ATTENTION:
__all__.append(FlashBert)


def get_model(model_path: Path, dtype: Optional[str]):
if dtype == "float32":
dtype = torch.float32
elif dtype == "float16":
dtype = torch.float16
elif dtype == "bfloat16":
dtype = torch.bfloat16
else:
raise RuntimeError(f"Unknown dtype {dtype}")

if torch.cuda.is_available():
device = torch.device("cuda")
elif HTCORE_AVAILABLE and torch.hpu.is_available():
device = torch.device("hpu")
else:
if dtype != torch.float32:
raise ValueError("CPU device only supports float32 dtype")
device = torch.device("cpu")

config = AutoConfig.from_pretrained(model_path)

if config.model_type == "bert":
config: BertConfig
if (
device.type == "cuda"
and config.position_embedding_type == "absolute"
and dtype in [torch.float16, torch.bfloat16]
and FLASH_ATTENTION
):
return FlashBert(model_path, device, dtype)
else:
return DefaultModel(model_path, device, dtype)

raise NotImplementedError