Skip to content

[core] Faster and thread-safe check_model_inputs implementation#43765

Merged
Cyrilvallez merged 17 commits intomainfrom
hooks-instead-of-monkey-patching
Feb 6, 2026
Merged

[core] Faster and thread-safe check_model_inputs implementation#43765
Cyrilvallez merged 17 commits intomainfrom
hooks-instead-of-monkey-patching

Conversation

@Cyrilvallez
Copy link
Copy Markdown
Member

@Cyrilvallez Cyrilvallez commented Feb 5, 2026

What does this PR do?

The problem

Currently, check_model_inputs needs to iterate on all modules and monkey-patch all needed submodule's forward on-the-fly, before restoring them afterwards. This brings 2 big issues:

  • It's NOT thread-safe (see also issue Qwen3ForCausalLM leaks VRAM if used in multiple dataloader threads #42673): due to the monkey patching, 2 threads running the model's forward concurrently will each try to monkey-patch the global state, resulting in undefined behavior and memory leaks in the collected outputs
  • It's not efficient at all: every call to the model's forward need to traverse the whole model graph to monkey patch every module

Note that thread-safety is an important aspect for most simple servers, i.e. gradio spaces etc as the model then runs on different threads. So even if the output_xxx kwargs are somewhat niche, I believe it's an important issue that we should not overlook.

The proposal

We can instead use forward hooks, and only activate them when needed. Coupled with ContextVars for thread-safety on the collected_outputs, this brings a much more safe, clean and efficient implementation. The idea is the following:

  • install forward hooks ONCE lazily the first time we need output capture (so we only traverse the graph once) -> no monkey patching of forwards - if we never need output capture (most of the time) the hooks are NOT installed
  • These hooks are dormant once installed, and are only awakened when we have the output_xxx kwargs
  • Since the hooks need to use a global variable to collect the outputs (because we don't want to modify the signature of every module's forward of course), use a ContextVar inside the hook, so the collected outputs are thread-safe (ContextVar basically behaves as a thread-safe global variable and it blazingly fast)
  • Since torch.compile does not support tracing the get of ContextVar, use a simple trick to use a simple global variable in those cases. This unfortunately means that torch.compile + return_xxx are not thread-safe, but works on single thread
  • Collect outputs inside lists instead of tuples to avoid creating new objects all the time, and simply append to it

TLDR - What is solves

Need to traverse the whole graph every time Thread-safe without compile Thread-safe with compile
Currently 🚫 🚫 🚫
This PR 🚫

Reproduction code

memory_leak.py
import threading
import time

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "meta-llama/Llama-3.1-8B-Instruct"
device = 0
WARMUP = 3
TOKS = 100
THREADS = 2

model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16, device_map=device, attn_implementation="eager")
tokenizer = AutoTokenizer.from_pretrained(model_id)
messages = [
    {"role": "user", "content": "What do you think about life?"},
]
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(device)
input_size = inputs.input_ids.shape[-1]

generation_config = dict(
    min_new_tokens=TOKS,
    max_new_tokens=TOKS,
    do_sample=False,
    eos_token_id=tokenizer.eos_token_id,
    pad_token_id=tokenizer.eos_token_id,
    output_attentions=True,
    output_hidden_states=True,
    return_dict_in_generate=True,
)

def print_memory_peak():
    peak = torch.cuda.max_memory_allocated(device) / 1024 ** 3
    print(f"Memory peak is {peak:.2f} GiB")

def run_a_few():
    t0 = time.time()
    for _ in range(WARMUP):
        output = model.generate(**inputs, **generation_config)
    dt = time.time() - t0
    print_memory_peak()

def run_indefinitely():
    while True:
        output = model.generate(**inputs, **generation_config)
        print_memory_peak()

run_a_few()


threads = []
for _ in range(THREADS):
    threads.append(threading.Thread(target=run_indefinitely))

# Start each thread
for t in threads:
    t.start()

# Wait for all threads to finish
for t in threads:
    t.join()

The above script shows the issue mentionned and the associated memory leak. On current main, memory keeps increasing without bounds with time. On this PR, it is solved and memory is stable.
Moreover, the output gathered are now correct compared to fully random before.

benchmark_speed.py
from transformers import AutoModelForCausalLM
import numpy as np
import torch
import time

model_id = "meta-llama/Llama-3.1-8B-Instruct"
device = 0
WARMUP = 10
EXPERIENCE = 1000
INPUT_SHAPE = (2, 56)


model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16, device_map=device, attn_implementation="eager")

torch.manual_seed(12)
input_ids = torch.randint(10, 1000, INPUT_SHAPE, device=device)

# Warmup
tot = []
for _ in range(WARMUP):
    torch.cuda.synchronize(device)
    t0 = time.time()
    with torch.no_grad():
        out = model(input_ids, return_dict=True)
    torch.cuda.synchronize(device)
    dt = time.time() - t0
    tot.append(dt)
assert out.attentions is None
print(f"Warmup took {np.mean(tot):.2e} s")

# Actual test
tot = []
for _ in range(EXPERIENCE):
    torch.cuda.synchronize(device)
    t0 = time.time()
    with torch.no_grad():
        out = model(input_ids, return_dict=True)
    torch.cuda.synchronize(device)
    dt = time.time() - t0
    tot.append(dt)
assert out.attentions is None
print(f"Forward took {np.mean(tot):.2e} s")


# Actual test with return_xxx
tot = []
for _ in range(EXPERIENCE):
    torch.cuda.synchronize(device)
    t0 = time.time()
    with torch.no_grad():
        out = model(input_ids, return_dict=True, output_attentions=True)
    torch.cuda.synchronize(device)
    dt = time.time() - t0
    tot.append(dt)
assert out.attentions is not None
print(f"Return attention forward took {np.mean(tot):.2e} s")

The above script was used to benchmark the speed of both approches when using output_xxx or not.
Results on 1 small foward of Llama3 8B (input size (2, 56)):

Without output capture With output capture
Currently 3.37e-02 s 3.53e-02 s
This PR 3.33e-02 s (1.2% faster) 3.41e-02 s (3.5% faster)

So this PR is slightly faster in both cases.

Note

Some models did not use the decorator correctly or very redundantly, that's why I had to perform some minimal changes on a few.

@Cyrilvallez Cyrilvallez changed the title [core] Improve check_model_inputs mechanism [core] Faster and thread-safe check_model_inputs implementation Feb 5, 2026
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, we have a big decision to make here tho.

One thing that I would love to reach is having 0 hooks / a normal forward when we are not recording to avoid having the unreadable stacktrace.

Other than that, let's try to reduce the overhead for inference-heavy workloads that never use capturing.
Currently, every hooked module calls _active_keys.get() on every forward pass, even when capturing is disabled.ContextVar lookup probably is not free.

We can add a fast-path guard using a simple global ref count of the active recorders.

Comment thread utils/check_repo.py
Comment on lines +113 to +117
"CLIPTextTransformer", # was not a PreTrainedModel originally but needs to be
"CLIPVisionTransformer", # was not a PreTrainedModel originally but needs to be
"MetaClip2TextTransformer", # was not a PreTrainedModel originally but needs to be
"MetaClip2VisionTransformer", # was not a PreTrainedModel originally but needs to be
"MLCDVisionTransformer", # was not a PreTrainedModel originally but needs to be
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do they need to be?

Copy link
Copy Markdown
Member Author

@Cyrilvallez Cyrilvallez Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Due to the super weird/bad structure of the other composite models using them...
Basically the already public PreTrainedModel were just very thin wrapper around them doing nothing more, so the composite models use the following in __init__:

text_model = CLIPTextModel._from_config(text_config)
self.text_model = text_model.text_model   #  <-- this child is `"CLIPTextTransformer"`

so it uses the submodel of the public wrapper CLIPTextModel (which is CLIPTextTransformer) directly. However, since CLIPTextTransformer needs to capture output with its own _can_record_putputs, it needs to be a PreTrainedModel. I have absolutely no clue why CLIPTextModel would even exist originally since other models use it to instantiate, but then bypass it by using the child....

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @zucchini-nlp @molbap this is known on the vision side, it's super awkward and is why it needs a refactor

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, a known issue and easily fixable by re-ordering modules. @yonigozlan had a plan to work on it :)

Comment thread src/transformers/utils/generic.py Outdated
Comment thread src/transformers/utils/generic.py Outdated
@Cyrilvallez
Copy link
Copy Markdown
Member Author

Cyrilvallez commented Feb 6, 2026

Yes I agree! I woke up this morning with the same idea basically: I want to defer hook installation as most of the time we don't want any output recording - I believe we can do so with simple locks to avoid installing twice.

As for the ContextVar.get(), I benchmarked it carefully and it's actually completely negligible (~2e-6 s) compared to the forward of a module. However, calling the hook itself is not entirely negligible due to pytorch implem (i.e. even an empty hook with only pass create a very tiny measurable overhead)

@Cyrilvallez Cyrilvallez force-pushed the hooks-instead-of-monkey-patching branch from 8a66f19 to a8b8611 Compare February 6, 2026 13:42
@Cyrilvallez
Copy link
Copy Markdown
Member Author

Cyrilvallez commented Feb 6, 2026

Actually I was wrong and it's not the hooks themselves that were slowing down the forward. They basically have no visible cost. Anyway, I benchmarked everything and this PR makes everything slightly faster in both cases.
TLDR -> we go faster and we are muuuuch safer

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Feb 6, 2026

[For maintainers] Suggested jobs to run (before merge)

run-slow: aria, aya_vision, blip_2, blt, clip, cohere2_vision, conditional_detr, deformable_detr, depth_pro, doge, edgetam, edgetam_video, ernie4_5_moe, ernie4_5_vl_moe, esm, evolla

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we have pytest-benchmark on tests specifc to this?

@Cyrilvallez Cyrilvallez merged commit f8f2834 into main Feb 6, 2026
26 checks passed
@Cyrilvallez Cyrilvallez deleted the hooks-instead-of-monkey-patching branch February 6, 2026 16:08
jiosephlee pushed a commit to jiosephlee/transformers_latest that referenced this pull request Feb 11, 2026
…ggingface#43765)

* poc

* cleaner

* style

* forgot name

* fix all

* more fixes

* oupsi

* fix

* use only 1 contextvar

* lazy hook installation

* capture with lists instead of tuples

* more bad-behaved model fixes

* style

* move to dedicated file

* doc

* style
@O-J1
Copy link
Copy Markdown

O-J1 commented Feb 17, 2026

@Cyrilvallez Was this fix backported into the v4 releases of Transformers? We arent ready to migrate to v5 I suspect for quite a while (OneTrainer, I help with Dxzqb). Would be a big boon for us if it was but I can understand why it may have not

@ArthurZucker
Copy link
Copy Markdown
Collaborator

It did not but we can push for that if required! (gonna be a bit hard)

@Cyrilvallez
Copy link
Copy Markdown
Member Author

Yep, not sure how worth it it is, and it's gonna be a mess to pick that commit back to v4

@ArthurZucker
Copy link
Copy Markdown
Collaborator

@O-J1 what can we do to help transition to v5?

@O-J1
Copy link
Copy Markdown

O-J1 commented Feb 18, 2026

@O-J1 what can we do to help transition to v5?

@ArthurZucker

At the moment not much 🤣 we are swamped with refactors, PyTorch bugs and functionality additions. Once we have the majority of our big ticket PR items merged I think there will be more headspace for us to transition.

Based on what you have both communicated I don’t wanna make your codebase a mess so let’s ignore my request. We will bite the bullet eventually

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants