1+ FROM lukemathwalker/cargo-chef:latest-rust-1.75 AS chef
2+ WORKDIR /usr/src
3+
4+ ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
5+
6+ FROM chef as planner
7+ COPY Cargo.toml Cargo.toml
8+ COPY rust-toolchain.toml rust-toolchain.toml
9+ COPY proto proto
10+ COPY benchmark benchmark
11+ COPY router router
12+ COPY launcher launcher
13+ RUN cargo chef prepare --recipe-path recipe.json
14+
15+ FROM chef AS builder
16+
17+ ARG GIT_SHA
18+ ARG DOCKER_LABEL
19+
20+ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
21+ curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
22+ unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
23+ unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
24+ rm -f $PROTOC_ZIP
25+
26+ COPY --from=planner /usr/src/recipe.json recipe.json
27+ RUN cargo chef cook --release --recipe-path recipe.json
28+
29+ COPY Cargo.toml Cargo.toml
30+ COPY rust-toolchain.toml rust-toolchain.toml
31+ COPY proto proto
32+ COPY benchmark benchmark
33+ COPY router router
34+ COPY launcher launcher
35+ RUN cargo build --release
36+
37+ # Python builder
38+ # Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
39+ FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 as pytorch-install
40+
41+ ARG PYTORCH_VERSION=2.1.1
42+ ARG PYTHON_VERSION=3.10
43+ # Keep in sync with `server/pyproject.toml
44+ ARG CUDA_VERSION=12.1
45+ ARG MAMBA_VERSION=23.3.1-1
46+ ARG CUDA_CHANNEL=nvidia
47+ ARG INSTALL_CHANNEL=pytorch
48+ # Automatically set by buildx
49+ ARG TARGETPLATFORM
50+
51+ ENV PATH /opt/conda/bin:$PATH
52+
53+ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
54+ build-essential \
55+ ca-certificates \
56+ ccache \
57+ curl \
58+ git && \
59+ rm -rf /var/lib/apt/lists/*
60+
61+ # Install conda
62+ # translating Docker's TARGETPLATFORM into mamba arches
63+ RUN case ${TARGETPLATFORM} in \
64+ "linux/arm64" ) MAMBA_ARCH=aarch64 ;; \
65+ *) MAMBA_ARCH=x86_64 ;; \
66+ esac && \
67+ curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh"
68+ RUN chmod +x ~/mambaforge.sh && \
69+ bash ~/mambaforge.sh -b -p /opt/conda && \
70+ rm ~/mambaforge.sh
71+
72+ # Install pytorch
73+ # On arm64 we exit with an error code
74+ RUN case ${TARGETPLATFORM} in \
75+ "linux/arm64" ) exit 1 ;; \
76+ *) /opt/conda/bin/conda update -y conda && \
77+ /opt/conda/bin/conda install -c "${INSTALL_CHANNEL}" -c "${CUDA_CHANNEL}" -y "python=${PYTHON_VERSION}" "pytorch=$PYTORCH_VERSION" "pytorch-cuda=$(echo $CUDA_VERSION | cut -d'.' -f 1-2)" ;; \
78+ esac && \
79+ /opt/conda/bin/conda clean -ya
80+
81+ # CUDA kernels builder image
82+ FROM pytorch-install as kernel-builder
83+
84+ ARG MAX_JOBS=8
85+
86+ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
87+ ninja-build cmake \
88+ && rm -rf /var/lib/apt/lists/*
89+
90+ # Build Flash Attention CUDA kernels
91+ FROM kernel-builder as flash-att-builder
92+
93+ WORKDIR /usr/src
94+
95+ COPY server/Makefile-flash-att Makefile
96+
97+ # Build specific version of flash attention
98+ RUN make build-flash-attention
99+
100+ # Build Flash Attention v2 CUDA kernels
101+ FROM kernel-builder as flash-att-v2-builder
102+
103+ WORKDIR /usr/src
104+
105+ COPY server/Makefile-flash-att-v2 Makefile
106+
107+ # Build specific version of flash attention v2
108+ RUN make build-flash-attention-v2-cuda
109+
110+ # Build Transformers exllama kernels
111+ FROM kernel-builder as exllama-kernels-builder
112+ WORKDIR /usr/src
113+ COPY server/exllama_kernels/ .
114+
115+ RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build
116+
117+ # Build Transformers exllama kernels
118+ FROM kernel-builder as exllamav2-kernels-builder
119+ WORKDIR /usr/src
120+ COPY server/exllamav2_kernels/ .
121+
122+ # Build specific version of transformers
123+ RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build
124+
125+ # Build Transformers awq kernels
126+ FROM kernel-builder as awq-kernels-builder
127+ WORKDIR /usr/src
128+ COPY server/Makefile-awq Makefile
129+ # Build specific version of transformers
130+ RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-awq
131+
132+ # Build eetq kernels
133+ FROM kernel-builder as eetq-kernels-builder
134+ WORKDIR /usr/src
135+ COPY server/Makefile-eetq Makefile
136+ # Build specific version of transformers
137+ RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-eetq
138+
139+ # Build Transformers CUDA kernels
140+ FROM kernel-builder as custom-kernels-builder
141+ WORKDIR /usr/src
142+ COPY server/custom_kernels/ .
143+ # Build specific version of transformers
144+ RUN python setup.py build
145+
146+ # Build vllm CUDA kernels
147+ FROM kernel-builder as vllm-builder
148+
149+ WORKDIR /usr/src
150+
151+ COPY server/Makefile-vllm Makefile
152+
153+ # Build specific version of vllm
154+ RUN make build-vllm-cuda
155+
156+ # Build mamba kernels
157+ FROM kernel-builder as mamba-builder
158+ WORKDIR /usr/src
159+ COPY server/Makefile-selective-scan Makefile
160+ RUN make build-all
161+
162+ # Text Generation Inference base image
163+ FROM nvidia/cuda:12.1.0-base-ubuntu22.04 as base
164+
165+ # Conda env
166+ ENV PATH=/opt/conda/bin:$PATH \
167+ CONDA_PREFIX=/opt/conda
168+
169+ # Text Generation Inference base env
170+ ENV HUGGINGFACE_HUB_CACHE=/tmp \
171+ HF_HUB_ENABLE_HF_TRANSFER=1 \
172+ PORT=80
173+
174+ WORKDIR /usr/src
175+
176+ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
177+ libssl-dev \
178+ ca-certificates \
179+ make \
180+ unzip \
181+ curl \
182+ && rm -rf /var/lib/apt/lists/*
183+
184+ # Copy conda with PyTorch installed
185+ COPY --from=pytorch-install /opt/conda /opt/conda
186+
187+ # Copy build artifacts from flash attention builder
188+ COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
189+ COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
190+ COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
191+
192+ # Copy build artifacts from flash attention v2 builder
193+ COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
194+
195+ # Copy build artifacts from custom kernels builder
196+ COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
197+ # Copy build artifacts from exllama kernels builder
198+ COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
199+ # Copy build artifacts from exllamav2 kernels builder
200+ COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
201+ # Copy build artifacts from awq kernels builder
202+ COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
203+ # Copy build artifacts from eetq kernels builder
204+ COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
205+
206+ # Copy builds artifacts from vllm builder
207+ COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
208+
209+ # Copy build artifacts from mamba builder
210+ COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
211+ COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
212+
213+ # Install vllm/flash-attention dependencies
214+ RUN pip install einops --no-cache-dir
215+
216+ # Install server
217+ COPY proto proto
218+ COPY server server
219+ COPY server/Makefile server/Makefile
220+ RUN cd server && \
221+ make gen-server && \
222+ pip install -r requirements_cuda.txt && \
223+ pip install ".[bnb, accelerate, quantize, peft, outlines]" --no-cache-dir
224+
225+ # Install benchmarker
226+ COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
227+ # Install router
228+ COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router
229+ # Install launcher
230+ COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
231+
232+ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
233+ build-essential \
234+ g++ \
235+ && rm -rf /var/lib/apt/lists/*
236+
237+ # AWS Sagemaker compatible image
238+ FROM base as sagemaker
239+
240+ COPY --chmod=775 sagemaker-entrypoint.sh entrypoint.sh
241+
242+ RUN HOME_DIR=/root && \
243+ pip install requests && \
244+ curl -o ${HOME_DIR}/oss_compliance.zip https://aws-dlinfra-utilities.s3.amazonaws.com/oss_compliance.zip && \
245+ unzip ${HOME_DIR}/oss_compliance.zip -d ${HOME_DIR}/ && \
246+ cp ${HOME_DIR}/oss_compliance/test/testOSSCompliance /usr/local/bin/testOSSCompliance && \
247+ chmod +x /usr/local/bin/testOSSCompliance && \
248+ chmod +x ${HOME_DIR}/oss_compliance/generate_oss_compliance.sh && \
249+ ${HOME_DIR}/oss_compliance/generate_oss_compliance.sh ${HOME_DIR} python && \
250+ rm -rf ${HOME_DIR}/oss_compliance*
251+ COPY /huggingface/pytorch/tgi/docker/2.0.1/THIRD-PARTY-LICENSES /root/THIRD-PARTY-LICENSES
252+
253+ RUN /opt/conda/bin/conda clean -py
254+
255+ ENTRYPOINT ["./entrypoint.sh" ]
256+ CMD ["--json-output" ]
257+
258+ LABEL dlc_major_version="2"
259+ LABEL com.amazonaws.ml.engines.sagemaker.dlc.framework.huggingface.tgi="true"
260+ LABEL com.amazonaws.sagemaker.capabilities.accept-bind-to-port="true"
0 commit comments