diff --git a/csrc/quantization.cpp b/csrc/quantization.cpp index dfe17a496c78..764188597fd3 100644 --- a/csrc/quantization.cpp +++ b/csrc/quantization.cpp @@ -1,11 +1,14 @@ #include -torch::Tensor awq_gemm( - torch::Tensor _in_feats, - torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, - int split_k_iters); + +#ifndef USE_ROCM + torch::Tensor awq_gemm( + torch::Tensor _in_feats, + torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, + int split_k_iters); +#endif void squeezellm_gemm( torch::Tensor vec, @@ -14,6 +17,8 @@ void squeezellm_gemm( torch::Tensor lookup_table); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { +#ifndef USE_ROCM m.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); +#endif m.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); } diff --git a/docs/source/getting_started/amd-installation.rst b/docs/source/getting_started/amd-installation.rst new file mode 100644 index 000000000000..862b8a1d6e89 --- /dev/null +++ b/docs/source/getting_started/amd-installation.rst @@ -0,0 +1,104 @@ +.. _installation: + +Installation with ROCm +============ + +vLLM-ROCm is here! Currently it is supporting llama-2. + +Requirements +------------ + +* OS: Linux +* Python: 3.8 -- 3.11 (Recommended 3.10 as this is the version that has been tested on.) +* GPU: MI210 +* Pytorch 2.0.1/2.1.1 +* ROCm 5.7 + + +Install with pip +---------------- + +You can install vLLM using pip: + +.. code-block:: console + + $ # (Optional) Create a new conda environment. + $ conda create -n myenv python=3.8 -y + $ conda activate myenv + + $ # Install vLLM with CUDA 12.1. + $ pip install vllm + +.. note:: + + As of now, vLLM's binaries are compiled on CUDA 12.1 by default. + However, you can install vLLM with CUDA 11.8 by running: + + .. code-block:: console + + $ # Install vLLM with CUDA 11.8. + $ # Replace `cp310` with your Python version (e.g., `cp38`, `cp39`, `cp311`). + $ pip install https://github.com/vllm-project/vllm/releases/download/v0.2.2/vllm-0.2.2+cu118-cp310-cp310-manylinux1_x86_64.whl + + $ # Re-install PyTorch with CUDA 11.8. + $ pip uninstall torch -y + $ pip install torch --upgrade --index-url https://download.pytorch.org/whl/cu118 + + +.. _build_from_source: + +Build from source with docker +----------------- + +You can also build and install vLLM from source: + +Build a docker image from `rocm.Dockerfile`, and launch a docker container. + +.. code-block:: console + + $ docker build -f rocm.Dockerfile -t vllm-rocm . + $ docker run -it \ + --network=host \ + --group-add=video \ + --ipc=host \ + --cap-add=SYS_PTRACE \ + --security-opt seccomp=unconfined \ + --shm-size 8G \ + --device /dev/kfd \ + --device /dev/dri \ + -v :/app/hf_model \ + vllm-rocm \ + bash + +If you are going to setup on new pytorch+rocm5.7 docker container, you can follow the following steps. + +1. Install flash-attention-2-rocm + + If you are using Pytorch-2.0.1+rocm5.7. + + Install flash-attention-2 (v2.0.4) following the instruction from [ROCmSoftwarePlatform/flash-attention](https://github.com/ROCmSoftwarePlatform/flash-attention/tree/flash_attention_for_rocm) + + + If you are using Pytorch-2.1.x+rocm5.7 or Pytorch-2.2.x+rocm5.7, you don't need to apply the `hipify_python.patch`. + You can directly build the flash-attention-2. + + .. code-block:: console + + $ bash patch_torch211_flash_attn2.rocm.sh + + .. note:: + - Flash-attention-2 (v2.0.4) does not support sliding windows attention. + - You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`) + +2. Setup xformers==0.0.22.post7 without dependencies, and apply patches + + .. code-block:: console + + $ pip install xformers==0.0.22.post7 --no-deps + $ bash patch_xformers-0.0.22.post7.rocm.sh + +3. Build vllm. + + .. code-block:: console + $ cd vllm + $ python setup.py install # This may take 5-10 minutes. diff --git a/patch_xformers-0.0.22.post7.rocm.sh b/patch_xformers-0.0.22.post7.rocm.sh new file mode 100644 index 000000000000..c8e58f721ae8 --- /dev/null +++ b/patch_xformers-0.0.22.post7.rocm.sh @@ -0,0 +1,22 @@ +#!/bin/bash +export XFORMERS_FMHA_FLASH_PATH=$(python -c 'from xformers import ops as xops; print(xops.fmha.flash.__file__)') +export XFORMERS_FMHA_COMMON_PATH=$(python -c 'from xformers import ops as xops; print(xops.fmha.common.__file__)') + +echo $XFORMERS_FMHA_FLASH_PATH +echo $XFORMERS_FMHA_COMMON_PATH + +if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_FLASH_PATH "./rocm_patch/flashpy_xformers-0.0.22.post7.rocm.patch"; then + echo "Applying patch to ${XFORMERS_FMHA_FLASH_PATH}" + patch -p0 $XFORMERS_FMHA_FLASH_PATH "./rocm_patch/flashpy_xformers-0.0.22.post7.rocm.patch" + echo "Successfully patch ${XFORMERS_FMHA_FLASH_PATH}" +else + echo "${XFORMERS_FMHA_FLASH_PATH} was patched before" +fi + +if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commonpy_xformers-0.0.22.post7.rocm.patch"; then + echo "Applying patch to ${XFORMERS_FMHA_COMMON_PATH}" + patch -p0 $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commonpy_xformers-0.0.22.post7.rocm.patch" + echo "Successfully patch ${XFORMERS_FMHA_COMMON_PATH}" +else + echo "${XFORMERS_FMHA_COMMON_PATH} was patched before" +fi \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index e3e3e389f789..f9390ed8c52a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ requires = [ "ninja", "packaging", "setuptools", - "torch >= 2.1.0", + # "torch >= 2.1.0", # commented out to accommodate ROCm "wheel", ] build-backend = "setuptools.build_meta" diff --git a/requirements-rocm.txt b/requirements-rocm.txt new file mode 100644 index 000000000000..53fd3ea24d92 --- /dev/null +++ b/requirements-rocm.txt @@ -0,0 +1,16 @@ +ninja # For faster builds. +typing-extensions>=4.8.0 +starlette +psutil +ray >= 2.5.1 +pandas # Required for Ray data. +pyarrow # Required for Ray data. +sentencepiece # Required for LLaMA tokenizer. +numpy +tokenizers>=0.15.0 +huggingface_hub<0.18,>=0.16.4 +einops # Required for phi-1_5 +transformers >= 4.34.0 # Required for Mistral. +fastapi +uvicorn[standard] +pydantic == 1.10.13 # Required for OpenAI server. diff --git a/rocm.Dockerfile b/rocm.Dockerfile new file mode 100644 index 000000000000..27b57097740d --- /dev/null +++ b/rocm.Dockerfile @@ -0,0 +1,64 @@ +FROM rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1 + +# Install some basic utilities +RUN apt-get update && apt-get install python3 python3-pip -y + +# Install some basic utilities +RUN apt-get update && apt-get install -y \ + curl \ + ca-certificates \ + sudo \ + git \ + bzip2 \ + libx11-6 \ + build-essential \ + wget \ + unzip \ + nvidia-cuda-toolkit \ + tmux \ + && rm -rf /var/lib/apt/lists/* + +### Mount Point ### +# When launching the container, mount the code directory to /app +ARG APP_MOUNT=/app +VOLUME [ ${APP_MOUNT} ] +WORKDIR ${APP_MOUNT} + +RUN python3 -m pip install --upgrade pip +RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers + +ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer +ENV PATH=$PATH:/opt/rocm/bin:/libtorch/bin: +ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib: +ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/: +ENV PYTORCH_ROCM_ARCH=gfx900;gfx906;gfx908;gfx90a;gfx1030;gfx1101 + +# Install ROCm flash-attention +RUN mkdir libs \ + && cd libs \ + && git clone https://github.com/ROCmSoftwarePlatform/flash-attention.git \ + && cd flash-attention \ + && git submodule update --init \ + && sed -i -e "s/--offload-arch=native/--offload-arch=$(/opt/rocm/llvm/bin/amdgpu-offload-arch)/g" setup.py \ + && patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch \ + && python3 setup.py install \ + && cd .. + +COPY ./ /app/vllm-rocm/ + +# RUN cd /app \ +# && cd vllm-rocm \ +# && git checkout v0.2.1.post1-rocm \ +# && python3 setup.py install \ +# && cd .. + +# RUN cd /app \ +# && mkdir dataset \ +# && cd .. + +# COPY ./benchmark_throughput.sh /app/benchmark_throughput.sh + +RUN python3 -m pip install --upgrade pip +# RUN python3 -m pip install --no-cache-dir ray[all] + +CMD ["/bin/bash"] \ No newline at end of file diff --git a/rocm_patch/commonpy_xformers-0.0.22.post7.rocm.patch b/rocm_patch/commonpy_xformers-0.0.22.post7.rocm.patch new file mode 100644 index 000000000000..4d7495cf13e1 --- /dev/null +++ b/rocm_patch/commonpy_xformers-0.0.22.post7.rocm.patch @@ -0,0 +1,13 @@ +--- /opt/conda/envs/py_3.10/lib/python3.10/site-packages/xformers/ops/fmha/common.py 2023-11-29 03:17:03.930103539 +0000 ++++ common.py 2023-11-28 16:14:19.846233146 +0000 +@@ -298,8 +298,8 @@ + dtype = d.query.dtype + if device_type not in cls.SUPPORTED_DEVICES: + reasons.append(f"device={device_type} (supported: {cls.SUPPORTED_DEVICES})") +- if device_type == "cuda" and not _built_with_cuda: +- reasons.append("xFormers wasn't build with CUDA support") ++ #if device_type == "cuda" and not _built_with_cuda: ++ # reasons.append("xFormers wasn't build with CUDA support") + if device_type == "cuda": + device_capability = torch.cuda.get_device_capability(d.device) + if device_capability < cls.CUDA_MINIMUM_COMPUTE_CAPABILITY: diff --git a/rocm_patch/flashpy_xformers-0.0.22.post7.rocm.patch b/rocm_patch/flashpy_xformers-0.0.22.post7.rocm.patch new file mode 100644 index 000000000000..4798f1efd461 --- /dev/null +++ b/rocm_patch/flashpy_xformers-0.0.22.post7.rocm.patch @@ -0,0 +1,134 @@ +--- /opt/conda/envs/py_3.10/lib/python3.10/site-packages/xformers/ops/fmha/flash.py 2023-11-29 03:17:03.930103539 +0000 ++++ flash.py 2023-11-28 16:14:25.206128903 +0000 +@@ -31,39 +31,39 @@ + + FLASH_VERSION = "0.0.0" + try: +- try: +- from ... import _C_flashattention # type: ignore[attr-defined] +- from ..._cpp_lib import _build_metadata +- +- if _build_metadata is not None: +- FLASH_VERSION = _build_metadata.flash_version +- except ImportError: +- import flash_attn +- from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention +- +- FLASH_VERSION = flash_attn.__version__ +- flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:2]) +- if flash_ver_parsed < (2, 3): +- raise ImportError("Requires 2.3 for sliding window support") ++ #try: ++ # from ... import _C_flashattention # type: ignore[attr-defined] ++ # from ..._cpp_lib import _build_metadata ++ ++ # if _build_metadata is not None: ++ # FLASH_VERSION = _build_metadata.flash_version ++ #except ImportError: ++ import flash_attn ++ from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention ++ ++ FLASH_VERSION = flash_attn.__version__ ++ # flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:2]) ++ # if flash_ver_parsed < (2, 3): ++ # raise ImportError("Requires 2.3 for sliding window support") + + # create library so that flash-attn goes through the PyTorch Dispatcher +- _flash_lib = torch.library.Library("xformers_flash", "DEF") ++ #_flash_lib = torch.library.Library("xformers_flash", "DEF") + +- _flash_lib.define( +- "flash_fwd(Tensor query, Tensor key, Tensor value, " +- "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, " +- "int max_seqlen_q, int max_seqlen_k, " +- "float p, float softmax_scale, " +- "bool is_causal, int window_size, bool return_softmax) -> (Tensor, Tensor, Tensor)" +- ) +- +- _flash_lib.define( +- "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, " +- "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, " +- "Tensor cu_seqlens_q, Tensor cu_seqlens_k, " +- "int max_seqlen_q, int max_seqlen_k, " +- "float p, float softmax_scale, bool is_causal, int window_size, Tensor rng_state) -> (Tensor, Tensor, Tensor)" +- ) ++ #_flash_lib.define( ++ # "flash_fwd(Tensor query, Tensor key, Tensor value, " ++ # "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, " ++ # "int max_seqlen_q, int max_seqlen_k, " ++ # "float p, float softmax_scale, " ++ # "bool is_causal, int window_size, bool return_softmax) -> (Tensor, Tensor, Tensor)" ++ #) ++ ++ #_flash_lib.define( ++ # "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, " ++ # "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, " ++ # "Tensor cu_seqlens_q, Tensor cu_seqlens_k, " ++ # "int max_seqlen_q, int max_seqlen_k, " ++ # "float p, float softmax_scale, bool is_causal, int window_size, Tensor rng_state) -> (Tensor, Tensor, Tensor)" ++ #) + + def _flash_fwd( + query, +@@ -98,8 +98,8 @@ + p, + softmax_scale, + is_causal, +- window_size - 1, # window_size_left +- -1, # window_size_right ++ # window_size - 1, # window_size_left ++ # -1, # window_size_right + return_softmax, + None, # rng + ) +@@ -127,8 +127,8 @@ + softmax_scale, + False, + is_causal, +- window_size - 1, # window_size_left +- -1, # window_size_right ++ # window_size - 1, # window_size_left ++ # -1, # window_size_right + return_softmax, + None, + ) +@@ -169,8 +169,8 @@ + p, + softmax_scale, + is_causal, +- window_size - 1, # window_size_left +- -1, # window_size_right ++ # window_size - 1, # window_size_left ++ # -1, # window_size_right + None, + rng_state, + ) +@@ -193,15 +193,15 @@ + softmax_scale, + False, # zero_tensors + is_causal, +- window_size - 1, # window_size_left +- -1, # window_size_right ++ # window_size - 1, # window_size_left ++ # -1, # window_size_right + None, + rng_state, + ) + return dq, dk, dv + +- _flash_lib.impl("flash_fwd", _flash_fwd, "CUDA") +- _flash_lib.impl("flash_bwd", _flash_bwd, "CUDA") ++ #_flash_lib.impl("flash_fwd", _flash_fwd, "CUDA") ++ #_flash_lib.impl("flash_bwd", _flash_bwd, "CUDA") + except ImportError: + pass + +@@ -348,7 +348,7 @@ + implementation. + """ + +- OPERATOR = get_operator("xformers_flash", "flash_fwd") ++ OPERATOR = _flash_fwd # get_operator("xformers_flash", "flash_fwd") + SUPPORTED_DEVICES: Set[str] = {"cuda"} + CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0) + SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16} diff --git a/setup.py b/setup.py index 0e28b9360277..b75d0912a9bb 100644 --- a/setup.py +++ b/setup.py @@ -16,30 +16,64 @@ # Supported NVIDIA GPU architectures. NVIDIA_SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"} -ROCM_SUPPORTED_ARCHS = {"gfx90a", "gfx908", "gfx906", "gfx1030","gfx1100"} -SUPPORTED_ARCHS = NVIDIA_SUPPORTED_ARCHS.union(ROCM_SUPPORTED_ARCHS) +ROCM_SUPPORTED_ARCHS = {"gfx90a", "gfx908", "gfx906", "gfx1030", "gfx1100"} +# SUPPORTED_ARCHS = NVIDIA_SUPPORTED_ARCHS.union(ROCM_SUPPORTED_ARCHS) # Compiler flags. CXX_FLAGS = ["-g", "-O2", "-std=c++17"] # TODO(woosuk): Should we use -O3? NVCC_FLAGS = ["-O2", "-std=c++17"] -if torch.version.hip: - if ROCM_HOME is not None: - NVCC_FLAGS += [f"-DUSE_ROCM"] +if torch.cuda.is_available() and torch.version.hip and ROCM_HOME is not None: + NVCC_FLAGS += ["-DUSE_ROCM"] -if not torch.version.hip: - if CUDA_HOME is None: - raise RuntimeError( - "Cannot find CUDA_HOME. CUDA must be available to build the package.") +if torch.cuda.is_available() and torch.version.cuda and CUDA_HOME is None: + raise RuntimeError( + "Cannot find CUDA_HOME. CUDA must be available to build the package.") ABI = 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0 CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] -if CUDA_HOME is None: - raise RuntimeError( - "Cannot find CUDA_HOME. CUDA must be available to build the package.") + +def get_amdgpu_offload_arch(): + error_message = "" + command = "/opt/rocm/llvm/bin/amdgpu-offload-arch" + try: + output = subprocess.check_output([command]) + return output.decode('utf-8').strip() + except subprocess.CalledProcessError as e: + error_message = f"Error: {e}" + except FileNotFoundError: + # If the command is not found, print an error message + error_message = f"The command {command} was not found." + + if error_message: + raise RuntimeError(error_message) + + return None + + +def get_hipcc_rocm_version(): + # Run the hipcc --version command + result = subprocess.run(['hipcc', '--version'], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True) + + # Check if the command was executed successfully + if result.returncode != 0: + print("Error running 'hipcc --version'") + return None + + # Extract the version using a regular expression + match = re.search(r'HIP version: (\S+)', result.stdout) + if match: + # Return the version string + return match.group(1) + else: + print("Could not find HIP version in the output") + return None def get_nvcc_cuda_version(cuda_dir: str) -> Version: @@ -72,7 +106,9 @@ def get_torch_arch_list() -> Set[str]: return set() # Filter out the invalid architectures and print a warning. - valid_archs = SUPPORTED_ARCHS.union({s + "+PTX" for s in NVIDIA_SUPPORTED_ARCHS}) + valid_archs = NVIDIA_SUPPORTED_ARCHS.union( + {s + "+PTX" + for s in NVIDIA_SUPPORTED_ARCHS}) arch_list = torch_arch_list.intersection(valid_archs) # If none of the specified architectures are valid, raise an error. if not arch_list: @@ -93,24 +129,24 @@ def get_torch_arch_list() -> Set[str]: # First, check the TORCH_CUDA_ARCH_LIST environment variable. compute_capabilities = get_torch_arch_list() -if not torch.version.hip: - if not compute_capabilities: - # If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available - # GPUs on the current machine. - device_count = torch.cuda.device_count() - for i in range(device_count): - major, minor = torch.cuda.get_device_capability(i) - if major < 7: - raise RuntimeError( - "GPUs with compute capability below 7.0 are not supported.") - compute_capabilities.add(f"{major}.{minor}") - -if not torch.version.hip: +if torch.cuda.is_available( +) and torch.version.cuda and not compute_capabilities: + # If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available + # GPUs on the current machine. + device_count = torch.cuda.device_count() + for i in range(device_count): + major, minor = torch.cuda.get_device_capability(i) + if major < 7: + raise RuntimeError( + "GPUs with compute capability below 7.0 are not supported.") + compute_capabilities.add(f"{major}.{minor}") + +if torch.cuda.is_available() and torch.version.cuda: nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME) if not compute_capabilities: # If no GPU is specified nor available, add all supported architectures # based on the NVCC CUDA version. - compute_capabilities = SUPPORTED_ARCHS.copy() + compute_capabilities = NVIDIA_SUPPORTED_ARCHS.copy() if nvcc_cuda_version < Version("11.1"): compute_capabilities.remove("8.6") if nvcc_cuda_version < Version("11.8"): @@ -118,7 +154,8 @@ def get_torch_arch_list() -> Set[str]: compute_capabilities.remove("9.0") # Validate the NVCC CUDA version. if nvcc_cuda_version < Version("11.0"): - raise RuntimeError("CUDA 11.0 or higher is required to build the package.") + raise RuntimeError( + "CUDA 11.0 or higher is required to build the package.") if (nvcc_cuda_version < Version("11.1") and any(cc.startswith("8.6") for cc in compute_capabilities)): raise RuntimeError( @@ -135,7 +172,7 @@ def get_torch_arch_list() -> Set[str]: "Targeting compute capability 8.0 instead.", stacklevel=2) compute_capabilities = set(cc for cc in compute_capabilities - if not cc.startswith("8.9")) + if not cc.startswith("8.9")) compute_capabilities.add("8.0+PTX") if any(cc.startswith("9.0") for cc in compute_capabilities): raise RuntimeError( @@ -146,13 +183,22 @@ def get_torch_arch_list() -> Set[str]: num = capability[0] + capability[2] NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"] if capability.endswith("+PTX"): - NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"] + NVCC_FLAGS += [ + "-gencode", f"arch=compute_{num},code=compute_{num}" + ] # Use NVCC threads to parallelize the build. if nvcc_cuda_version >= Version("11.2"): num_threads = min(os.cpu_count(), 8) NVCC_FLAGS += ["--threads", str(num_threads)] +elif torch.cuda.is_available() and torch.version.hip: + amd_arch = get_amdgpu_offload_arch() + if amd_arch not in ROCM_SUPPORTED_ARCHS: + raise RuntimeError( + f"Only the following arch is supported: {ROCM_SUPPORTED_ARCHS}" + f"amdgpu_arch_found: {amd_arch}") + ext_modules = [] # Cache operations. @@ -211,7 +257,7 @@ def get_torch_arch_list() -> Set[str]: ext_modules.append(activation_extension) # Quantization kernels. -if not torch.version.hip: +if torch.cuda.is_available() and torch.version.cuda: quantization_extension = CUDAExtension( name="vllm.quantization_ops", sources=[ @@ -224,7 +270,7 @@ def get_torch_arch_list() -> Set[str]: "nvcc": NVCC_FLAGS, }, ) -else: +elif torch.cuda.is_available() and torch.version.hip: quantization_extension = CUDAExtension( name="vllm.quantization_ops", sources=[ @@ -269,10 +315,20 @@ def find_version(filepath: str) -> str: def get_vllm_version() -> str: version = find_version(get_path("vllm", "__init__.py")) - # cuda_version = str(nvcc_cuda_version) - # if cuda_version != MAIN_CUDA_VERSION: - # cuda_version_str = cuda_version.replace(".", "")[:3] - # version += f"+cu{cuda_version_str}" + + if torch.cuda.is_available() and torch.version.cuda: + cuda_version = str(nvcc_cuda_version) + if cuda_version != MAIN_CUDA_VERSION: + cuda_version_str = cuda_version.replace(".", "")[:3] + version += f"+cu{cuda_version_str}" + + elif torch.cuda.is_available() and torch.version.hip: + # Get the HIP version + hipcc_version = get_hipcc_rocm_version() + if hipcc_version != MAIN_CUDA_VERSION: + rocm_version_str = hipcc_version.replace(".", "")[:3] + version += f"+rocm{rocm_version_str}" + return version @@ -287,8 +343,14 @@ def read_readme() -> str: def get_requirements() -> List[str]: """Get Python package dependencies from requirements.txt.""" - with open(get_path("requirements.txt")) as f: - requirements = f.read().strip().split("\n") + if torch.cuda.is_available() and torch.version.hip: + with open(get_path("requirements-rocm.txt")) as f: + requirements = f.read().strip().split("\n") + elif torch.cuda.is_available() and torch.version.cuda: + with open(get_path("requirements.txt")) as f: + requirements = f.read().strip().split("\n") + print("requirements: ", requirements) + # exit() return requirements diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index c7e476c70474..c7612b3ac407 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -3,6 +3,8 @@ from dataclasses import dataclass from typing import Optional, Tuple +import torch + from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig) @@ -83,32 +85,52 @@ def add_cli_args( help='directory to download and load the weights, ' 'default to the default cache dir of ' 'huggingface') - parser.add_argument( - '--load-format', - type=str, - default=EngineArgs.load_format, - choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'], - help='The format of the model weights to load. ' - '"auto" will try to load the weights in the safetensors format ' - 'and fall back to the pytorch bin format if safetensors format ' - 'is not available. ' - '"pt" will load the weights in the pytorch bin format. ' - '"safetensors" will load the weights in the safetensors format. ' - '"npcache" will load the weights in pytorch format and store ' - 'a numpy cache to speed up the loading. ' - '"dummy" will initialize the weights with random values, ' - 'which is mainly for profiling.') - parser.add_argument( - '--dtype', - type=str, - default=EngineArgs.dtype, - choices=[ - 'auto', 'half', 'float16', 'bfloat16', 'float', 'float32' - ], - help='data type for model weights and activations. ' - 'The "auto" option will use FP16 precision ' - 'for FP32 and FP16 models, and BF16 precision ' - 'for BF16 models.') + if torch.cuda.is_available() and torch.version.hip: + # do something specific for HIP + parser.add_argument( + '--load-format', + type=str, + default='pt', + choices=['pt'], + help='The format of the model weights to load. ' + '"pt" will load the weights in the pytorch bin format. ') + parser.add_argument( + '--dtype', + type=str, + default='half', + choices=['half', 'float16', 'bfloat16'], + help='data type for model weights and activations. ' + 'The default option is FP16 precision ' + 'Supports FP16 and BF16 ') + elif torch.cuda.is_available() and torch.version.cuda: + # do something specific for CUDA + parser.add_argument( + '--load-format', + type=str, + default=EngineArgs.load_format, + choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'], + help='The format of the model weights to load. ' + '"auto" will try to load the weights in the safetensors format ' + 'and fall back to the pytorch bin format if safetensors format ' + 'is not available. ' + '"pt" will load the weights in the pytorch bin format. ' + '"safetensors" will load the weights in the safetensors format. ' + '"npcache" will load the weights in pytorch format and store ' + 'a numpy cache to speed up the loading. ' + '"dummy" will initialize the weights with random values, ' + 'which is mainly for profiling.') + parser.add_argument( + '--dtype', + type=str, + default=EngineArgs.dtype, + choices=[ + 'auto', 'half', 'float16', 'bfloat16', 'float', 'float32' + ], + help='data type for model weights and activations. ' + 'The "auto" option will use FP16 precision ' + 'for FP32 and FP16 models, and BF16 precision ' + 'for BF16 models.') + parser.add_argument('--max-model-len', type=int, default=None, @@ -171,13 +193,23 @@ def add_cli_args( parser.add_argument('--disable-log-stats', action='store_true', help='disable logging statistics') - # Quantization settings. - parser.add_argument('--quantization', - '-q', - type=str, - choices=['awq', 'squeezellm', None], - default=None, - help='Method used to quantize the weights') + if torch.cuda.is_available() and torch.version.hip: + # Quantization settings. + parser.add_argument('--quantization', + '-q', + type=str, + choices=['squeezellm', None], + default=None, + help='Method used to quantize the weights') + + elif torch.cuda.is_available() and torch.version.cuda: + # Quantization settings. + parser.add_argument('--quantization', + '-q', + type=str, + choices=['awq', 'squeezellm', None], + default=None, + help='Method used to quantize the weights') return parser @classmethod diff --git a/vllm/engine/ray_utils.py b/vllm/engine/ray_utils.py index ee58b8b9074a..6bff8153e2a5 100644 --- a/vllm/engine/ray_utils.py +++ b/vllm/engine/ray_utils.py @@ -73,7 +73,9 @@ def initialize_cluster( "Ray is not installed. Please install Ray to use distributed " "serving.") # Connect to a ray cluster. - ray.init(address=ray_address, ignore_reinit_error=True) + ray.init(address=ray_address, + ignore_reinit_error=True, + num_gpus=parallel_config.world_size) if not parallel_config.worker_use_ray: # Initialize cluster locally. diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index e51bb311decd..2e042721d9a2 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -113,6 +113,8 @@ def multi_query_kv_attention( attn_bias=input_metadata.attn_bias, p=0.0, scale=self.scale, + op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if + (torch.cuda.is_available() and torch.version.hip) else None, ) # TODO(woosuk): Unnecessary copy. Optimize. output.copy_(out.view_as(output)) @@ -451,6 +453,8 @@ def multi_query_kv_attention( attn_bias=input_metadata.attn_bias, p=0.0, scale=self.scale, + op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if + (torch.cuda.is_available() and torch.version.hip) else None, ) # TODO(woosuk): Unnecessary copy. Optimize. output.copy_(out.view_as(output)) diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 3d937ba64f9f..f4d25566cf59 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -1,14 +1,16 @@ from typing import Type - -from vllm.model_executor.layers.quantization.awq import AWQConfig +import torch from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig _QUANTIZATION_CONFIG_REGISTRY = { - "awq": AWQConfig, "squeezellm": SqueezeLLMConfig, } +if torch.cuda.is_available() and torch.version.cuda: + from vllm.model_executor.layers.quantization.awq import AWQConfig + _QUANTIZATION_CONFIG_REGISTRY["awq"] = AWQConfig + def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: if quantization not in _QUANTIZATION_CONFIG_REGISTRY: diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 0ab5819d930a..5561a9309fc3 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -2,8 +2,12 @@ import torch from torch.nn.parameter import Parameter +if torch.cuda.is_available() and torch.version.hip: + # do something specific for HIP + print("Warning: vLLM does not support AWQ on ROCm.") +elif torch.cuda.is_available() and torch.version.cuda: + from vllm import quantization_ops -from vllm import quantization_ops from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import QuantizationConfig diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py index 61ec8b79b6dd..be318b2ef205 100644 --- a/vllm/model_executor/layers/quantization/squeezellm.py +++ b/vllm/model_executor/layers/quantization/squeezellm.py @@ -114,10 +114,19 @@ def apply_weights(self, lookup_table = weights["lookup_table"] out_shape = x.shape[:-1] + (qweight.shape[-1], ) reshaped_x = x.reshape(-1, x.shape[-1]) - # NOTE: The output tensor should be zero-initialized. - out = torch.zeros(out_shape, device="cuda", dtype=torch.float16) - quantization_ops.squeezellm_gemm(reshaped_x, qweight, out, - lookup_table) + if torch.cuda.is_available() and torch.version.hip: + out_float = torch.zeros(out_shape, + device="cuda", + dtype=torch.float) + quantization_ops.squeezellm_gemm(reshaped_x, qweight, out_float, + lookup_table) + out = out_float.to(dtype=torch.float16) + # do something specific for HIP + elif torch.cuda.is_available() and torch.version.cuda: + # NOTE: The output tensor should be zero-initialized. + out = torch.zeros(out_shape, device="cuda", dtype=torch.float16) + quantization_ops.squeezellm_gemm(reshaped_x, qweight, out, + lookup_table) if bias is not None: out = out + bias