Skip to content

Commit 09eac16

Browse files
authored
[aoti-et] Enable multimodal runner for Voxtral on CUDA (#14980)
This pull request introduces changes to the CUDA workflow, model artifact handling, and multimodal runner logic. The main changes include restructuring the GitHub Actions workflow to separate model export, benchmarking, and end-to-end testing for the Voxtral CUDA pipeline, improving artifact management and reproducibility. Additionally, the multimodal runner now supports automatic conversion of audio tensors to bfloat16, ensuring compatibility with expected input types. There are also enhancements to caching and symbol registration in the CUDA backend, and build system updates to support linking the CUDA backend. **Workflow and Artifact Management Improvements:** * Refactored `.github/workflows/cuda.yml` to split the Voxtral CUDA pipeline into three jobs: `export-voxtral-cuda-artifact` (exports and stores model artifacts), `benchmark-voxtral-cuda` (benchmarks using exported artifacts), and `test-voxtral-cuda-e2e` (runs full end-to-end tests with artifact download and audio input). Improved artifact handling, reproducibility, and added explicit checks for required files. [[1]](diffhunk://#diff-29abea04e0613c2569973e5c8e3c89e04846d408c855eeb1f3efcfae7cfa6f89L90-R91) [[2]](diffhunk://#diff-29abea04e0613c2569973e5c8e3c89e04846d408c855eeb1f3efcfae7cfa6f89R107) [[3]](diffhunk://#diff-29abea04e0613c2569973e5c8e3c89e04846d408c855eeb1f3efcfae7cfa6f89R134-R185) [[4]](diffhunk://#diff-29abea04e0613c2569973e5c8e3c89e04846d408c855eeb1f3efcfae7cfa6f89R196-R267) [[5]](diffhunk://#diff-29abea04e0613c2569973e5c8e3c89e04846d408c855eeb1f3efcfae7cfa6f89R122) **Multimodal Runner Logic:** * Added automatic conversion of audio tensors to bfloat16 in `MultimodalPrefiller::prefill` and implemented a helper function `convert_to_bfloat16` in `util.h` to support this. This ensures that audio inputs match the expected dtype for the encoder, improving robustness for multimodal inference. [[1]](diffhunk://#diff-ad4fcb32ffc5f1f7b4f87b5ee58927cb948a8c0976295befd10e3de445913ae4L96-R136) [[2]](diffhunk://#diff-db4801445eaa3bb4f1370fe41d3a00ae2e3ef354a23ad4d5ace141ecc3c6f413R144-R180) **CUDA Backend and Caching Enhancements:** * Improved caching logic in `common_shims.cpp` for tensor strides and sizes by validating cached values and updating them when necessary. This prevents stale cache issues and ensures correct tensor metadata. [[1]](diffhunk://#diff-1e7c9d572d434c9a85c9d466e7f406877bc974a373c370fe7ddb3fe32852c1f2R54-R81) [[2]](diffhunk://#diff-1e7c9d572d434c9a85c9d466e7f406877bc974a373c370fe7ddb3fe32852c1f2R104-R130) * Added dynamic symbol re-registration in `CudaBackend` to handle multiple shared objects in the same process, ensuring correct execution when switching between models. * Removed redundant logging statements in CUDA backend for cleaner output. [[1]](diffhunk://#diff-a4b17eccf1aa933837671c5184e02bc815d934a362344bb2b17b789cdfaa5375L226) [[2]](diffhunk://#diff-a4b17eccf1aa933837671c5184e02bc815d934a362344bb2b17b789cdfaa5375L256) **Build System Updates:** * Updated `CMakeLists.txt` and `executorch-config.cmake` to include and link the CUDA backend (`aoti_cuda`) when building Voxtral and other components, improving build flexibility and CUDA support. [[1]](diffhunk://#diff-606feb24310595f592d98d021a2c90618346977d94decb80b35b7e26ed8ccc1eR89-R95) [[2]](diffhunk://#diff-6a78a155992483ff6f35d595ff6cef63b477d1c853f6482e77acae6ef443f0e4R56) **Debugging and Tuning Options:** * Added support for enabling debug compilation in `cuda_backend.py` via the `DEBUG` environment variable, allowing easier troubleshooting and development.
1 parent 7533df6 commit 09eac16

File tree

11 files changed

+374
-30
lines changed

11 files changed

+374
-30
lines changed

.github/workflows/cuda.yml

Lines changed: 130 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ jobs:
8787
export LD_LIBRARY_PATH=/opt/conda/lib:$LD_LIBRARY_PATH
8888
PYTHON_EXECUTABLE=python source .ci/scripts/test_model.sh "${{ matrix.model }}" cmake cuda
8989
90-
test-voxtral-cuda-e2e:
91-
name: test-voxtral-cuda-e2e
90+
export-voxtral-cuda-artifact:
91+
name: export-voxtral-cuda-artifact
9292
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
9393
permissions:
9494
id-token: write
@@ -104,6 +104,7 @@ jobs:
104104
gpu-arch-version: 12.6
105105
use-custom-docker-registry: false
106106
submodules: recursive
107+
upload-artifact: voxtral-cuda-export
107108
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
108109
script: |
109110
set -eux
@@ -118,6 +119,7 @@ jobs:
118119
OPTIMUM_ET_VERSION=$(cat .ci/docker/ci_commit_pins/optimum-executorch.txt)
119120
pip install git+https://github.com/huggingface/optimum-executorch.git@${OPTIMUM_ET_VERSION}
120121
pip install mistral-common librosa
122+
pip list
121123
echo "::endgroup::"
122124
123125
echo "::group::Export Voxtral"
@@ -129,9 +131,58 @@ jobs:
129131
--device cuda \
130132
--max_seq_len 1024 \
131133
--output_dir ./
134+
python -m executorch.extension.audio.mel_spectrogram \
135+
--feature_size 128 \
136+
--stack_output \
137+
--max_audio_len 300 \
138+
--output_file voxtral_preprocessor.pte
139+
140+
test -f model.pte
141+
test -f aoti_cuda_blob.ptd
142+
test -f voxtral_preprocessor.pte
132143
echo "::endgroup::"
133144
134-
echo "::group::Build Voxtral Runner"
145+
echo "::group::Store Voxtral Artifacts"
146+
mkdir -p "${RUNNER_ARTIFACT_DIR}"
147+
cp model.pte "${RUNNER_ARTIFACT_DIR}/"
148+
cp aoti_cuda_blob.ptd "${RUNNER_ARTIFACT_DIR}/"
149+
cp voxtral_preprocessor.pte "${RUNNER_ARTIFACT_DIR}/"
150+
ls -al "${RUNNER_ARTIFACT_DIR}"
151+
echo "::endgroup::"
152+
153+
benchmark-voxtral-cuda:
154+
name: benchmark-voxtral-cuda
155+
needs: export-voxtral-cuda-artifact
156+
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
157+
permissions:
158+
id-token: write
159+
contents: read
160+
strategy:
161+
fail-fast: false
162+
with:
163+
timeout: 90
164+
runner: linux.g5.4xlarge.nvidia.gpu
165+
gpu-arch-type: cuda
166+
gpu-arch-version: 12.6
167+
use-custom-docker-registry: false
168+
submodules: recursive
169+
download-artifact: voxtral-cuda-export
170+
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
171+
script: |
172+
set -eux
173+
174+
echo "::group::Setup ExecuTorch Requirements"
175+
CMAKE_ARGS="-DEXECUTORCH_BUILD_CUDA=ON" ./install_requirements.sh
176+
pip list
177+
echo "::endgroup::"
178+
179+
echo "::group::Prepare Voxtral Artifacts"
180+
cp "${RUNNER_ARTIFACT_DIR}/model.pte" .
181+
cp "${RUNNER_ARTIFACT_DIR}/aoti_cuda_blob.ptd" .
182+
ls -al model.pte aoti_cuda_blob.ptd
183+
echo "::endgroup::"
184+
185+
echo "::group::Build Voxtral Benchmark"
135186
cmake -DCMAKE_BUILD_TYPE=Release \
136187
-DEXECUTORCH_BUILD_CUDA=ON \
137188
-DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \
@@ -142,31 +193,90 @@ jobs:
142193
cmake --build cmake-out -j$(( $(nproc) - 1 )) --target voxtral_runner
143194
echo "::endgroup::"
144195
196+
echo "::group::Run Voxtral Benchmark"
197+
198+
export LD_LIBRARY_PATH=/opt/conda/lib:$LD_LIBRARY_PATH
199+
cmake-out/backends/cuda/voxtral_runner model.pte aoti_cuda_blob.ptd
200+
201+
echo "::endgroup::"
202+
203+
test-voxtral-cuda-e2e:
204+
name: test-voxtral-cuda-e2e
205+
needs: export-voxtral-cuda-artifact
206+
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
207+
permissions:
208+
id-token: write
209+
contents: read
210+
strategy:
211+
fail-fast: false
212+
with:
213+
timeout: 90
214+
runner: linux.g5.4xlarge.nvidia.gpu
215+
gpu-arch-type: cuda
216+
gpu-arch-version: 12.6
217+
use-custom-docker-registry: false
218+
submodules: recursive
219+
download-artifact: voxtral-cuda-export
220+
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
221+
script: |
222+
set -eux
223+
224+
echo "::group::Setup ExecuTorch Requirements"
225+
CMAKE_ARGS="-DEXECUTORCH_BUILD_CUDA=ON" ./install_requirements.sh
226+
pip list
227+
echo "::endgroup::"
228+
229+
echo "::group::Prepare Voxtral Artifacts"
230+
cp "${RUNNER_ARTIFACT_DIR}/model.pte" .
231+
cp "${RUNNER_ARTIFACT_DIR}/aoti_cuda_blob.ptd" .
232+
cp "${RUNNER_ARTIFACT_DIR}/voxtral_preprocessor.pte" .
233+
TOKENIZER_URL="https://huggingface.co/mistralai/Voxtral-Mini-3B-2507/resolve/main/tekken.json"
234+
curl -L $TOKENIZER_URL -o tekken.json
235+
ls -al model.pte aoti_cuda_blob.ptd voxtral_preprocessor.pte tekken.json
236+
echo "::endgroup::"
237+
238+
echo "::group::Download Test Audio File"
239+
AUDIO_URL="https://github.com/voxserv/audio_quality_testing_samples/raw/refs/heads/master/testaudio/16000/test01_20s.wav"
240+
curl -L $AUDIO_URL -o poem.wav
241+
echo "::endgroup::"
242+
243+
echo "::group::Build Voxtral Runner"
244+
cmake --preset llm \
245+
-DEXECUTORCH_BUILD_CUDA=ON \
246+
-DCMAKE_INSTALL_PREFIX=cmake-out \
247+
-DCMAKE_BUILD_TYPE=Release \
248+
-Bcmake-out -S.
249+
cmake --build cmake-out -j$(( $(nproc) - 1 )) --target install --config Release
250+
251+
cmake -DEXECUTORCH_BUILD_CUDA=ON \
252+
-DCMAKE_BUILD_TYPE=Release \
253+
-Sexamples/models/voxtral \
254+
-Bcmake-out/examples/models/voxtral/
255+
cmake --build cmake-out/examples/models/voxtral --target voxtral_runner --config Release
256+
echo "::endgroup::"
257+
145258
echo "::group::Run Voxtral Runner"
146-
# Capture output and allow exit code 139 if we have the expected printout
147259
set +e
148260
export LD_LIBRARY_PATH=/opt/conda/lib:$LD_LIBRARY_PATH
149-
OUTPUT=$(cmake-out/backends/cuda/voxtral_runner model.pte aoti_cuda_blob.ptd 2>&1)
261+
OUTPUT=$(cmake-out/examples/models/voxtral/voxtral_runner \
262+
--model_path model.pte \
263+
--data_path aoti_cuda_blob.ptd \
264+
--tokenizer_path tekken.json \
265+
--audio_path poem.wav \
266+
--processor_path voxtral_preprocessor.pte \
267+
--temperature 0 2>&1)
150268
EXIT_CODE=$?
151269
set -e
152270
153271
echo "$OUTPUT"
154272
155-
# Check if the output contains "Run latency (ms):"
156-
if echo "$OUTPUT" | grep -q "Run latency (ms):"; then
157-
echo "Found expected output: 'Run latency (ms):'"
158-
if [ $EXIT_CODE -eq 139 ]; then
159-
echo "Exit code 139 (segfault) detected, but passing since we have the expected output"
160-
exit 0
161-
elif [ $EXIT_CODE -ne 0 ]; then
162-
echo "Unexpected exit code: $EXIT_CODE"
163-
exit $EXIT_CODE
164-
else
165-
echo "Command succeeded with exit code 0"
166-
exit 0
167-
fi
168-
else
169-
echo "Expected output 'Run latency (ms):' not found in output"
273+
if ! echo "$OUTPUT" | grep -iq "poem"; then
274+
echo "Expected output 'poem' not found in output"
170275
exit 1
171276
fi
277+
278+
if [ $EXIT_CODE -ne 0 ]; then
279+
echo "Unexpected exit code: $EXIT_CODE"
280+
exit $EXIT_CODE
281+
fi
172282
echo "::endgroup::"

backends/aoti/common_shims.cpp

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,32 @@ AOTITorchError aoti_torch_get_storage_offset(
5151

5252
AOTITorchError aoti_torch_get_strides(Tensor* tensor, int64_t** ret_strides) {
5353
auto it = internal::tensor_to_strides.find(tensor);
54+
bool needs_update = false;
55+
5456
if (it == internal::tensor_to_strides.end()) {
57+
needs_update = true;
58+
} else {
59+
// CRITICAL: Multimodal models reuse tensors with different shapes across
60+
// executions (e.g., variable-length audio). We MUST validate cached
61+
// metadata matches current tensor state, or CUDA kernels will receive
62+
// incorrect shapes leading to memory corruption and segfaults.
63+
auto tensor_strides = tensor->strides();
64+
needs_update = !std::equal(
65+
it->second.begin(),
66+
it->second.end(),
67+
tensor_strides.begin(),
68+
tensor_strides.end());
69+
}
70+
71+
if (needs_update) {
5572
std::vector<int64_t> strides(tensor->dim());
5673
auto tensor_strides = tensor->strides();
5774
for (int i = 0; i < tensor->dim(); i++) {
5875
strides[i] = tensor_strides[i];
5976
}
60-
it = internal::tensor_to_strides.emplace(tensor, std::move(strides)).first;
77+
it =
78+
internal::tensor_to_strides.insert_or_assign(tensor, std::move(strides))
79+
.first;
6180
}
6281

6382
// For 0D tensors, data() returns nullptr on empty vectors, but we need to
@@ -80,13 +99,31 @@ AOTITorchError aoti_torch_get_dtype(Tensor* tensor, int32_t* ret_dtype) {
8099

81100
AOTITorchError aoti_torch_get_sizes(Tensor* tensor, int64_t** ret_sizes) {
82101
auto it = internal::tensor_to_sizes.find(tensor);
102+
bool needs_update = false;
103+
83104
if (it == internal::tensor_to_sizes.end()) {
105+
needs_update = true;
106+
} else {
107+
// CRITICAL: Multimodal models reuse tensors with different shapes across
108+
// executions (e.g., variable-length audio). We MUST validate cached
109+
// metadata matches current tensor state, or CUDA kernels will receive
110+
// incorrect shapes leading to memory corruption and segfaults.
111+
auto tensor_sizes = tensor->sizes();
112+
needs_update = !std::equal(
113+
it->second.begin(),
114+
it->second.end(),
115+
tensor_sizes.begin(),
116+
tensor_sizes.end());
117+
}
118+
119+
if (needs_update) {
84120
std::vector<int64_t> sizes(tensor->dim());
85121
auto tensor_sizes = tensor->sizes();
86122
for (int i = 0; i < tensor->dim(); i++) {
87123
sizes[i] = tensor_sizes[i];
88124
}
89-
it = internal::tensor_to_sizes.emplace(tensor, std::move(sizes)).first;
125+
it = internal::tensor_to_sizes.insert_or_assign(tensor, std::move(sizes))
126+
.first;
90127
}
91128

92129
// For 0D tensors, data() returns nullptr on empty vectors, but we need to

backends/cuda/runtime/cuda_backend.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,14 @@ class ET_EXPERIMENTAL CudaBackend final
165165
Span<EValue*> args) const override {
166166
AOTIDelegateHandle* handle = (AOTIDelegateHandle*)handle_;
167167

168+
// Need to re-register all the symbols from the so_handle hosted by this
169+
// CudaBackend instance. The reason is that these symbols are
170+
// static/singleton across the whole process. When we share multiple methods
171+
// (meaning multiple so_handle) in the same process, we need to re-register
172+
// the symbols from the so_handle that is being used in this execution.
173+
ET_CHECK_OK_OR_RETURN_ERROR(
174+
register_shared_library_functions(handle->so_handle));
175+
168176
size_t n_inputs;
169177
AOTInductorModelContainerGetNumInputs(handle->container_handle, &n_inputs);
170178

@@ -223,7 +231,6 @@ class ET_EXPERIMENTAL CudaBackend final
223231
"Failed to copy input %d from CPU to GPU",
224232
i);
225233
}
226-
ET_LOG(Info, "Inputs copied to GPU");
227234
// Process output tensors: create GPU counterparts for ExecuTorch CPU
228235
// tensors
229236
for (int i = 0; i < n_outputs; i++) {
@@ -253,7 +260,6 @@ class ET_EXPERIMENTAL CudaBackend final
253260

254261
gpu_outputs[i] = gpu_output_handle;
255262
}
256-
ET_LOG(Info, "Outputs created on GPU");
257263
// Run AOTI container with GPU tensors
258264
AOTIRuntimeError error = AOTInductorModelContainerRun(
259265
handle->container_handle,

examples/models/voxtral/CMakeLists.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,13 @@ list(
8686
extension_flat_tensor
8787
)
8888

89+
# Link CUDA backend
90+
if(EXECUTORCH_BUILD_CUDA)
91+
find_package(CUDAToolkit REQUIRED)
92+
list(APPEND link_libraries aoti_cuda)
93+
executorch_target_link_options_shared_lib(aoti_cuda)
94+
endif()
95+
8996
# Add tokenizers
9097
list(APPEND link_libraries tokenizers::tokenizers)
9198

examples/models/voxtral/README.md

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,29 @@ optimum-cli export executorch \
3636

3737
This exports Voxtral with XNNPack backend acceleration and 4-bit weight/8-bit activation linear quantization.
3838

39+
## CUDA Support
40+
If your environment has CUDA support, you can enable the runner to run on CUDA for improved performance. Follow the export and runtime commands below:
41+
42+
**Note:** We are currently working on quantization support for CUDA. Currently, only bfloat16 dtype is supported for CUDA execution.
43+
44+
### Exporting with CUDA
45+
```
46+
optimum-cli export executorch \
47+
--model "mistralai/Voxtral-Mini-3B-2507" \
48+
--task "multimodal-text-to-text" \
49+
--recipe "cuda" \
50+
--dtype bfloat16 \
51+
--device cuda \
52+
--max_seq_len 1024 \
53+
--output_dir="voxtral"
54+
```
55+
56+
This will generate:
57+
- `model.pte` - The exported model
58+
- `aoti_cuda_blob.ptd` - The CUDA kernel blob required for runtime
59+
60+
See the "Building the multimodal runner" section below for instructions on building with CUDA support, and the "Running the model" section for runtime instructions.
61+
3962
# Running the model
4063
To run the model, we will use the Voxtral runner, which utilizes ExecuTorch's MultiModal runner API.
4164
The Voxtral runner will do the following things:
@@ -56,6 +79,8 @@ python -m executorch.extension.audio.mel_spectrogram --feature_size 128 --stack_
5679
```
5780

5881
## Building the multimodal runner
82+
83+
### Building for CPU (XNNPack)
5984
```
6085
# Build and install ExecuTorch
6186
cmake --preset llm -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=cmake-out -DEXECUTORCH_ENABLE_LOGGING=ON && cmake --build cmake-out -j16 --target install --config Release
@@ -64,6 +89,26 @@ cmake --preset llm -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=cmake-out -
6489
cmake -DCMAKE_INSTALL_PREFIX=cmake-out -DBUILD_TESTING=OFF -DCMAKE_BUILD_TYPE=Release -Bcmake-out/examples/models/voxtral examples/models/voxtral && cmake --build cmake-out/examples/models/voxtral -j16 --config Release
6590
```
6691

92+
### Building for CUDA
93+
```
94+
# Install ExecuTorch with CUDA support
95+
CMAKE_ARGS="-DEXECUTORCH_BUILD_CUDA=ON" ./install_executorch.sh
96+
97+
# Build the multimodal runner with CUDA
98+
cmake --preset llm \
99+
-DEXECUTORCH_BUILD_CUDA=ON \
100+
-DCMAKE_INSTALL_PREFIX=cmake-out \
101+
-DCMAKE_BUILD_TYPE=Release \
102+
-Bcmake-out -S.
103+
cmake --build cmake-out -j16 --target install --config Release
104+
105+
cmake -DEXECUTORCH_BUILD_CUDA=ON \
106+
-DCMAKE_BUILD_TYPE=Release \
107+
-Sexamples/models/voxtral \
108+
-Bcmake-out/examples/models/voxtral/
109+
cmake --build cmake-out/examples/models/voxtral --target voxtral_runner --config Release
110+
```
111+
67112
## Running the model
68113
You can download the `tekken.json` tokenizer from [Voxtral's HuggingFace repo](https://huggingface.co/mistralai/Voxtral-Mini-3B-2507).
69114

@@ -88,6 +133,12 @@ If you already have a preprocessed mel spectrogram saved as a `.bin` file, you c
88133
--audio_path path/to/preprocessed_audio.bin
89134
```
90135

136+
137+
**For CUDA:** Add the `--data_path` argument to provide the CUDA kernel blob to the commands above:
138+
```
139+
--data_path path/to/aoti_cuda_blob.ptd
140+
```
141+
91142
Example output:
92143
```
93144
The speaker in this audio seems to be talking about their concerns about a device called the model or maybe they're just talking about the model in general. They mention that the model was trained with the speaker for inference, which suggests that

0 commit comments

Comments
 (0)