Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions csrc/quantization.cpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
#include <torch/extension.h>

torch::Tensor awq_gemm(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int split_k_iters);

#ifndef USE_ROCM
torch::Tensor awq_gemm(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int split_k_iters);
#endif

void squeezellm_gemm(
torch::Tensor vec,
Expand All @@ -14,6 +17,8 @@ void squeezellm_gemm(
torch::Tensor lookup_table);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#ifndef USE_ROCM
m.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
#endif
m.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
}
104 changes: 104 additions & 0 deletions docs/source/getting_started/amd-installation.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
.. _installation:

Installation with ROCm
============

vLLM-ROCm is here! Currently it is supporting llama-2.

Requirements
------------

* OS: Linux
* Python: 3.8 -- 3.11 (Recommended 3.10 as this is the version that has been tested on.)
* GPU: MI210
* Pytorch 2.0.1/2.1.1
* ROCm 5.7


Install with pip
----------------

You can install vLLM using pip:

.. code-block:: console

$ # (Optional) Create a new conda environment.
$ conda create -n myenv python=3.8 -y
$ conda activate myenv

$ # Install vLLM with CUDA 12.1.
$ pip install vllm

.. note::

As of now, vLLM's binaries are compiled on CUDA 12.1 by default.
However, you can install vLLM with CUDA 11.8 by running:

.. code-block:: console

$ # Install vLLM with CUDA 11.8.
$ # Replace `cp310` with your Python version (e.g., `cp38`, `cp39`, `cp311`).
$ pip install https://github.com/vllm-project/vllm/releases/download/v0.2.2/vllm-0.2.2+cu118-cp310-cp310-manylinux1_x86_64.whl

$ # Re-install PyTorch with CUDA 11.8.
$ pip uninstall torch -y
$ pip install torch --upgrade --index-url https://download.pytorch.org/whl/cu118


.. _build_from_source:

Build from source with docker
-----------------

You can also build and install vLLM from source:

Build a docker image from `rocm.Dockerfile`, and launch a docker container.

.. code-block:: console

$ docker build -f rocm.Dockerfile -t vllm-rocm .
$ docker run -it \
--network=host \
--group-add=video \
--ipc=host \
--cap-add=SYS_PTRACE \
--security-opt seccomp=unconfined \
--shm-size 8G \
--device /dev/kfd \
--device /dev/dri \
-v <path/to/model>:/app/hf_model \
vllm-rocm \
bash

If you are going to setup on new pytorch+rocm5.7 docker container, you can follow the following steps.

1. Install flash-attention-2-rocm

If you are using Pytorch-2.0.1+rocm5.7.

Install flash-attention-2 (v2.0.4) following the instruction from [ROCmSoftwarePlatform/flash-attention](https://github.com/ROCmSoftwarePlatform/flash-attention/tree/flash_attention_for_rocm)


If you are using Pytorch-2.1.x+rocm5.7 or Pytorch-2.2.x+rocm5.7, you don't need to apply the `hipify_python.patch`.
You can directly build the flash-attention-2.

.. code-block:: console

$ bash patch_torch211_flash_attn2.rocm.sh

.. note::
- Flash-attention-2 (v2.0.4) does not support sliding windows attention.
- You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`)

2. Setup xformers==0.0.22.post7 without dependencies, and apply patches

.. code-block:: console

$ pip install xformers==0.0.22.post7 --no-deps
$ bash patch_xformers-0.0.22.post7.rocm.sh

3. Build vllm.

.. code-block:: console
$ cd vllm
$ python setup.py install # This may take 5-10 minutes.
22 changes: 22 additions & 0 deletions patch_xformers-0.0.22.post7.rocm.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#!/bin/bash
export XFORMERS_FMHA_FLASH_PATH=$(python -c 'from xformers import ops as xops; print(xops.fmha.flash.__file__)')
export XFORMERS_FMHA_COMMON_PATH=$(python -c 'from xformers import ops as xops; print(xops.fmha.common.__file__)')

echo $XFORMERS_FMHA_FLASH_PATH
echo $XFORMERS_FMHA_COMMON_PATH

if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_FLASH_PATH "./rocm_patch/flashpy_xformers-0.0.22.post7.rocm.patch"; then
echo "Applying patch to ${XFORMERS_FMHA_FLASH_PATH}"
patch -p0 $XFORMERS_FMHA_FLASH_PATH "./rocm_patch/flashpy_xformers-0.0.22.post7.rocm.patch"
echo "Successfully patch ${XFORMERS_FMHA_FLASH_PATH}"
else
echo "${XFORMERS_FMHA_FLASH_PATH} was patched before"
fi

if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commonpy_xformers-0.0.22.post7.rocm.patch"; then
echo "Applying patch to ${XFORMERS_FMHA_COMMON_PATH}"
patch -p0 $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commonpy_xformers-0.0.22.post7.rocm.patch"
echo "Successfully patch ${XFORMERS_FMHA_COMMON_PATH}"
else
echo "${XFORMERS_FMHA_COMMON_PATH} was patched before"
fi
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ requires = [
"ninja",
"packaging",
"setuptools",
"torch >= 2.1.0",
# "torch >= 2.1.0", # commented out to accommodate ROCm
"wheel",
]
build-backend = "setuptools.build_meta"
Expand Down
16 changes: 16 additions & 0 deletions requirements-rocm.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
ninja # For faster builds.
typing-extensions>=4.8.0
starlette
psutil
ray >= 2.5.1
pandas # Required for Ray data.
pyarrow # Required for Ray data.
sentencepiece # Required for LLaMA tokenizer.
numpy
tokenizers>=0.15.0
huggingface_hub<0.18,>=0.16.4
einops # Required for phi-1_5
transformers >= 4.34.0 # Required for Mistral.
fastapi
uvicorn[standard]
pydantic == 1.10.13 # Required for OpenAI server.
64 changes: 64 additions & 0 deletions rocm.Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
FROM rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1

# Install some basic utilities
RUN apt-get update && apt-get install python3 python3-pip -y

# Install some basic utilities
RUN apt-get update && apt-get install -y \
curl \
ca-certificates \
sudo \
git \
bzip2 \
libx11-6 \
build-essential \
wget \
unzip \
nvidia-cuda-toolkit \
tmux \
&& rm -rf /var/lib/apt/lists/*

### Mount Point ###
# When launching the container, mount the code directory to /app
ARG APP_MOUNT=/app
VOLUME [ ${APP_MOUNT} ]
WORKDIR ${APP_MOUNT}

RUN python3 -m pip install --upgrade pip
RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers

ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer
ENV PATH=$PATH:/opt/rocm/bin:/libtorch/bin:
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib:
ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/:
ENV PYTORCH_ROCM_ARCH=gfx900;gfx906;gfx908;gfx90a;gfx1030;gfx1101

# Install ROCm flash-attention
RUN mkdir libs \
&& cd libs \
&& git clone https://github.com/ROCmSoftwarePlatform/flash-attention.git \
&& cd flash-attention \
&& git submodule update --init \
&& sed -i -e "s/--offload-arch=native/--offload-arch=$(/opt/rocm/llvm/bin/amdgpu-offload-arch)/g" setup.py \
&& patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch \
&& python3 setup.py install \
&& cd ..

COPY ./ /app/vllm-rocm/

# RUN cd /app \
# && cd vllm-rocm \
# && git checkout v0.2.1.post1-rocm \
# && python3 setup.py install \
# && cd ..

# RUN cd /app \
# && mkdir dataset \
# && cd ..

# COPY ./benchmark_throughput.sh /app/benchmark_throughput.sh

RUN python3 -m pip install --upgrade pip
# RUN python3 -m pip install --no-cache-dir ray[all]

CMD ["/bin/bash"]
13 changes: 13 additions & 0 deletions rocm_patch/commonpy_xformers-0.0.22.post7.rocm.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
--- /opt/conda/envs/py_3.10/lib/python3.10/site-packages/xformers/ops/fmha/common.py 2023-11-29 03:17:03.930103539 +0000
+++ common.py 2023-11-28 16:14:19.846233146 +0000
@@ -298,8 +298,8 @@
dtype = d.query.dtype
if device_type not in cls.SUPPORTED_DEVICES:
reasons.append(f"device={device_type} (supported: {cls.SUPPORTED_DEVICES})")
- if device_type == "cuda" and not _built_with_cuda:
- reasons.append("xFormers wasn't build with CUDA support")
+ #if device_type == "cuda" and not _built_with_cuda:
+ # reasons.append("xFormers wasn't build with CUDA support")
if device_type == "cuda":
device_capability = torch.cuda.get_device_capability(d.device)
if device_capability < cls.CUDA_MINIMUM_COMPUTE_CAPABILITY:
134 changes: 134 additions & 0 deletions rocm_patch/flashpy_xformers-0.0.22.post7.rocm.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
--- /opt/conda/envs/py_3.10/lib/python3.10/site-packages/xformers/ops/fmha/flash.py 2023-11-29 03:17:03.930103539 +0000
+++ flash.py 2023-11-28 16:14:25.206128903 +0000
@@ -31,39 +31,39 @@

FLASH_VERSION = "0.0.0"
try:
- try:
- from ... import _C_flashattention # type: ignore[attr-defined]
- from ..._cpp_lib import _build_metadata
-
- if _build_metadata is not None:
- FLASH_VERSION = _build_metadata.flash_version
- except ImportError:
- import flash_attn
- from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention
-
- FLASH_VERSION = flash_attn.__version__
- flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:2])
- if flash_ver_parsed < (2, 3):
- raise ImportError("Requires 2.3 for sliding window support")
+ #try:
+ # from ... import _C_flashattention # type: ignore[attr-defined]
+ # from ..._cpp_lib import _build_metadata
+
+ # if _build_metadata is not None:
+ # FLASH_VERSION = _build_metadata.flash_version
+ #except ImportError:
+ import flash_attn
+ from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention
+
+ FLASH_VERSION = flash_attn.__version__
+ # flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:2])
+ # if flash_ver_parsed < (2, 3):
+ # raise ImportError("Requires 2.3 for sliding window support")

# create library so that flash-attn goes through the PyTorch Dispatcher
- _flash_lib = torch.library.Library("xformers_flash", "DEF")
+ #_flash_lib = torch.library.Library("xformers_flash", "DEF")

- _flash_lib.define(
- "flash_fwd(Tensor query, Tensor key, Tensor value, "
- "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, "
- "int max_seqlen_q, int max_seqlen_k, "
- "float p, float softmax_scale, "
- "bool is_causal, int window_size, bool return_softmax) -> (Tensor, Tensor, Tensor)"
- )
-
- _flash_lib.define(
- "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, "
- "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
- "Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
- "int max_seqlen_q, int max_seqlen_k, "
- "float p, float softmax_scale, bool is_causal, int window_size, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
- )
+ #_flash_lib.define(
+ # "flash_fwd(Tensor query, Tensor key, Tensor value, "
+ # "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, "
+ # "int max_seqlen_q, int max_seqlen_k, "
+ # "float p, float softmax_scale, "
+ # "bool is_causal, int window_size, bool return_softmax) -> (Tensor, Tensor, Tensor)"
+ #)
+
+ #_flash_lib.define(
+ # "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, "
+ # "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
+ # "Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
+ # "int max_seqlen_q, int max_seqlen_k, "
+ # "float p, float softmax_scale, bool is_causal, int window_size, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
+ #)

def _flash_fwd(
query,
@@ -98,8 +98,8 @@
p,
softmax_scale,
is_causal,
- window_size - 1, # window_size_left
- -1, # window_size_right
+ # window_size - 1, # window_size_left
+ # -1, # window_size_right
return_softmax,
None, # rng
)
@@ -127,8 +127,8 @@
softmax_scale,
False,
is_causal,
- window_size - 1, # window_size_left
- -1, # window_size_right
+ # window_size - 1, # window_size_left
+ # -1, # window_size_right
return_softmax,
None,
)
@@ -169,8 +169,8 @@
p,
softmax_scale,
is_causal,
- window_size - 1, # window_size_left
- -1, # window_size_right
+ # window_size - 1, # window_size_left
+ # -1, # window_size_right
None,
rng_state,
)
@@ -193,15 +193,15 @@
softmax_scale,
False, # zero_tensors
is_causal,
- window_size - 1, # window_size_left
- -1, # window_size_right
+ # window_size - 1, # window_size_left
+ # -1, # window_size_right
None,
rng_state,
)
return dq, dk, dv

- _flash_lib.impl("flash_fwd", _flash_fwd, "CUDA")
- _flash_lib.impl("flash_bwd", _flash_bwd, "CUDA")
+ #_flash_lib.impl("flash_fwd", _flash_fwd, "CUDA")
+ #_flash_lib.impl("flash_bwd", _flash_bwd, "CUDA")
except ImportError:
pass

@@ -348,7 +348,7 @@
implementation.
"""

- OPERATOR = get_operator("xformers_flash", "flash_fwd")
+ OPERATOR = _flash_fwd # get_operator("xformers_flash", "flash_fwd")
SUPPORTED_DEVICES: Set[str] = {"cuda"}
CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0)
SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16}
Loading