[core] Faster and thread-safe check_model_inputs implementation#43765
[core] Faster and thread-safe check_model_inputs implementation#43765Cyrilvallez merged 17 commits intomainfrom
check_model_inputs implementation#43765Conversation
check_model_inputs mechanismcheck_model_inputs implementation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
ArthurZucker
left a comment
There was a problem hiding this comment.
Nice, we have a big decision to make here tho.
One thing that I would love to reach is having 0 hooks / a normal forward when we are not recording to avoid having the unreadable stacktrace.
Other than that, let's try to reduce the overhead for inference-heavy workloads that never use capturing.
Currently, every hooked module calls _active_keys.get() on every forward pass, even when capturing is disabled.ContextVar lookup probably is not free.
We can add a fast-path guard using a simple global ref count of the active recorders.
| "CLIPTextTransformer", # was not a PreTrainedModel originally but needs to be | ||
| "CLIPVisionTransformer", # was not a PreTrainedModel originally but needs to be | ||
| "MetaClip2TextTransformer", # was not a PreTrainedModel originally but needs to be | ||
| "MetaClip2VisionTransformer", # was not a PreTrainedModel originally but needs to be | ||
| "MLCDVisionTransformer", # was not a PreTrainedModel originally but needs to be |
There was a problem hiding this comment.
why do they need to be?
There was a problem hiding this comment.
Due to the super weird/bad structure of the other composite models using them...
Basically the already public PreTrainedModel were just very thin wrapper around them doing nothing more, so the composite models use the following in __init__:
text_model = CLIPTextModel._from_config(text_config)
self.text_model = text_model.text_model # <-- this child is `"CLIPTextTransformer"`
so it uses the submodel of the public wrapper CLIPTextModel (which is CLIPTextTransformer) directly. However, since CLIPTextTransformer needs to capture output with its own _can_record_putputs, it needs to be a PreTrainedModel. I have absolutely no clue why CLIPTextModel would even exist originally since other models use it to instantiate, but then bypass it by using the child....
There was a problem hiding this comment.
cc @zucchini-nlp @molbap this is known on the vision side, it's super awkward and is why it needs a refactor
There was a problem hiding this comment.
yep, a known issue and easily fixable by re-ordering modules. @yonigozlan had a plan to work on it :)
|
Yes I agree! I woke up this morning with the same idea basically: I want to defer hook installation as most of the time we don't want any output recording - I believe we can do so with simple locks to avoid installing twice. As for the |
8a66f19 to
a8b8611
Compare
|
Actually I was wrong and it's not the hooks themselves that were slowing down the forward. They basically have no visible cost. Anyway, I benchmarked everything and this PR makes everything slightly faster in both cases. |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: aria, aya_vision, blip_2, blt, clip, cohere2_vision, conditional_detr, deformable_detr, depth_pro, doge, edgetam, edgetam_video, ernie4_5_moe, ernie4_5_vl_moe, esm, evolla |
ArthurZucker
left a comment
There was a problem hiding this comment.
Should we have pytest-benchmark on tests specifc to this?
…ggingface#43765) * poc * cleaner * style * forgot name * fix all * more fixes * oupsi * fix * use only 1 contextvar * lazy hook installation * capture with lists instead of tuples * more bad-behaved model fixes * style * move to dedicated file * doc * style
|
@Cyrilvallez Was this fix backported into the v4 releases of Transformers? We arent ready to migrate to v5 I suspect for quite a while (OneTrainer, I help with Dxzqb). Would be a big boon for us if it was but I can understand why it may have not |
|
It did not but we can push for that if required! (gonna be a bit hard) |
|
Yep, not sure how worth it it is, and it's gonna be a mess to pick that commit back to v4 |
|
@O-J1 what can we do to help transition to v5? |
At the moment not much 🤣 we are swamped with refactors, PyTorch bugs and functionality additions. Once we have the majority of our big ticket PR items merged I think there will be more headspace for us to transition. Based on what you have both communicated I don’t wanna make your codebase a mess so let’s ignore my request. We will bite the bullet eventually |
What does this PR do?
The problem
Currently,
check_model_inputsneeds to iterate on all modules and monkey-patch all needed submodule'sforwardon-the-fly, before restoring them afterwards. This brings 2 big issues:forwardneed to traverse the whole model graph to monkey patch every moduleNote that thread-safety is an important aspect for most simple servers, i.e. gradio spaces etc as the model then runs on different threads. So even if the
output_xxxkwargs are somewhat niche, I believe it's an important issue that we should not overlook.The proposal
We can instead use forward hooks, and only activate them when needed. Coupled with
ContextVarsfor thread-safety on thecollected_outputs, this brings a much more safe, clean and efficient implementation. The idea is the following:output_xxxkwargsforwardof course), use aContextVarinside the hook, so the collected outputs are thread-safe (ContextVarbasically behaves as a thread-safe global variable and it blazingly fast)torch.compiledoes not support tracing thegetofContextVar, use a simple trick to use a simple global variable in those cases. This unfortunately means thattorch.compile+return_xxxare not thread-safe, but works on single threadTLDR - What is solves
Reproduction code
memory_leak.py
The above script shows the issue mentionned and the associated memory leak. On current main, memory keeps increasing without bounds with time. On this PR, it is solved and memory is stable.
Moreover, the output gathered are now correct compared to fully random before.
benchmark_speed.py
The above script was used to benchmark the speed of both approches when using
output_xxxor not.Results on 1 small foward of Llama3 8B (input size (2, 56)):
So this PR is slightly faster in both cases.
Note
Some models did not use the decorator correctly or very redundantly, that's why I had to perform some minimal changes on a few.