Skip to content

Fix OOM in CI by reducing batch size in VLM SFT tests#5687

Merged
albertvillanova merged 1 commit into
mainfrom
pfix-5207-per_device_train_batch_size
Apr 30, 2026
Merged

Fix OOM in CI by reducing batch size in VLM SFT tests#5687
albertvillanova merged 1 commit into
mainfrom
pfix-5207-per_device_train_batch_size

Conversation

@albertvillanova

@albertvillanova albertvillanova commented Apr 30, 2026

Copy link
Copy Markdown
Member

Fix OOM in CI by reducing batch size in VLM SFT tests.

Partial fix for:

Motivation

VLM training tests in test_sft_trainer.py were running with the default per_device_train_batch_size=8. For Gemma3, with vocab_size=262208 (production-scale, never reduced for tiny models) and mm_tokens_per_image=256, each training step computes logits of shape [8, 279, 262208].

PyTorch needs several float32 copies of this tensor for log-softmax and its gradient, pushing peak GPU memory to ~9 GiB per worker. With 4 parallel pytest-xdist workers this caused CUDA out-of-memory errors for other concurrent tests.

Solution

Set per_device_train_batch_size=1 in test_train_vlm, test_train_vlm_multi_image, and test_train_vlm_prompt_completion, following the pattern already used in test_train_vlm_gemma_3n. This drops peak GPU memory to ~1.1 GiB per worker for Gemma3, leaving ample headroom for parallel execution.


Note

Low Risk
Test-only change that lowers batch size to avoid CI OOMs; no production code paths are modified.

Overview
Reduces GPU memory pressure in vision-language SFT integration tests by explicitly setting per_device_train_batch_size=1 in test_train_vlm, test_train_vlm_multi_image, and test_train_vlm_prompt_completion.

This prevents CI CUDA OOMs during parallel test execution while keeping the VLM-specific max_length=None behavior unchanged.

Reviewed by Cursor Bugbot for commit 0480d77. Bugbot is set up for automated code reviews on this repo. Configure here.

@HuggingFaceDocBuilderDev

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@qgallouedec

Copy link
Copy Markdown
Member

It seems to work!

@albertvillanova albertvillanova merged commit 32bec88 into main Apr 30, 2026
13 of 14 checks passed
@albertvillanova albertvillanova deleted the pfix-5207-per_device_train_batch_size branch April 30, 2026 13:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants