Skip to content

use updated version of prebuilt wheels for flash attention for cu130#3342

Merged
winglian merged 4 commits into
mainfrom
flash-attn-cu130
Jan 5, 2026
Merged

use updated version of prebuilt wheels for flash attention for cu130#3342
winglian merged 4 commits into
mainfrom
flash-attn-cu130

Conversation

@winglian
Copy link
Copy Markdown
Collaborator

@winglian winglian commented Jan 5, 2026

Description

Motivation and Context

How has this been tested?

Screenshots (if appropriate)

Types of changes

Social Handles (Optional)

Summary by CodeRabbit

  • Chores
    • Added support for CUDA 13.0 and Python 3.11 in build pipelines.
    • Updated base Docker image to use Python 3.11 and CUDA 12.8, with optimized support for CUDA 13.0 configurations.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Jan 5, 2026

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

📝 Walkthrough

Walkthrough

Activates a new CUDA 13.0 / Python 3.11 / PyTorch 2.9.1 matrix configuration in GitHub Actions workflows and updates the Dockerfile base image to support flash-attention wheel installation for both CUDA 12.8 and 13.0 architectures via conditional logic.

Changes

Cohort / File(s) Summary
GitHub Actions Workflow Configuration
.github/workflows/main.yml
Uncomments and enables a matrix entry for CUDA 130, Python 3.11, and PyTorch 2.9.1 across build-axolotl and build-axolotl-cloud job pipelines
Docker Base Image
docker/Dockerfile-base
Bumps base Python to 3.11 and CUDA to 128; refactors flash-attention installation from single conditional to case/esac structure supporting CUDA 128 (flash_attn-2.8.3+cu128torch2.9) and CUDA 130 (flash_attn-2.8.3+cu130torch2.9) wheels

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

Possibly related PRs

Suggested labels

ready to merge

Suggested reviewers

  • SalmanMohammadi
  • djsaunde
  • NanoCode012

Pre-merge checks

✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately reflects the main change: enabling and updating flash attention prebuilt wheels for CUDA 13.0.
Docstring Coverage ✅ Passed No functions found in the changed files to evaluate docstring coverage. Skipping docstring coverage check.

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Fix all issues with AI Agents 🤖
In @docker/Dockerfile-base:
- Around line 54-66: The case branch for pattern "2.9.[0-9]*)" closes
prematurely so the CUDA 130 if-block is left outside the case and causes a shell
syntax error; edit the RUN case on PYTORCH_VERSION so both CUDA checks (the if [
"$CUDA" = "128" ] and the if [ "$CUDA" = "130" ]) are placed inside the same
"2.9.[0-9]*)" pattern before its terminating fi ;; (i.e., move the CUDA 130
block up into the same case branch and ensure a single matching fi ;; ends that
branch).
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b26ba3a and e16a562.

📒 Files selected for processing (2)
  • .github/workflows/main.yml
  • docker/Dockerfile-base
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (14)
  • GitHub Check: PyTest (3.11, 2.8.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.9.1)
  • GitHub Check: PyTest (3.11, 2.9.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.9.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.8.0)
  • GitHub Check: PyTest (3.11, 2.9.1)
  • GitHub Check: build-base-uv (128, 12.8.1, 3.11, 2.9.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX, Dockerfile-uv-base)
  • GitHub Check: build-base (128, 12.8.1, 3.11, 2.9.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX, Dockerfile-base)
  • GitHub Check: build-base-uv (128, 12.8.1, 3.11, 2.8.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX, Dockerfile-uv-base)
  • GitHub Check: build-base (130, 13.0.0, 3.11, 2.9.1, 9.0+PTX, Dockerfile-base)
  • GitHub Check: build-base-uv (128, 12.8.1, 3.11, 2.9.1, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX, Dockerfile-uv-base)
  • GitHub Check: build-base (128, 12.8.1, 3.11, 2.9.1, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX, Dockerfile-base)
  • GitHub Check: build-base-uv (130, 13.0.0, 3.11, 2.9.1, 9.0+PTX, Dockerfile-uv-base)
  • GitHub Check: build-base (128, 12.8.1, 3.11, 2.8.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX, Dockerfile-base)
🔇 Additional comments (5)
.github/workflows/main.yml (3)

34-38: LGTM! New CUDA 13.0 matrix entry added correctly.

The new matrix configuration for CUDA 13.0 with PyTorch 2.9.1 and Python 3.11 is properly structured and consistent with existing entries.


101-105: LGTM! Consistent CUDA 13.0 configuration in build-axolotl-cloud.

The matrix entry matches the configuration added to the build-axolotl job.


140-191: Verify whether CUDA 13.0 support is needed for build-axolotl-cloud-no-tmux.

The build-axolotl-cloud-no-tmux job does not include a CUDA 13.0 matrix entry, while both build-axolotl and build-axolotl-cloud jobs have been updated with CUDA 13.0 configurations. Please confirm whether this omission is intentional or if a CUDA 13.0 entry should be added here as well.

docker/Dockerfile-base (2)

10-10: LGTM! Python version bump aligns with workflow changes.

Bumping Python from 3.10 to 3.11 is consistent with the matrix configuration in .github/workflows/main.yml.


12-12: LGTM! CUDA version default updated correctly.

The CUDA default version bump from 118 to 128 aligns with the workflow matrix changes.

Comment thread docker/Dockerfile-base Outdated
Comment on lines +54 to +66
RUN case "$PYTORCH_VERSION" in \
2.9.[0-9]*) \
if [ "$CUDA" = "128" ]; then \
wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
pip3 install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
fi ;; \
if [ "$CUDA" = "130" ]; then \
wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
pip3 install --no-cache-dir flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
rm flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
fi ;; \
esac \
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Critical shell syntax error in case statement.

The if [ "$CUDA" = "130" ] block (lines 61-65) is placed outside the case pattern. Line 60 ends with fi ;; which closes the if statement and terminates the 2.9.[0-9]*) case branch, leaving lines 61-65 orphaned outside any case pattern. This will cause a shell syntax error during the Docker build.

🔎 Proposed fix

Move the CUDA 130 block inside the case pattern:

 RUN case "$PYTORCH_VERSION" in \
         2.9.[0-9]*) \
             if [ "$CUDA" = "128" ]; then \
                 wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
                 pip3 install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
                 rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
-            fi ;; \
+            fi; \
             if [ "$CUDA" = "130" ]; then \
                 wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
                 pip3 install --no-cache-dir flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
                 rm flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
-            fi ;; \
+            fi; \
+            ;; \
     esac \
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
RUN case "$PYTORCH_VERSION" in \
2.9.[0-9]*) \
if [ "$CUDA" = "128" ]; then \
wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
pip3 install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
fi ;; \
if [ "$CUDA" = "130" ]; then \
wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
pip3 install --no-cache-dir flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
rm flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
fi ;; \
esac \
RUN case "$PYTORCH_VERSION" in \
2.9.[0-9]*) \
if [ "$CUDA" = "128" ]; then \
wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
pip3 install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
fi; \
if [ "$CUDA" = "130" ]; then \
wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
pip3 install --no-cache-dir flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
rm flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
fi; \
;; \
esac \
🤖 Prompt for AI Agents
In @docker/Dockerfile-base around lines 54-66, The case branch for pattern
"2.9.[0-9]*)" closes prematurely so the CUDA 130 if-block is left outside the
case and causes a shell syntax error; edit the RUN case on PYTORCH_VERSION so
both CUDA checks (the if [ "$CUDA" = "128" ] and the if [ "$CUDA" = "130" ]) are
placed inside the same "2.9.[0-9]*)" pattern before its terminating fi ;; (i.e.,
move the CUDA 130 block up into the same case branch and ensure a single
matching fi ;; ends that branch).

@codecov
Copy link
Copy Markdown

codecov Bot commented Jan 5, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

@winglian winglian merged commit 4e61b8a into main Jan 5, 2026
10 of 17 checks passed
@winglian winglian deleted the flash-attn-cu130 branch January 5, 2026 18:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant