Skip to content

Commit 23c696d

Browse files
authored
Merge pull request vllm-project#20 from ROCm/Dockerfile_multistage_refactor
Dockerfile improvements: multistage
2 parents d4db2f9 + 34174f8 commit 23c696d

File tree

1 file changed

+147
-133
lines changed

1 file changed

+147
-133
lines changed

Dockerfile.rocm

Lines changed: 147 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,22 @@
11
# default base image
22
ARG BASE_IMAGE="rocm/pytorch:rocm6.1_ubuntu20.04_py3.9_pytorch_2.1.2"
33

4-
FROM $BASE_IMAGE
4+
ARG COMMON_WORKDIR=/app
5+
ARG BUILD_HIPBLASLT="1"
6+
ARG BUILD_RCCL="1"
7+
ARG BUILD_FA="1"
8+
ARG BUILD_CUPY="0"
9+
ARG BUILD_TRITON="1"
10+
11+
# -----------------------
12+
# vLLM base image
13+
FROM $BASE_IMAGE AS base
514
USER root
615

716
# Import BASE_IMAGE arg from pre-FROM
817
ARG BASE_IMAGE
918
RUN echo "Base image is $BASE_IMAGE"
10-
19+
ARG COMMON_WORKDIR
1120
# Used as ARCHes for all components
1221
ENV PYTORCH_ROCM_ARCH="gfx90a;gfx942"
1322
RUN echo "PYTORCH_ROCM_ARCH is $PYTORCH_ROCM_ARCH"
@@ -17,167 +26,172 @@ RUN apt-get update && apt-get install python3 python3-pip -
1726
RUN apt-get update && apt-get install -y \
1827
sqlite3 libsqlite3-dev libfmt-dev libmsgpack-dev libsuitesparse-dev
1928

20-
### Mount Point ###
21-
# When launching the container, mount the code directory to /app
22-
ARG APP_MOUNT=/app
23-
VOLUME [ ${APP_MOUNT} ]
24-
WORKDIR ${APP_MOUNT}
29+
ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer
30+
ENV PATH=$PATH:/opt/rocm/bin:/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/bin:
31+
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/lib:
32+
ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/include:/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/include/torch/csrc/api/include/:/opt/rocm/include/:
2533

34+
WORKDIR ${COMMON_WORKDIR}
2635

27-
ARG BUILD_HIPBLASLT="1"
36+
# -----------------------
37+
# hipBLASLt build stages
38+
FROM base AS build_hipblaslt
2839
ARG HIPBLASLT_BRANCH="ee51a9d1"
29-
30-
RUN if [ "$BUILD_HIPBLASLT" = "1" ]; then \
31-
echo "HIPBLASLT_BRANCH is $HIPBLASLT_BRANCH"; \
32-
fi
33-
# Build HipblasLt
34-
RUN if [ "$BUILD_HIPBLASLT" = "1" ] ; then \
35-
apt-get purge -y hipblaslt \
36-
&& mkdir -p libs \
37-
&& cd libs \
38-
&& git clone https://github.com/ROCm/hipBLASLt \
40+
RUN git clone https://github.com/ROCm/hipBLASLt \
3941
&& cd hipBLASLt \
4042
&& git checkout ${HIPBLASLT_BRANCH} \
41-
&& SCCACHE_IDLE_TIMEOUT=1800 ./install.sh -i --architecture ${PYTORCH_ROCM_ARCH} \
42-
&& cd .. && rm -rf hipBLASLt \
43-
&& sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
44-
&& sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status \
45-
&& cd ..; \
46-
fi
47-
48-
49-
RUN python3 -m pip install --upgrade pip && rm -rf /var/lib/apt/lists/*
50-
51-
52-
ARG BUILD_RCCL="1"
43+
&& SCCACHE_IDLE_TIMEOUT=1800 ./install.sh --architecture ${PYTORCH_ROCM_ARCH} \
44+
&& cd build/release \
45+
&& make package
46+
FROM scratch AS export_hipblaslt_1
47+
ARG COMMON_WORKDIR
48+
COPY --from=build_hipblaslt ${COMMON_WORKDIR}/hipBLASLt/build/release/*.deb /
49+
FROM scratch AS export_hipblaslt_0
50+
51+
# -----------------------
52+
# RCCL build stages
53+
FROM base AS build_rccl
5354
ARG RCCL_BRANCH="eeea3b6"
54-
55-
RUN if [ "$BUILD_RCCL" = "1" ]; then \
56-
echo "RCCL_BRANCH is $RCCL_BRANCH"; \
57-
fi
58-
# Install RCCL
59-
RUN if [ "$BUILD_RCCL" = "1" ]; then \
60-
mkdir -p libs \
61-
&& cd libs \
62-
&& git clone https://github.com/ROCm/rccl \
55+
RUN git clone https://github.com/ROCm/rccl \
6356
&& cd rccl \
6457
&& git checkout ${RCCL_BRANCH} \
65-
&& ./install.sh -i --amdgpu_targets ${PYTORCH_ROCM_ARCH} \
66-
&& cd .. \
67-
&& rm -r rccl \
68-
&& cd ..; \
69-
fi
70-
71-
72-
ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer
73-
ENV PATH=$PATH:/opt/rocm/bin:/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/bin:
74-
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/lib:
75-
ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/include:/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/include/torch/csrc/api/include/:/opt/rocm/include/:
76-
77-
78-
# whether to build flash-attention
79-
# if 0, will not build flash attention
80-
# this is useful for gfx target where flash-attention is not supported
81-
# In that case, we need to use the python reference attention implementation in vllm
82-
ARG BUILD_FA="1"
58+
&& ./install.sh --amdgpu_targets ${PYTORCH_ROCM_ARCH} \
59+
&& cd build/release \
60+
&& make package
61+
FROM scratch AS export_rccl_1
62+
ARG COMMON_WORKDIR
63+
COPY --from=build_rccl ${COMMON_WORKDIR}/rccl/build/release/*.deb /
64+
FROM scratch AS export_rccl_0
65+
66+
# -----------------------
67+
# flash attn build stages
68+
FROM base AS build_flash_attn
8369
ARG FA_BRANCH="ae7928c"
84-
85-
RUN if [ "$BUILD_FA" = "1" ]; then \
86-
echo "FA_BRANCH is $FA_BRANCH"; \
87-
fi
88-
# Install ROCm flash-attention
89-
RUN if [ "$BUILD_FA" = "1" ]; then \
90-
mkdir -p libs \
91-
&& cd libs \
92-
&& git clone https://github.com/ROCm/flash-attention.git \
70+
RUN git clone https://github.com/ROCm/flash-attention.git \
9371
&& cd flash-attention \
9472
&& git checkout ${FA_BRANCH} \
9573
&& git submodule update --init \
96-
&& GPU_ARCHS=${PYTORCH_ROCM_ARCH} python3 setup.py install \
97-
&& cd .. \
98-
&& rm -rf flash-attention \
99-
&& cd ..; \
100-
fi
74+
&& GPU_ARCHS=${PYTORCH_ROCM_ARCH} python3 setup.py bdist_wheel --dist-dir=dist
75+
FROM scratch AS export_flash_attn_1
76+
ARG COMMON_WORKDIR
77+
COPY --from=build_flash_attn ${COMMON_WORKDIR}/flash-attention/dist/*.whl /
78+
FROM scratch AS export_flash_attn_0
79+
80+
# -----------------------
81+
# CuPy build stages
82+
FROM base AS build_cupy
83+
ARG CUPY_BRANCH="hipgraph_enablement"
84+
RUN git clone https://github.com/ROCm/cupy.git \
85+
&& cd cupy \
86+
&& git checkout $CUPY_BRANCH \
87+
&& git submodule update --init --recursive \
88+
&& pip install mpi4py-mpich scipy==1.9.3 cython==0.29.* \
89+
&& CC=$MPI_HOME/bin/mpicc python -m pip install mpi4py \
90+
&& CUPY_INSTALL_USE_HIP=1 ROCM_HOME=/opt/rocm HCC_AMDGPU_TARGET=${PYTORCH_ROCM_ARCH} \
91+
python3 setup.py bdist_wheel --dist-dir=dist
92+
FROM build_cupy AS export_cupy_1
93+
ARG COMMON_WORKDIR
94+
COPY --from=build_cupy ${COMMON_WORKDIR}/cupy/dist/*.whl /
95+
FROM scratch AS export_cupy_0
96+
97+
# -----------------------
98+
# Triton build stages
99+
FROM base AS build_triton
100+
ARG TRITON_BRANCH="main"
101+
RUN git clone https://github.com/OpenAI/triton.git \
102+
&& cd triton \
103+
&& git checkout ${TRITON_BRANCH} \
104+
&& cd python \
105+
&& python3 setup.py bdist_wheel --dist-dir=dist
106+
FROM scratch AS export_triton_1
107+
ARG COMMON_WORKDIR
108+
COPY --from=build_triton ${COMMON_WORKDIR}/triton/python/dist/*.whl /
109+
FROM scratch AS export_triton_0
110+
111+
# -----------------------
112+
# vLLM (and gradlib) build stages
113+
FROM base AS build_vllm
114+
ARG COMMON_WORKDIR
115+
# To consider: Obtain vLLM via git clone
116+
COPY ./ ${COMMON_WORKDIR}/vllm
117+
# Build vLLM
118+
RUN cd vllm \
119+
&& python3 setup.py clean --all && python3 setup.py bdist_wheel --dist-dir=dist
120+
# Build gradlib
121+
RUN cd vllm/gradlib \
122+
&& python3 setup.py clean --all && python3 setup.py bdist_wheel --dist-dir=dist
123+
FROM scratch AS export_vllm
124+
ARG COMMON_WORKDIR
125+
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/dist/*.whl /
126+
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/gradlib/dist/*.whl /
127+
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/rocm_patch /rocm_patch
128+
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/requirements*.txt /
129+
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/patch_xformers.rocm.sh /
130+
131+
# -----------------------
132+
# Aliases to ensure we only use enabled components
133+
FROM export_hipblaslt_${BUILD_HIPBLASLT} AS export_hipblaslt
134+
FROM export_rccl_${BUILD_RCCL} AS export_rccl
135+
FROM export_flash_attn_${BUILD_FA} AS export_flash_attn
136+
FROM export_cupy_${BUILD_CUPY} AS export_cupy
137+
FROM export_triton_${BUILD_TRITON} AS export_triton
138+
139+
# -----------------------
140+
# Final vLLM image
141+
FROM base AS final
142+
ARG BASE_IMAGE
143+
ARG BUILD_FA
101144

145+
RUN python3 -m pip install --upgrade pip && rm -rf /var/lib/apt/lists/*
102146
# Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt.
103147
# Manually removed it so that later steps of numpy upgrade can continue
104148
RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.1_ubuntu20.04_py3.9_pytorch_2.1.2" ]; then \
105149
rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi
106150

107-
108-
# Whether to build CuPy. 0.3.3 <= vLLM < 0.4.0 might need it for HIPgraph.
109-
ARG BUILD_CUPY="0"
110-
ARG CUPY_BRANCH="hipgraph_enablement"
111-
112-
RUN if [ "$BUILD_CUPY" = "1" ]; then \
113-
echo "CUPY_BRANCH is $CUPY_BRANCH"; \
151+
RUN --mount=type=bind,from=export_hipblaslt,src=/,target=/install \
152+
if ls /install/*.deb; then \
153+
apt-get purge -y hipblaslt \
154+
&& dpkg -i /install/*.deb \
155+
&& sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
156+
&& sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \
114157
fi
115-
# Build cupy
116-
RUN if [ "$BUILD_CUPY" = "1" ]; then \
117-
mkdir -p libs \
118-
&& cd libs \
119-
&& git clone $CUPY_BRANCH --recursive https://github.com/ROCm/cupy.git \
120-
&& cd cupy \
121-
&& pip install mpi4py-mpich scipy==1.9.3 cython==0.29.* \
122-
&& CC=$MPI_HOME/bin/mpicc python -m pip install mpi4py \
123-
&& CUPY_INSTALL_USE_HIP=1 ROCM_HOME=/opt/rocm HCC_AMDGPU_TARGET=${PYTORCH_ROCM_ARCH} pip install . \
124-
&& cd .. \
125-
&& rm -rf cupy \
126-
&& cd ..; \
127-
fi
128-
129-
130-
# whether to build triton on rocm
131-
ARG BUILD_TRITON="1"
132-
ARG TRITON_BRANCH="main"
133158

134-
RUN if [ "$BUILD_TRITON" = "1" ]; then \
135-
echo "TRITON_BRANCH is $TRITON_BRANCH"; \
159+
RUN --mount=type=bind,from=export_rccl,src=/,target=/install \
160+
if ls /install/*.deb; then \
161+
dpkg -i /install/*.deb \
162+
&& sed -i 's/, rccl-dev \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status \
163+
&& sed -i 's/, rccl \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status; \
136164
fi
137-
# build triton
138-
RUN if [ "$BUILD_TRITON" = "1" ]; then \
139-
mkdir -p libs \
140-
&& cd libs \
141-
&& pip uninstall -y triton \
142-
&& git clone https://github.com/OpenAI/triton.git \
143-
&& cd triton \
144-
&& git checkout ${TRITON_BRANCH} \
145-
&& cd python \
146-
&& pip install . \
147-
&& cd ../.. \
148-
&& rm -rf triton \
149-
&& cd ..; \
165+
166+
RUN --mount=type=bind,from=export_flash_attn,src=/,target=/install \
167+
if ls /install/*.whl; then \
168+
pip install /install/*.whl; \
150169
fi
151170

171+
RUN --mount=type=bind,from=export_cupy,src=/,target=/install \
172+
if ls /install/*.whl; then \
173+
pip install /install/*.whl; \
174+
fi
152175

153-
COPY ./ /app/vllm
154-
# Fix HIP runtime on ROCm 6.1
155-
RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.1_ubuntu20.04_py3.9_pytorch_2.1.2" ]; then \
156-
cp /app/vllm/rocm_patch/libamdhip64.so.6 /opt/rocm-6.1.0/lib/libamdhip64.so.6; fi
176+
RUN --mount=type=bind,from=export_triton,src=/,target=/install \
177+
if ls /install/*.whl; then \
178+
pip install /install/*.whl; \
179+
fi
157180

158-
RUN python3 -m pip install --upgrade pip numba
181+
RUN python3 -m pip install --upgrade numba
159182
RUN python3 -m pip install xformers==0.0.23 --no-deps
160183

161-
# Install vLLM
162-
ARG VLLM_BUILD_MODE="install"
163-
# developer might choose to use "develop" mode. But for end-users, we should do an install mode.
164-
# the current "develop" mode has issues with ImportError: cannot import name '_custom_C' from 'vllm' (/app/vllm/vllm/__init__.py)
165-
RUN cd /app \
166-
&& cd vllm \
184+
# Install vLLM (and gradlib)
185+
RUN --mount=type=bind,from=export_vllm,src=/,target=/install \
186+
cd /install \
167187
&& pip install -U -r requirements-rocm.txt \
168188
&& if [ "$BUILD_FA" = "1" ]; then \
169-
bash patch_xformers.rocm.sh; fi \
189+
bash patch_xformers.rocm.sh; fi \
170190
&& if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ]; then \
171-
patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h /app/vllm/rocm_patch/rocm_bf16.patch; fi \
172-
&& python3 setup.py clean --all && python3 setup.py $VLLM_BUILD_MODE \
173-
&& cd ..
174-
175-
176-
# Install gradlib
177-
RUN cd /app/vllm/gradlib \
178-
&& pip install . \
179-
&& cd ../..
180-
191+
patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h rocm_patch/rocm_bf16.patch; fi \
192+
&& if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.1_ubuntu20.04_py3.9_pytorch_2.1.2" ]; then \
193+
cp rocm_patch/libamdhip64.so.6 /opt/rocm-6.1.0/lib/libamdhip64.so.6; fi \
194+
&& pip install *.whl
181195

182196
# Update Ray to latest version + set environment variable to ensure it works on TP > 1
183197
RUN python3 -m pip install --no-cache-dir 'ray[all]>=2.10.0'

0 commit comments

Comments
 (0)