Skip to content

feat: multi-arch CUDA Dockerfile and sm_121 (DGX Spark GB10)#840

Merged
alvarobartt merged 1 commit intohuggingface:mainfrom
nazq:feat/arm64-cuda-blackwell
Mar 31, 2026
Merged

feat: multi-arch CUDA Dockerfile and sm_121 (DGX Spark GB10)#840
alvarobartt merged 1 commit intohuggingface:mainfrom
nazq:feat/arm64-cuda-blackwell

Conversation

@nazq
Copy link
Copy Markdown
Contributor

@nazq nazq commented Mar 4, 2026

Summary

Builds on #827 (ARM64 CPU Dockerfile) by extending CUDA support to ARM64 and adding the DGX Spark GB10's sm_121 compute capability. Also adds the CI matrix entries and README updates needed to ship ARM64 images.

Changes

Dockerfile-cuda (multi-arch)

  • Use TARGETARCH to select correct sccache binary (x86_64 or aarch64)
  • Use TARGETARCH to select correct protoc binary (x86_64 or aarch_64)
  • Add sm_121 to nvprune section for DGX Spark GB10

compute_cap.rs

  • (120..=121, 120) => true — sm_121 runtime is compatible with sm_120 compiled binaries
  • (121, 121) => true — exact match for native sm_121 builds
  • Full test coverage for sm_121 compatibility matrix

flash_attn.rs

  • Allow runtime_compute_cap == 121 to use flash attention v2 (same arch family as sm_120)

build.yaml

  • Use matrix.platforms with fallback to linux/amd64 — enables per-variant platform selection without breaking existing entries

matrix.json

  • Add blackwell-121 entry (linux/arm64, CUDA_COMPUTE_CAP=121) for DGX Spark GB10
  • Add cpu-arm64 entry (linux/arm64, Dockerfile-arm64) for ARM64 CPU-only hosts

README.md

  • Add Platform column to Docker Images table
  • Add cpu-arm64-1.9 and 121-1.9 image entries
  • Replace Apple-only ARM64 section with comprehensive aarch64 docs covering CPU-only and CUDA build paths (DGX Spark, Jetson)
  • Add sm_121 to CUDA compute capability examples

Motivation

The NVIDIA DGX Spark uses the GB10 SoC with compute capability 12.1 (sm_121). This is a Blackwell-family chip (Grace + Blackwell GPU) on ARM64. Without these changes, TEI cannot run on the DGX Spark with CUDA acceleration.

Testing

  • docker build -f Dockerfile-cuda --build-arg CUDA_COMPUTE_CAP=121 --platform linux/arm64 .
  • Unit tests pass for compute_cap_matching with sm_121
  • CI matrix produces 121-{version}-grpc and cpu-arm64-{version}-grpc images

Closes #769

@nazq nazq force-pushed the feat/arm64-cuda-blackwell branch 3 times, most recently from 44f1190 to 8cf4772 Compare March 4, 2026 16:38
@alvarobartt alvarobartt self-requested a review March 6, 2026 09:18
alvarobartt
alvarobartt previously approved these changes Mar 6, 2026
Copy link
Copy Markdown
Member

@alvarobartt alvarobartt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the PR @nazq, looks really clean!

Could you review and update also the table with the different images at https://github.com/huggingface/text-embeddings-inference/blob/main/docs/source/en/supported_models.md? Then I'll merge and validate that the CI is working as expected, hoping to release v1.9.3 next week.

And thanks for building on top of @z4y4ts PR and keeping them as co-author, much appreciated 🤗

@nazq
Copy link
Copy Markdown
Contributor Author

nazq commented Mar 6, 2026

Thanks a lot for the PR @nazq, looks really clean!

Could you review and update also the table with the different images at https://github.com/huggingface/text-embeddings-inference/blob/main/docs/source/en/supported_models.md? Then I'll merge and validate that the CI is working as expected, hoping to release v1.9.3 next week.

And thanks for building on top of @z4y4ts PR and keeping them as co-author, much appreciated 🤗

Updated supported_models.md. I did update the CI too but I've not run it so all done by inspection.

@stefan-it
Copy link
Copy Markdown

Hi @nazq thanks so much for that PR!

I tested the PR on my Spark and I got a build failure:

41.46 error[E0521]: borrowed data escapes outside of associated function                                                                                                                  
41.46   --> /root/.cargo/registry/src/index.crates.io-1949cf8c6b5b557f/metrics-0.23.0/src/recorder/mod.rs:77:9                                                                            
41.46    |                                                                                                                                                                                
41.46 71 |       fn new(recorder: &dyn Recorder) -> Self {                                                                                                                                
41.46    |              --------  - let's call the lifetime of this reference `'1`                                                                                                        
41.46    |              |                                                                    
41.46    |              `recorder` is a reference that is only valid in the associated function body                                                                                      
41.46 ...                                                                                                                                                                                 
41.46 77 | /         LOCAL_RECORDER.with(|local_recorder| {                                                                                                                               
41.46 78 | |             local_recorder.set(Some(recorder_ptr));                                                                                                                          
41.46 79 | |         });                                                                                                                                                                  
41.46    | |          ^                                                                                                                                                                   
41.46    | |          |                                                                                                                                                                   
41.46    | |__________`recorder` escapes the associated function body here                                                                                                                
41.46    |            argument requires that `'1` must outlive `'static`                                                                                                                  
41.46    |                                                                                                                                                                                
41.46 note: raw pointer casts of trait objects cannot extend lifetimes                                                                                                                    
41.46   --> /root/.cargo/registry/src/index.crates.io-1949cf8c6b5b557f/metrics-0.23.0/src/recorder/mod.rs:75:60                                                                           
41.46    |                                                                                                                                                                                
41.46 75 |         let recorder_ptr = unsafe { NonNull::new_unchecked(recorder as *const _ as *mut _) };                                                                                  
41.46    |                                                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                      
41.46    = note: this was previously accepted by the compiler but was changed recently                                                                                                    
41.46    = help: see <https://github.com/rust-lang/rust/issues/141402> for more information                                                                                               
41.46                                                                                                                                                                                     
41.46 For more information about this error, try `rustc --explain E0521`.                                                                                                                 
41.47 error: could not compile `metrics` (lib) due to 1 previous error                                                                                                                    
41.47 warning: build failed, waiting for other jobs to finish...                                                                                                                          332.4                                                                                                                                                                                     332.4 thread 'main' (7) panicked at /root/.cargo/registry/src/index.crates.io-1949cf8c6b5b557f/cargo-chef-0.1.73/src/recipe.rs:218:27:                                                    
332.4 Exited with status code: 101                                                                                                                                                        
332.4 note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace                                                                                                       
------                                                                                                                                                                                    
Dockerfile-cuda:82                                                                                                                                                                        
--------------------                                                                                                                                                                      
  81 |                                                                                                                                                                                    
  82 | >>> RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \                                                                                                       
  83 | >>>     --mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \                                                                                                   
  84 | >>>     if [ ${CUDA_COMPUTE_CAP} -ge 75 -a ${CUDA_COMPUTE_CAP} -lt 80 ]; \                                                                                                         
  85 | >>>     then \                                                                                                                                                                     
  86 | >>>     cargo chef cook --release --features candle-cuda-turing --features static-linking --no-default-features --recipe-path recipe.json && sccache -s; \                         
  87 | >>>     else \                                                                                                                                                                     
  88 | >>>     cargo chef cook --release --features candle-cuda --features static-linking --no-default-features --recipe-path recipe.json && sccache -s; \                                
  89 | >>>     fi;                                                                                                                                                                        
  90 |                                                                                                                                                                                    
--------------------                                                                                                                                                                      
ERROR: failed to build: failed to solve: process "/bin/sh -c if [ ${CUDA_COMPUTE_CAP} -ge 75 -a ${CUDA_COMPUTE_CAP} -lt 80 ];     then     cargo chef cook --release --features candle-cud
a-turing --features static-linking --no-default-features --recipe-path recipe.json && sccache -s;     else     cargo chef cook --release --features candle-cuda --features static-linking 
--no-default-features --recipe-path recipe.json && sccache -s;     fi;" did not complete successfully: exit code: 101

After searching a bit, I found out that this #842 PR should fix it. So I applied these changes and the build finished without any errors. So I guess only a rebase is needed.

@nazq
Copy link
Copy Markdown
Contributor Author

nazq commented Mar 15, 2026

Great. Thanks for this i didn't buy a Spark till i knew we could get this PR in. Happy to rebase it

@nazq nazq force-pushed the feat/arm64-cuda-blackwell branch from a9395f8 to ad55ed2 Compare March 15, 2026 13:24
@nazq
Copy link
Copy Markdown
Contributor Author

nazq commented Mar 15, 2026

Hey @stefan-it — rebased onto upstream main, which now includes #842. Should fix the metrics crate build failure you hit. Let me know if it works on your Spark!

@stefan-it
Copy link
Copy Markdown

Hi @nazq many thanks! I did a fresh clone of the rebased branch and built it with:

docker build . -f Dockerfile-cuda --no-cache   --build-arg CUDA_COMPUTE_CAP=121   --platform linux/arm64 -t text-embeddings-inference:121-1.9-pr

result was:

[+] Building 895.2s (32/32) FINISHED                                                                                                                                       docker:default
 => [internal] load build definition from Dockerfile-cuda                                                                                                                            0.0s
 => => transferring dockerfile: 6.46kB                                                                                                                                               0.0s
 => [internal] load metadata for docker.io/nvidia/cuda:12.9.1-runtime-ubuntu24.04                                                                                                    0.2s
 => [internal] load metadata for docker.io/nvidia/cuda:12.9.1-devel-ubuntu24.04                                                                                                      0.2s
 => [internal] load .dockerignore                                                                                                                                                    0.0s
 => => transferring context: 53B                                                                                                                                                     0.0s
 => [internal] load build context                                                                                                                                                    0.0s
 => => transferring context: 17.28kB                                                                                                                                                 0.0s
 => CACHED [base-builder 1/6] FROM docker.io/nvidia/cuda:12.9.1-devel-ubuntu24.04@sha256:020bc241a628776338f4d4053fed4c38f6f7f3d7eb5919fecb8de313bb8ba47c                            0.0s
 => CACHED [base 1/3] FROM docker.io/nvidia/cuda:12.9.1-runtime-ubuntu24.04@sha256:1287141d283b8f06f45681b56a48a85791398c615888b1f96bfb9fc981392d98                                  0.0s
 => [base-builder 2/6] RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends     curl     libssl-dev     pkg-config     && rm -rf /var/l  22.1s
 => [base 2/3] RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends     ca-certificates     libssl-dev     curl     cuda-compat-12-9     19.6s
 => [base 3/3] COPY --chmod=775 cuda-entrypoint.sh entrypoint.sh                                                                                                                     0.0s
 => [base-builder 3/6] RUN case "arm64" in     "amd64") SCCACHE_ARCH=x86_64-unknown-linux-musl ;;     "arm64") SCCACHE_ARCH=aarch64-unknown-linux-musl ;;     *) echo "Unsupported   2.9s 
 => [base-builder 4/6] COPY rust-toolchain.toml rust-toolchain.toml                                                                                                                  0.0s 
 => [base-builder 5/6] RUN curl https://sh.rustup.rs -sSf | bash -s -- -y                                                                                                           32.5s 
 => [base-builder 6/6] RUN cargo install cargo-chef --version 0.1.73 --locked                                                                                                       49.9s 
 => [planner 1/7] WORKDIR /usr/src                                                                                                                                                   0.0s 
 => [builder 2/9] RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL     --mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN     if [ 121 -g  0.6s 
 => [planner 2/7] COPY backends backends                                                                                                                                             0.1s 
 => [planner 3/7] COPY core core                                                                                                                                                     0.1s 
 => [planner 4/7] COPY router router                                                                                                                                                 0.1s 
 => [planner 5/7] COPY Cargo.toml ./                                                                                                                                                 0.1s 
 => [planner 6/7] COPY Cargo.lock ./                                                                                                                                                 0.1s 
 => [planner 7/7] RUN cargo chef prepare  --recipe-path recipe.json                                                                                                                  0.2s
 => [builder 3/9] COPY --from=planner /usr/src/recipe.json recipe.json                                                                                                               0.1s
 => [builder 4/9] RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL     --mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN     if [ 121   360.7s
 => [builder 5/9] COPY backends backends                                                                                                                                             0.1s 
 => [builder 6/9] COPY core core                                                                                                                                                     0.1s 
 => [builder 7/9] COPY router router                                                                                                                                                 0.1s 
 => [builder 8/9] COPY Cargo.toml ./                                                                                                                                                 0.1s 
 => [builder 9/9] COPY Cargo.lock ./                                                                                                                                                 0.1s 
 => [http-builder 1/1] RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL     --mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN     if [  423.1s 
 => [stage-7 1/1] COPY --from=http-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router                                                      0.4s 
 => exporting to image                                                                                                                                                               1.5s 
 => => exporting layers                                                                                                                                                              1.4s 
 => => writing image sha256:2018875deaebfac387abad481f0f2bb7979853ad2b607297aa8bdba5b1d67ef4                                                                                         0.0s 
 => => naming to docker.io/library/text-embeddings-inference:121-1.9-pr

So definitely working on a Spark 🥳

@nazq
Copy link
Copy Markdown
Contributor Author

nazq commented Mar 16, 2026

I'll put my order in then ;-)

@JCorners68
Copy link
Copy Markdown

JCorners68 commented Mar 29, 2026

Independent DGX Spark Validation

hardware: DGX Spark (spark-97dd), NVIDIA GB10
arch: aarch64
compute_cap: 12.1 (sm_121)
cuda_driver: 590.48.01
cuda_version: 13.1
os: Ubuntu 24.04.4 LTS (Noble Numbat)
kernel: 6.17.0-1008-nvidia

Build

branch: feat/arm64-cuda-blackwell (ad55ed2)
base: f016879 (1 commit behind main — only v1.9.3 version bump missing, no conflicts)
cmd: docker build . -f Dockerfile-cuda --build-arg CUDA_COMPUTE_CAP=121 -t tei:121-pr840-test
build_time: ~2300s total wall clock (994.8s final Rust compilation, 0% sccache hit)
result: SUCCESS
image_sha: d4c11dd2776d

Smoke Tests

test_1_model: BAAI/bge-small-en-v1.5
test_1_flash_attn: ON (default)
test_1_backend: FlashBert
test_1_endpoint: /embed
test_1_result: PASS (384-dim float array, HTTP 200)
test_1_latency: 0.104s

test_2_model: BAAI/bge-small-en-v1.5
test_2_flash_attn: OFF (USE_FLASH_ATTENTION=False)
test_2_backend: Bert (standard, non-flash)
test_2_endpoint: /embed
test_2_result: PASS (384-dim float array, HTTP 200)
test_2_latency: 0.108s

test_3_model: BAAI/bge-reranker-base
test_3_flash_attn: ON (default)
test_3_backend: FlashBert
test_3_endpoint: /rerank
test_3_result: PASS (correct ranking — "deep learning" scored 0.996 vs "weather" at 0.00004)
test_3_latency: 0.135s
test_3_gpu_mem: 827 MiB

Key Finding

Flash attention works on sm_121 (GB10) with this PR. This is an improvement
over @gpadiolleau's original workaround in #769 which required USE_FLASH_ATTENTION=false.
The flash_attn.rs change adding runtime_compute_cap == 121 is correct and validated.

Notes

  • Second independent validation after @stefan-it's successful build (895s, Mar 15)
  • All three tests used the PR's default settings — no workarounds needed
  • PR is 1 commit behind main (v1.9.3 version bump only), clean rebase expected
  • Previous review by @alvarobartt was auto-dismissed after rebase force-push

Tested by: JC (jonathan.corners@voxell.ai) on DGX Spark (spark-97dd)

@alvarobartt alvarobartt added this to the v1.10.0 milestone Mar 30, 2026
alvarobartt
alvarobartt previously approved these changes Mar 30, 2026
Copy link
Copy Markdown
Member

@alvarobartt alvarobartt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks again @nazq 🙏🏻

Left some minor comments, happy to merge afterwards!

Comment thread docs/source/en/supported_models.md Outdated
Comment on lines +93 to +94
| Blackwell 12.1 (DGX Spark GB10, ...) | ghcr.io/huggingface/text-embeddings-inference:121-1.9 (experimental) |
| CPU (ARM64 / aarch64) | ghcr.io/huggingface/text-embeddings-inference:cpu-arm64-1.9 |
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please align this table with the same table in the README.md, using it as reference.

Comment thread README.md
| Architecture | Platform | Image |
|----------------------------------------|----------|-------------------------------------------------------------------------|
| CPU | x86_64 | ghcr.io/huggingface/text-embeddings-inference:cpu-1.9 |
| CPU | aarch64 | ghcr.io/huggingface/text-embeddings-inference:cpu-arm64-1.9 |
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add (experimental) here too despite already being validated, at least until we run aarch64 for a couple of releases?

Comment thread README.md
Comment on lines +624 to +625
For ARM64 hosts without NVIDIA GPUs, use the CPU Dockerfile. Inference runs on CPU cores
only (no Metal/MPS support via Docker).
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
For ARM64 hosts without NVIDIA GPUs, use the CPU Dockerfile. Inference runs on CPU cores
only (no Metal/MPS support via Docker).
For ARM64 hosts without NVIDIA GPUs such as Apple Silicon, use the `Dockerfile` for CPU,
where inference will run without any accelerator, as Metal / MPS is not supported via Docker.

Comment thread README.md
Comment on lines +633 to +634
For ARM64 hosts with NVIDIA GPUs, build `Dockerfile-cuda` with the appropriate compute
capability and `--platform linux/arm64`:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
For ARM64 hosts with NVIDIA GPUs, build `Dockerfile-cuda` with the appropriate compute
capability and `--platform linux/arm64`:
For ARM64 hosts with NVIDIA GPUs, use / build the `Dockerfile-cuda` with `--platform linux/arm64`,
and also with the `--build-arg CUDA_COMPUTE_CAP` set to whatever your instance compute capability is (only required when building the image).

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

- Add Dockerfile-cuda supporting both x86_64 and ARM64 (aarch64)
- Add sm_121 compute capability for NVIDIA GB10 (DGX Spark)
- Add cpu-arm64 image variant
- Update supported hardware documentation

Co-Authored-By: z4y4ts <z4y4ts@users.noreply.github.com>
@nazq
Copy link
Copy Markdown
Contributor Author

nazq commented Mar 30, 2026

Hey @alvarobartt — rebased onto main and aligned the supported_models.md table with the README (added Platform column, matched row order). Two independent Spark validations have confirmed the build and flash attention on sm_121:

  • @stefan-it: successful build (895s)
  • @JCorners68: full validation — build + 3 smoke tests (embed, embed no-flash, rerank) all passing

Ready for your re-review when you get a chance.

@alvarobartt alvarobartt merged commit 2e690c2 into huggingface:main Mar 31, 2026
3 of 18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

ARM64 Support

5 participants