Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ RUN apt-get update --no-install-recommends && apt-get install -y nginx && mkdir
ENV HF_HUB_ENABLE_HF_TRANSFER=1
RUN pip install torch==2.7.0 torchvision==0.22.0 --index-url https://download.pytorch.org/whl/cu128 --no-cache-dir
RUN pip install packaging --no-cache-dir
RUN pip install flash-attn==2.8.0.post2 --no-build-isolation --no-cache-dir
RUN pip install flash-attn==2.8.0.post2 flashinfer-python>=0.2.7.post1 --no-build-isolation --no-cache-dir
COPY requirements.txt .
RUN pip install -r requirements.txt --no-cache-dir
RUN python -m nltk.downloader punkt
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ pip install --upgrade pip "setuptools<70.0.0" wheel
# TODO, unpin setuptools when this issue in flash attention is resolved
pip install torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/cu124
pip install packaging
pip install flash-attn==2.7.2.post1 --no-build-isolation
pip install flash-attn==2.7.2.post2 flashinfer-python>=0.2.7.post1 --no-build-isolation
pip install -r requirements.txt
pip install -e .
python -m nltk.downloader punkt
Expand All @@ -62,7 +62,7 @@ python -m nltk.downloader punkt
* **Local installation with uv (preview)**: We are experimenting with using [uv](https://docs.astral.sh/uv/). You can install via
```bash
uv sync
uv sync --extra compile --extra liger # to install flash attention and liger-kernel
uv sync --extra compile --extra liger # to install flash attention, flash infer, and liger-kernel
```


Expand Down
3 changes: 3 additions & 0 deletions mason.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,9 @@ def get_env_vars(pure_docker_mode: bool, cluster: List[str], beaker_secrets: Lis
whoami: str, resumable: bool, num_nodes: int, additional_env_vars: List[Dict[str, str]],
additional_secrets: List[Dict[str, str]]):
env_vars = []
if "VLLM_ATTENTION_BACKEND" not in additional_env_vars:
env_vars.append(beaker.EnvVar(name="VLLM_ATTENTION_BACKEND",
value="FLASHINFER"))
# Add user-specified environment variables first
for env_var in additional_env_vars:
env_vars.append(
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ explicit = true

# flash-attn related setups
[project.optional-dependencies]
compile = ["flash-attn>=2.8.0.post1"]
compile = ["flash-attn>=2.8.0.post2",
"flashinfer-python>=0.2.7.post1"]
liger = ["liger-kernel>=0.5.4"]
code = [
"fastapi>=0.100.0",
Expand Down
24 changes: 22 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ compressed-tensors==0.10.1
# via vllm
contourpy==1.3.2
# via matplotlib
cuda-bindings==12.9.0
# via cuda-python
cuda-python==12.9.0
# via flashinfer-python
cupy-cuda12x==13.4.1 ; sys_platform != 'darwin'
# via ray
cycler==0.12.1
Expand Down Expand Up @@ -134,6 +138,7 @@ docker-pycreds==0.4.0
einops==0.8.1
# via
# flash-attn
# flashinfer-python
# vllm
email-validator==2.2.0
# via fastapi
Expand Down Expand Up @@ -161,6 +166,8 @@ filelock==3.18.0
flake8==7.2.0
flash-attn==2.8.0.post2
# via open-instruct
flashinfer-python==0.2.7.post1
# via open-instruct
fonttools==4.58.1
# via matplotlib
frozenlist==1.6.2
Expand Down Expand Up @@ -340,6 +347,7 @@ networkx==3.5 ; python_full_version >= '3.11'
ninja==1.11.1.4
# via
# deepspeed
# flashinfer-python
# vllm
# xgrammar
nltk==3.9.1
Expand All @@ -354,6 +362,7 @@ numpy==1.26.4
# cupy-cuda12x
# datasets
# deepspeed
# flashinfer-python
# gguf
# matplotlib
# mistral-common
Expand Down Expand Up @@ -397,7 +406,9 @@ nvidia-cusparse-cu12==12.5.7.53 ; platform_machine == 'x86_64' and sys_platform
nvidia-cusparselt-cu12==0.6.3 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
nvidia-ml-py==12.575.51
# via nvitop
# via
# nvitop
# pynvml
nvidia-nccl-cu12==2.26.2 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
nvidia-nvjitlink-cu12==12.8.61 ; platform_machine == 'x86_64' and sys_platform == 'linux'
Expand All @@ -406,6 +417,8 @@ nvidia-nvjitlink-cu12==12.8.61 ; platform_machine == 'x86_64' and sys_platform =
# nvidia-cusolver-cu12
# nvidia-cusparse-cu12
# torch
nvidia-nvshmem-cu12==3.3.9
# via flashinfer-python
nvidia-nvtx-cu12==12.8.55 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
nvitop==1.5.1
Expand Down Expand Up @@ -580,6 +593,8 @@ pygments==2.19.1
# rich
pymdown-extensions==10.15
# via mkdocs-material
pynvml==12.0.0
# via flashinfer-python
pyparsing==3.2.3
# via matplotlib
pytest==8.4.0
Expand All @@ -599,7 +614,9 @@ python-multipart==0.0.20
pytz==2025.2
# via pandas
pywin32==310 ; sys_platform == 'win32'
# via docker
# via
# cuda-bindings
# docker
pyyaml==6.0.2
# via
# accelerate
Expand Down Expand Up @@ -642,6 +659,7 @@ requests==2.32.3
# beaker-py
# datasets
# docker
# flashinfer-python
# google-api-core
# huggingface-hub
# mistral-common
Expand Down Expand Up @@ -743,6 +761,7 @@ torch==2.7.0 ; sys_platform == 'darwin'
# compressed-tensors
# deepspeed
# flash-attn
# flashinfer-python
# liger-kernel
# open-instruct
# outlines
Expand All @@ -758,6 +777,7 @@ torch==2.7.0+cu128 ; sys_platform != 'darwin'
# compressed-tensors
# deepspeed
# flash-attn
# flashinfer-python
# liger-kernel
# open-instruct
# outlines
Expand Down
Loading