use updated version of prebuilt wheels for flash attention for cu130#3342
Conversation
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the 📝 WalkthroughWalkthroughActivates a new CUDA 13.0 / Python 3.11 / PyTorch 2.9.1 matrix configuration in GitHub Actions workflows and updates the Dockerfile base image to support flash-attention wheel installation for both CUDA 12.8 and 13.0 architectures via conditional logic. Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Pre-merge checks✅ Passed checks (3 passed)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
Fix all issues with AI Agents 🤖
In @docker/Dockerfile-base:
- Around line 54-66: The case branch for pattern "2.9.[0-9]*)" closes
prematurely so the CUDA 130 if-block is left outside the case and causes a shell
syntax error; edit the RUN case on PYTORCH_VERSION so both CUDA checks (the if [
"$CUDA" = "128" ] and the if [ "$CUDA" = "130" ]) are placed inside the same
"2.9.[0-9]*)" pattern before its terminating fi ;; (i.e., move the CUDA 130
block up into the same case branch and ensure a single matching fi ;; ends that
branch).
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
.github/workflows/main.ymldocker/Dockerfile-base
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (14)
- GitHub Check: PyTest (3.11, 2.8.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.9.1)
- GitHub Check: PyTest (3.11, 2.9.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.9.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.8.0)
- GitHub Check: PyTest (3.11, 2.9.1)
- GitHub Check: build-base-uv (128, 12.8.1, 3.11, 2.9.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX, Dockerfile-uv-base)
- GitHub Check: build-base (128, 12.8.1, 3.11, 2.9.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX, Dockerfile-base)
- GitHub Check: build-base-uv (128, 12.8.1, 3.11, 2.8.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX, Dockerfile-uv-base)
- GitHub Check: build-base (130, 13.0.0, 3.11, 2.9.1, 9.0+PTX, Dockerfile-base)
- GitHub Check: build-base-uv (128, 12.8.1, 3.11, 2.9.1, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX, Dockerfile-uv-base)
- GitHub Check: build-base (128, 12.8.1, 3.11, 2.9.1, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX, Dockerfile-base)
- GitHub Check: build-base-uv (130, 13.0.0, 3.11, 2.9.1, 9.0+PTX, Dockerfile-uv-base)
- GitHub Check: build-base (128, 12.8.1, 3.11, 2.8.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX, Dockerfile-base)
🔇 Additional comments (5)
.github/workflows/main.yml (3)
34-38: LGTM! New CUDA 13.0 matrix entry added correctly.The new matrix configuration for CUDA 13.0 with PyTorch 2.9.1 and Python 3.11 is properly structured and consistent with existing entries.
101-105: LGTM! Consistent CUDA 13.0 configuration in build-axolotl-cloud.The matrix entry matches the configuration added to the build-axolotl job.
140-191: Verify whether CUDA 13.0 support is needed for build-axolotl-cloud-no-tmux.The
build-axolotl-cloud-no-tmuxjob does not include a CUDA 13.0 matrix entry, while bothbuild-axolotlandbuild-axolotl-cloudjobs have been updated with CUDA 13.0 configurations. Please confirm whether this omission is intentional or if a CUDA 13.0 entry should be added here as well.docker/Dockerfile-base (2)
10-10: LGTM! Python version bump aligns with workflow changes.Bumping Python from 3.10 to 3.11 is consistent with the matrix configuration in
.github/workflows/main.yml.
12-12: LGTM! CUDA version default updated correctly.The CUDA default version bump from 118 to 128 aligns with the workflow matrix changes.
| RUN case "$PYTORCH_VERSION" in \ | ||
| 2.9.[0-9]*) \ | ||
| if [ "$CUDA" = "128" ]; then \ | ||
| wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ | ||
| pip3 install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ | ||
| rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ | ||
| fi ;; \ | ||
| if [ "$CUDA" = "130" ]; then \ | ||
| wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \ | ||
| pip3 install --no-cache-dir flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \ | ||
| rm flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \ | ||
| fi ;; \ | ||
| esac \ |
There was a problem hiding this comment.
Critical shell syntax error in case statement.
The if [ "$CUDA" = "130" ] block (lines 61-65) is placed outside the case pattern. Line 60 ends with fi ;; which closes the if statement and terminates the 2.9.[0-9]*) case branch, leaving lines 61-65 orphaned outside any case pattern. This will cause a shell syntax error during the Docker build.
🔎 Proposed fix
Move the CUDA 130 block inside the case pattern:
RUN case "$PYTORCH_VERSION" in \
2.9.[0-9]*) \
if [ "$CUDA" = "128" ]; then \
wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
pip3 install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
- fi ;; \
+ fi; \
if [ "$CUDA" = "130" ]; then \
wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
pip3 install --no-cache-dir flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
rm flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
- fi ;; \
+ fi; \
+ ;; \
esac \📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| RUN case "$PYTORCH_VERSION" in \ | |
| 2.9.[0-9]*) \ | |
| if [ "$CUDA" = "128" ]; then \ | |
| wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ | |
| pip3 install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ | |
| rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ | |
| fi ;; \ | |
| if [ "$CUDA" = "130" ]; then \ | |
| wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \ | |
| pip3 install --no-cache-dir flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \ | |
| rm flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \ | |
| fi ;; \ | |
| esac \ | |
| RUN case "$PYTORCH_VERSION" in \ | |
| 2.9.[0-9]*) \ | |
| if [ "$CUDA" = "128" ]; then \ | |
| wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ | |
| pip3 install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ | |
| rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ | |
| fi; \ | |
| if [ "$CUDA" = "130" ]; then \ | |
| wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \ | |
| pip3 install --no-cache-dir flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \ | |
| rm flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \ | |
| fi; \ | |
| ;; \ | |
| esac \ |
🤖 Prompt for AI Agents
In @docker/Dockerfile-base around lines 54-66, The case branch for pattern
"2.9.[0-9]*)" closes prematurely so the CUDA 130 if-block is left outside the
case and causes a shell syntax error; edit the RUN case on PYTORCH_VERSION so
both CUDA checks (the if [ "$CUDA" = "128" ] and the if [ "$CUDA" = "130" ]) are
placed inside the same "2.9.[0-9]*)" pattern before its terminating fi ;; (i.e.,
move the CUDA 130 block up into the same case branch and ensure a single
matching fi ;; ends that branch).
Codecov Report✅ All modified and coverable lines are covered by tests. 📢 Thoughts on this report? Let us know! |
Description
Motivation and Context
How has this been tested?
Screenshots (if appropriate)
Types of changes
Social Handles (Optional)
Summary by CodeRabbit
✏️ Tip: You can customize this high-level summary in your review settings.