Fix multi-threaded dataloader for Qwen3/Mistral text encoders#1346
Merged
dxqb merged 1 commit intoNerogar:mergefrom Mar 1, 2026
Merged
Fix multi-threaded dataloader for Qwen3/Mistral text encoders#1346dxqb merged 1 commit intoNerogar:mergefrom
dxqb merged 1 commit intoNerogar:mergefrom
Conversation
Contributor
Author
|
Tested on Z-Image. Using 12 threads took caching 110k files from 3.5 hours to 50 minutes. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Enables
dataloader_threads > 1for Z-Image and Flux2.Klein models by working around a thread-safety bug in the transformers library'scheck_model_inputsdecorator (huggingface/transformers#42673).Closes #1291
Problem
The
check_model_inputsdecorator in transformers v4 monkey-patches child module.forward()methods on every call to captureoutput_hidden_states, then restores them after. When two dataloader threads call the same text encoder concurrently, they race on patching/restoring these methods, causing hidden states from different threads to bleed into each other.Fix
Wraps the text encoder's
.forward()with a per-instancethreading.Lockto serialize concurrent calls, preventing the race condition. The lock is applied conditionally only whendataloader_threads > 1and is idempotent (safe if called multiple times).Performance impact is negligible since GPU computation is already serialized on a single device. The benefit of multiple dataloader threads (pipelining CPU image loading/preprocessing against GPU encoding) is preserved.
Also proactively applies the same fix to the Flux2.Dev (Mistral) path, which has the same underlying vulnerability via
MistralModel.forward().The upstream fix (huggingface/transformers#43765) shipped in transformers v5 only. This workaround can be removed when upgrading to v5+.
Changes
modules/util/thread_safety.py: New utility —apply_thread_safe_forward()wraps a model's forward with a per-instance lockmodules/dataLoader/ZImageBaseDataLoader.py: ReplaceNotImplementedErrorwith thread-safe forward patchmodules/dataLoader/Flux2BaseDataLoader.py: ReplaceNotImplementedError(Klein) and add proactive fix (Dev)Testing Notes
Verified the bug and fix with a tiny Qwen3ForCausalLM (CPU, random weights, 280K params):
Tested on Windows 11, Python 3.10.11.
Will run a full test on Z-Image either today or tomorrow.