1+ # Rust builder
2+ FROM lukemathwalker/cargo-chef:latest-rust-1.75 AS chef
3+ WORKDIR /usr/src
4+
5+ ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
6+
7+ FROM chef as planner
8+ COPY Cargo.toml Cargo.toml
9+ COPY rust-toolchain.toml rust-toolchain.toml
10+ COPY proto proto
11+ COPY benchmark benchmark
12+ COPY router router
13+ COPY launcher launcher
14+ RUN cargo chef prepare --recipe-path recipe.json
15+
16+ FROM chef AS builder
17+
18+ ARG GIT_SHA
19+ ARG DOCKER_LABEL
20+
21+ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
22+ curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
23+ unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
24+ unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
25+ rm -f $PROTOC_ZIP
26+
27+ COPY --from=planner /usr/src/recipe.json recipe.json
28+ RUN cargo chef cook --release --recipe-path recipe.json
29+
30+ COPY Cargo.toml Cargo.toml
31+ COPY rust-toolchain.toml rust-toolchain.toml
32+ COPY proto proto
33+ COPY benchmark benchmark
34+ COPY router router
35+ COPY launcher launcher
36+ RUN cargo build --release
37+
38+ # Python builder
39+ # Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
40+ FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 as pytorch-install
41+
42+ ARG PYTORCH_VERSION=2.3.0
43+ ARG PYTHON_VERSION=3.10
44+ # Keep in sync with `server/pyproject.toml
45+ ARG CUDA_VERSION=12.1
46+ ARG MAMBA_VERSION=23.3.1-1
47+ ARG CUDA_CHANNEL=nvidia
48+ ARG INSTALL_CHANNEL=pytorch
49+ # Automatically set by buildx
50+ ARG TARGETPLATFORM
51+
52+ ENV PATH /opt/conda/bin:$PATH
53+
54+ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
55+ build-essential \
56+ ca-certificates \
57+ ccache \
58+ curl \
59+ git && \
60+ rm -rf /var/lib/apt/lists/*
61+
62+ # Install conda
63+ # translating Docker's TARGETPLATFORM into mamba arches
64+ RUN case ${TARGETPLATFORM} in \
65+ "linux/arm64" ) MAMBA_ARCH=aarch64 ;; \
66+ *) MAMBA_ARCH=x86_64 ;; \
67+ esac && \
68+ curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh"
69+ RUN chmod +x ~/mambaforge.sh && \
70+ bash ~/mambaforge.sh -b -p /opt/conda && \
71+ rm ~/mambaforge.sh
72+
73+ # Install pytorch
74+ # On arm64 we exit with an error code
75+ RUN case ${TARGETPLATFORM} in \
76+ "linux/arm64" ) exit 1 ;; \
77+ *) /opt/conda/bin/conda update -y conda && \
78+ /opt/conda/bin/conda install -c "${INSTALL_CHANNEL}" -c "${CUDA_CHANNEL}" -y "python=${PYTHON_VERSION}" "pytorch=$PYTORCH_VERSION" "pytorch-cuda=$(echo $CUDA_VERSION | cut -d'.' -f 1-2)" ;; \
79+ esac && \
80+ /opt/conda/bin/conda clean -ya
81+
82+ # CUDA kernels builder image
83+ FROM pytorch-install as kernel-builder
84+
85+ ARG MAX_JOBS=8
86+
87+ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
88+ ninja-build cmake \
89+ && rm -rf /var/lib/apt/lists/*
90+
91+ # Build Flash Attention CUDA kernels
92+ FROM kernel-builder as flash-att-builder
93+
94+ WORKDIR /usr/src
95+
96+ COPY server/Makefile-flash-att Makefile
97+
98+ # Build specific version of flash attention
99+ RUN make build-flash-attention
100+
101+ # Build Flash Attention v2 CUDA kernels
102+ FROM kernel-builder as flash-att-v2-builder
103+
104+ WORKDIR /usr/src
105+
106+ COPY server/Makefile-flash-att-v2 Makefile
107+
108+ # Build specific version of flash attention v2
109+ RUN make build-flash-attention-v2-cuda
110+
111+ # Build Transformers exllama kernels
112+ FROM kernel-builder as exllama-kernels-builder
113+ WORKDIR /usr/src
114+ COPY server/exllama_kernels/ .
115+
116+ RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build
117+
118+ # Build Transformers exllama kernels
119+ FROM kernel-builder as exllamav2-kernels-builder
120+ WORKDIR /usr/src
121+ COPY server/exllamav2_kernels/ .
122+
123+ # Build specific version of transformers
124+ RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build
125+
126+ # Build Transformers awq kernels
127+ FROM kernel-builder as awq-kernels-builder
128+ WORKDIR /usr/src
129+ COPY server/Makefile-awq Makefile
130+ # Build specific version of transformers
131+ RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-awq
132+
133+ # Build eetq kernels
134+ FROM kernel-builder as eetq-kernels-builder
135+ WORKDIR /usr/src
136+ COPY server/Makefile-eetq Makefile
137+ # Build specific version of transformers
138+ RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-eetq
139+
140+ # Build Transformers CUDA kernels
141+ FROM kernel-builder as custom-kernels-builder
142+ WORKDIR /usr/src
143+ COPY server/custom_kernels/ .
144+ # Build specific version of transformers
145+ RUN python setup.py build
146+
147+ # Build vllm CUDA kernels
148+ FROM kernel-builder as vllm-builder
149+
150+ WORKDIR /usr/src
151+
152+ ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.9 9.0+PTX"
153+
154+ COPY server/Makefile-vllm Makefile
155+
156+ # Build specific version of vllm
157+ RUN make build-vllm-cuda
158+
159+ # Build mamba kernels
160+ FROM kernel-builder as mamba-builder
161+ WORKDIR /usr/src
162+ COPY server/Makefile-selective-scan Makefile
163+ RUN make build-all
164+
165+ # Text Generation Inference base image
166+ FROM nvidia/cuda:12.1.0-base-ubuntu22.04 as base
167+
168+ # Conda env
169+ ENV PATH=/opt/conda/bin:$PATH \
170+ CONDA_PREFIX=/opt/conda
171+
172+ # Text Generation Inference base env
173+ ENV HUGGINGFACE_HUB_CACHE=/tmp \
174+ HF_HUB_ENABLE_HF_TRANSFER=1 \
175+ PORT=80
176+
177+ WORKDIR /usr/src
178+
179+ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
180+ libssl-dev \
181+ ca-certificates \
182+ make \
183+ unzip \
184+ curl \
185+ && rm -rf /var/lib/apt/lists/*
186+
187+ # Copy conda with PyTorch installed
188+ COPY --from=pytorch-install /opt/conda /opt/conda
189+
190+ # Copy build artifacts from flash attention builder
191+ COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
192+ COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
193+ COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
194+
195+ # Copy build artifacts from flash attention v2 builder
196+ COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
197+
198+ # Copy build artifacts from custom kernels builder
199+ COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
200+ # Copy build artifacts from exllama kernels builder
201+ COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
202+ # Copy build artifacts from exllamav2 kernels builder
203+ COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
204+ # Copy build artifacts from awq kernels builder
205+ COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
206+ # Copy build artifacts from eetq kernels builder
207+ COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
208+
209+ # Copy builds artifacts from vllm builder
210+ COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
211+
212+ # Copy build artifacts from mamba builder
213+ COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
214+ COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
215+
216+ # Install flash-attention dependencies
217+ RUN pip install einops --no-cache-dir
218+
219+ # Install server
220+ COPY proto proto
221+ COPY server server
222+ COPY server/Makefile server/Makefile
223+ RUN cd server && \
224+ make gen-server && \
225+ pip install -r requirements_cuda.txt && \
226+ pip install ".[bnb, accelerate, quantize, peft, outlines]" --no-cache-dir
227+
228+ # Install benchmarker
229+ COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
230+ # Install router
231+ COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router
232+ # Install launcher
233+ COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
234+
235+ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
236+ build-essential \
237+ g++ \
238+ && rm -rf /var/lib/apt/lists/*
239+
240+ # AWS Sagemaker compatible image
241+ FROM base as sagemaker
242+
243+ COPY --chmod=775 sagemaker-entrypoint.sh entrypoint.sh
244+
245+ RUN HOME_DIR=/root && \
246+ pip install requests && \
247+ curl -o ${HOME_DIR}/oss_compliance.zip https://aws-dlinfra-utilities.s3.amazonaws.com/oss_compliance.zip && \
248+ unzip ${HOME_DIR}/oss_compliance.zip -d ${HOME_DIR}/ && \
249+ cp ${HOME_DIR}/oss_compliance/test/testOSSCompliance /usr/local/bin/testOSSCompliance && \
250+ chmod +x /usr/local/bin/testOSSCompliance && \
251+ chmod +x ${HOME_DIR}/oss_compliance/generate_oss_compliance.sh && \
252+ ${HOME_DIR}/oss_compliance/generate_oss_compliance.sh ${HOME_DIR} python && \
253+ rm -rf ${HOME_DIR}/oss_compliance*
254+ COPY /huggingface/pytorch/tgi/docker/2.0.2/THIRD-PARTY-LICENSES /root/THIRD-PARTY-LICENSES
255+
256+ RUN /opt/conda/bin/conda clean -py
257+
258+ ENTRYPOINT ["./entrypoint.sh" ]
259+ CMD ["--json-output" ]
260+
261+ LABEL dlc_major_version="2"
262+ LABEL com.amazonaws.ml.engines.sagemaker.dlc.framework.huggingface.tgi="true"
263+ LABEL com.amazonaws.sagemaker.capabilities.accept-bind-to-port="true"
0 commit comments