11# default base image
22ARG BASE_IMAGE="rocm/pytorch:rocm6.1_ubuntu20.04_py3.9_pytorch_2.1.2"
33
4- FROM $BASE_IMAGE
4+ ARG COMMON_WORKDIR=/app
5+ ARG BUILD_HIPBLASLT="1"
6+ ARG BUILD_RCCL="1"
7+ ARG BUILD_FA="1"
8+ ARG BUILD_CUPY="0"
9+ ARG BUILD_TRITON="1"
10+
11+ # -----------------------
12+ # vLLM base image
13+ FROM $BASE_IMAGE AS base
514USER root
615
716# Import BASE_IMAGE arg from pre-FROM
817ARG BASE_IMAGE
918RUN echo "Base image is $BASE_IMAGE"
10-
19+ ARG COMMON_WORKDIR
1120# Used as ARCHes for all components
1221ENV PYTORCH_ROCM_ARCH="gfx90a;gfx942"
1322RUN echo "PYTORCH_ROCM_ARCH is $PYTORCH_ROCM_ARCH"
@@ -17,167 +26,172 @@ RUN apt-get update && apt-get install python3 python3-pip -
1726RUN apt-get update && apt-get install -y \
1827 sqlite3 libsqlite3-dev libfmt-dev libmsgpack-dev libsuitesparse-dev
1928
20- ### Mount Point ###
21- # When launching the container, mount the code directory to /app
22- ARG APP_MOUNT=/app
23- VOLUME [ ${APP_MOUNT} ]
24- WORKDIR ${APP_MOUNT}
29+ ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer
30+ ENV PATH=$PATH:/opt/rocm/bin:/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/bin:
31+ ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/lib:
32+ ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/include:/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/include/torch/csrc/api/include/:/opt/rocm/include/:
2533
34+ WORKDIR ${COMMON_WORKDIR}
2635
27- ARG BUILD_HIPBLASLT="1"
36+ # -----------------------
37+ # hipBLASLt build stages
38+ FROM base AS build_hipblaslt
2839ARG HIPBLASLT_BRANCH="ee51a9d1"
29-
30- RUN if [ "$BUILD_HIPBLASLT" = "1" ]; then \
31- echo "HIPBLASLT_BRANCH is $HIPBLASLT_BRANCH"; \
32- fi
33- # Build HipblasLt
34- RUN if [ "$BUILD_HIPBLASLT" = "1" ] ; then \
35- apt-get purge -y hipblaslt \
36- && mkdir -p libs \
37- && cd libs \
38- && git clone https://github.com/ROCm/hipBLASLt \
40+ RUN git clone https://github.com/ROCm/hipBLASLt \
3941 && cd hipBLASLt \
4042 && git checkout ${HIPBLASLT_BRANCH} \
41- && SCCACHE_IDLE_TIMEOUT=1800 ./install.sh -i --architecture ${PYTORCH_ROCM_ARCH} \
42- && cd .. && rm -rf hipBLASLt \
43- && sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
44- && sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status \
45- && cd ..; \
46- fi
47-
48-
49- RUN python3 -m pip install --upgrade pip && rm -rf /var/lib/apt/lists/*
50-
51-
52- ARG BUILD_RCCL="1"
43+ && SCCACHE_IDLE_TIMEOUT=1800 ./install.sh --architecture ${PYTORCH_ROCM_ARCH} \
44+ && cd build/release \
45+ && make package
46+ FROM scratch AS export_hipblaslt_1
47+ ARG COMMON_WORKDIR
48+ COPY --from=build_hipblaslt ${COMMON_WORKDIR}/hipBLASLt/build/release/*.deb /
49+ FROM scratch AS export_hipblaslt_0
50+
51+ # -----------------------
52+ # RCCL build stages
53+ FROM base AS build_rccl
5354ARG RCCL_BRANCH="eeea3b6"
54-
55- RUN if [ "$BUILD_RCCL" = "1" ]; then \
56- echo "RCCL_BRANCH is $RCCL_BRANCH"; \
57- fi
58- # Install RCCL
59- RUN if [ "$BUILD_RCCL" = "1" ]; then \
60- mkdir -p libs \
61- && cd libs \
62- && git clone https://github.com/ROCm/rccl \
55+ RUN git clone https://github.com/ROCm/rccl \
6356 && cd rccl \
6457 && git checkout ${RCCL_BRANCH} \
65- && ./install.sh -i --amdgpu_targets ${PYTORCH_ROCM_ARCH} \
66- && cd .. \
67- && rm -r rccl \
68- && cd ..; \
69- fi
70-
71-
72- ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer
73- ENV PATH=$PATH:/opt/rocm/bin:/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/bin:
74- ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/lib:
75- ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/include:/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/include/torch/csrc/api/include/:/opt/rocm/include/:
76-
77-
78- # whether to build flash-attention
79- # if 0, will not build flash attention
80- # this is useful for gfx target where flash-attention is not supported
81- # In that case, we need to use the python reference attention implementation in vllm
82- ARG BUILD_FA="1"
58+ && ./install.sh --amdgpu_targets ${PYTORCH_ROCM_ARCH} \
59+ && cd build/release \
60+ && make package
61+ FROM scratch AS export_rccl_1
62+ ARG COMMON_WORKDIR
63+ COPY --from=build_rccl ${COMMON_WORKDIR}/rccl/build/release/*.deb /
64+ FROM scratch AS export_rccl_0
65+
66+ # -----------------------
67+ # flash attn build stages
68+ FROM base AS build_flash_attn
8369ARG FA_BRANCH="ae7928c"
84-
85- RUN if [ "$BUILD_FA" = "1" ]; then \
86- echo "FA_BRANCH is $FA_BRANCH"; \
87- fi
88- # Install ROCm flash-attention
89- RUN if [ "$BUILD_FA" = "1" ]; then \
90- mkdir -p libs \
91- && cd libs \
92- && git clone https://github.com/ROCm/flash-attention.git \
70+ RUN git clone https://github.com/ROCm/flash-attention.git \
9371 && cd flash-attention \
9472 && git checkout ${FA_BRANCH} \
9573 && git submodule update --init \
96- && GPU_ARCHS=${PYTORCH_ROCM_ARCH} python3 setup.py install \
97- && cd .. \
98- && rm -rf flash-attention \
99- && cd ..; \
100- fi
74+ && GPU_ARCHS=${PYTORCH_ROCM_ARCH} python3 setup.py bdist_wheel --dist-dir=dist
75+ FROM scratch AS export_flash_attn_1
76+ ARG COMMON_WORKDIR
77+ COPY --from=build_flash_attn ${COMMON_WORKDIR}/flash-attention/dist/*.whl /
78+ FROM scratch AS export_flash_attn_0
79+
80+ # -----------------------
81+ # CuPy build stages
82+ FROM base AS build_cupy
83+ ARG CUPY_BRANCH="hipgraph_enablement"
84+ RUN git clone https://github.com/ROCm/cupy.git \
85+ && cd cupy \
86+ && git checkout $CUPY_BRANCH \
87+ && git submodule update --init --recursive \
88+ && pip install mpi4py-mpich scipy==1.9.3 cython==0.29.* \
89+ && CC=$MPI_HOME/bin/mpicc python -m pip install mpi4py \
90+ && CUPY_INSTALL_USE_HIP=1 ROCM_HOME=/opt/rocm HCC_AMDGPU_TARGET=${PYTORCH_ROCM_ARCH} \
91+ python3 setup.py bdist_wheel --dist-dir=dist
92+ FROM build_cupy AS export_cupy_1
93+ ARG COMMON_WORKDIR
94+ COPY --from=build_cupy ${COMMON_WORKDIR}/cupy/dist/*.whl /
95+ FROM scratch AS export_cupy_0
96+
97+ # -----------------------
98+ # Triton build stages
99+ FROM base AS build_triton
100+ ARG TRITON_BRANCH="main"
101+ RUN git clone https://github.com/OpenAI/triton.git \
102+ && cd triton \
103+ && git checkout ${TRITON_BRANCH} \
104+ && cd python \
105+ && python3 setup.py bdist_wheel --dist-dir=dist
106+ FROM scratch AS export_triton_1
107+ ARG COMMON_WORKDIR
108+ COPY --from=build_triton ${COMMON_WORKDIR}/triton/python/dist/*.whl /
109+ FROM scratch AS export_triton_0
110+
111+ # -----------------------
112+ # vLLM (and gradlib) build stages
113+ FROM base AS build_vllm
114+ ARG COMMON_WORKDIR
115+ # To consider: Obtain vLLM via git clone
116+ COPY ./ ${COMMON_WORKDIR}/vllm
117+ # Build vLLM
118+ RUN cd vllm \
119+ && python3 setup.py clean --all && python3 setup.py bdist_wheel --dist-dir=dist
120+ # Build gradlib
121+ RUN cd vllm/gradlib \
122+ && python3 setup.py clean --all && python3 setup.py bdist_wheel --dist-dir=dist
123+ FROM scratch AS export_vllm
124+ ARG COMMON_WORKDIR
125+ COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/dist/*.whl /
126+ COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/gradlib/dist/*.whl /
127+ COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/rocm_patch /rocm_patch
128+ COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/requirements*.txt /
129+ COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/patch_xformers.rocm.sh /
130+
131+ # -----------------------
132+ # Aliases to ensure we only use enabled components
133+ FROM export_hipblaslt_${BUILD_HIPBLASLT} AS export_hipblaslt
134+ FROM export_rccl_${BUILD_RCCL} AS export_rccl
135+ FROM export_flash_attn_${BUILD_FA} AS export_flash_attn
136+ FROM export_cupy_${BUILD_CUPY} AS export_cupy
137+ FROM export_triton_${BUILD_TRITON} AS export_triton
138+
139+ # -----------------------
140+ # Final vLLM image
141+ FROM base AS final
142+ ARG BASE_IMAGE
143+ ARG BUILD_FA
101144
145+ RUN python3 -m pip install --upgrade pip && rm -rf /var/lib/apt/lists/*
102146# Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt.
103147# Manually removed it so that later steps of numpy upgrade can continue
104148RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.1_ubuntu20.04_py3.9_pytorch_2.1.2" ]; then \
105149 rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi
106150
107-
108- # Whether to build CuPy. 0.3.3 <= vLLM < 0.4.0 might need it for HIPgraph.
109- ARG BUILD_CUPY="0"
110- ARG CUPY_BRANCH="hipgraph_enablement"
111-
112- RUN if [ "$BUILD_CUPY" = "1" ]; then \
113- echo "CUPY_BRANCH is $CUPY_BRANCH"; \
151+ RUN --mount=type=bind,from=export_hipblaslt,src=/,target=/install \
152+ if ls /install/*.deb; then \
153+ apt-get purge -y hipblaslt \
154+ && dpkg -i /install/*.deb \
155+ && sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
156+ && sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \
114157 fi
115- # Build cupy
116- RUN if [ "$BUILD_CUPY" = "1" ]; then \
117- mkdir -p libs \
118- && cd libs \
119- && git clone $CUPY_BRANCH --recursive https://github.com/ROCm/cupy.git \
120- && cd cupy \
121- && pip install mpi4py-mpich scipy==1.9.3 cython==0.29.* \
122- && CC=$MPI_HOME/bin/mpicc python -m pip install mpi4py \
123- && CUPY_INSTALL_USE_HIP=1 ROCM_HOME=/opt/rocm HCC_AMDGPU_TARGET=${PYTORCH_ROCM_ARCH} pip install . \
124- && cd .. \
125- && rm -rf cupy \
126- && cd ..; \
127- fi
128-
129-
130- # whether to build triton on rocm
131- ARG BUILD_TRITON="1"
132- ARG TRITON_BRANCH="main"
133158
134- RUN if [ "$BUILD_TRITON" = "1" ]; then \
135- echo "TRITON_BRANCH is $TRITON_BRANCH"; \
159+ RUN --mount=type=bind,from=export_rccl,src=/,target=/install \
160+ if ls /install/*.deb; then \
161+ dpkg -i /install/*.deb \
162+ && sed -i 's/, rccl-dev \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status \
163+ && sed -i 's/, rccl \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status; \
136164 fi
137- # build triton
138- RUN if [ "$BUILD_TRITON" = "1" ]; then \
139- mkdir -p libs \
140- && cd libs \
141- && pip uninstall -y triton \
142- && git clone https://github.com/OpenAI/triton.git \
143- && cd triton \
144- && git checkout ${TRITON_BRANCH} \
145- && cd python \
146- && pip install . \
147- && cd ../.. \
148- && rm -rf triton \
149- && cd ..; \
165+
166+ RUN --mount=type=bind,from=export_flash_attn,src=/,target=/install \
167+ if ls /install/*.whl; then \
168+ pip install /install/*.whl; \
150169 fi
151170
171+ RUN --mount=type=bind,from=export_cupy,src=/,target=/install \
172+ if ls /install/*.whl; then \
173+ pip install /install/*.whl; \
174+ fi
152175
153- COPY ./ /app/vllm
154- # Fix HIP runtime on ROCm 6.1
155- RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.1_ubuntu20.04_py3.9_pytorch_2.1.2" ]; then \
156- cp /app/vllm/rocm_patch/libamdhip64.so.6 /opt/rocm-6.1.0/lib/libamdhip64.so.6; fi
176+ RUN --mount=type=bind,from=export_triton,src=/,target=/install \
177+ if ls /install/*.whl; then \
178+ pip install /install/*.whl; \
179+ fi
157180
158- RUN python3 -m pip install --upgrade pip numba
181+ RUN python3 -m pip install --upgrade numba
159182RUN python3 -m pip install xformers==0.0.23 --no-deps
160183
161- # Install vLLM
162- ARG VLLM_BUILD_MODE="install"
163- # developer might choose to use "develop" mode. But for end-users, we should do an install mode.
164- # the current "develop" mode has issues with ImportError: cannot import name '_custom_C' from 'vllm' (/app/vllm/vllm/__init__.py)
165- RUN cd /app \
166- && cd vllm \
184+ # Install vLLM (and gradlib)
185+ RUN --mount=type=bind,from=export_vllm,src=/,target=/install \
186+ cd /install \
167187 && pip install -U -r requirements-rocm.txt \
168188 && if [ "$BUILD_FA" = "1" ]; then \
169- bash patch_xformers.rocm.sh; fi \
189+ bash patch_xformers.rocm.sh; fi \
170190 && if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ]; then \
171- patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h /app/vllm/rocm_patch/rocm_bf16.patch; fi \
172- && python3 setup.py clean --all && python3 setup.py $VLLM_BUILD_MODE \
173- && cd ..
174-
175-
176- # Install gradlib
177- RUN cd /app/vllm/gradlib \
178- && pip install . \
179- && cd ../..
180-
191+ patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h rocm_patch/rocm_bf16.patch; fi \
192+ && if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.1_ubuntu20.04_py3.9_pytorch_2.1.2" ]; then \
193+ cp rocm_patch/libamdhip64.so.6 /opt/rocm-6.1.0/lib/libamdhip64.so.6; fi \
194+ && pip install *.whl
181195
182196# Update Ray to latest version + set environment variable to ensure it works on TP > 1
183197RUN python3 -m pip install --no-cache-dir 'ray[all]>=2.10.0'
0 commit comments