Skip to content

Comments

[Model] Add PaddleOCR-VL Model Support#42178

Merged
zucchini-nlp merged 24 commits intohuggingface:mainfrom
zhang-prog:feat/paddleocr_vl
Dec 11, 2025
Merged

[Model] Add PaddleOCR-VL Model Support#42178
zucchini-nlp merged 24 commits intohuggingface:mainfrom
zhang-prog:feat/paddleocr_vl

Conversation

@zhang-prog
Copy link
Contributor

@zhang-prog zhang-prog commented Nov 13, 2025

What does this PR do?

This PR adds PaddleOCR-VL model to Hugging Face Transformers from PaddleOCR.

Relevant Links:

PaddleOCR
https://huggingface.co/PaddlePaddle/PaddleOCR-VL

Usage

Use a pipeline

from transformers import pipeline

pipe = pipeline("image-text-to-text", model="PaddlePaddle/PaddleOCR-VL", dtype="bfloat16")
messages = [
    {
        "role": "user",
        "content": [
            {"type": "image", "url": "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/ocr_demo2.jpg"},
            {"type": "text", "text": "OCR:"},
        ]
    }
]
result = pipe(text=messages)
print(result)

Load model directly

from transformers import AutoProcessor, AutoModelForImageTextToText

processor = AutoProcessor.from_pretrained("PaddlePaddle/PaddleOCR-VL")
model = AutoModelForImageTextToText.from_pretrained("PaddlePaddle/PaddleOCR-VL", dtype="bfloat16")
messages = [
    {
        "role": "user",
        "content": [
            {"type": "image", "url": "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/ocr_demo2.jpg"},
            {"type": "text", "text": "OCR:"},
        ]
    }
]
inputs = processor.apply_chat_template(
	messages,
	add_generation_prompt=True,
	tokenize=True,
	return_dict=True,
	return_tensors="pt",
).to(model.device)

outputs = model.generate(**inputs, max_new_tokens=100)
result = processor.decode(outputs[0][inputs["input_ids"].shape[-1]:-1])
print(result)

@zucchini-nlp zucchini-nlp self-requested a review November 13, 2025 09:07
Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hey @zhang-prog , thanks for the PR! Great model to have in transformers!

The main thing to fix first is the naming, it should clearly include "PaddlePaddleOCR" and follow the usual pattern depending on the modality. The config format also isn’t right; it needs to be fully nested, with text and vision configs inside. Additionally there are no tests or docs, several files are missing. You can run transformers add-new-model-like which would generate a placeholder with the necessary files. I also left some smaller comments here and there. Let me know if you hit any issues

@zhang-prog
Copy link
Contributor Author

@zucchini-nlp
We have refactored the code to address the issues you mentioned in your comments.
Please review the code again when you have time.
Thank you for your efforts!!!

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhang-prog thanks for iterating!

There are a couple major comments which were not addressed and can be a blocker for merging.

  1. The model seems to not support batched inference in current state. We need to enable batching before merging if possible. Should not be hard I think given that the image tower is quite similar to existing models
  2. We also need tests to make sure everything actually works and a documentation page. These files are usually auto-prefilled with empty files when you run transformers add-new-model-like
  3. Let the modular copy automatically when possible. I think there are a few more modules which can be copied from similar models. If you struggle with finding a similar model, you can try out a modular detector

@zhang-prog
Copy link
Contributor Author

zhang-prog commented Nov 26, 2025

@zhang-prog thanks for iterating!

There are a couple major comments which were not addressed and can be a blocker for merging.

  1. The model seems to not support batched inference in current state. We need to enable batching before merging if possible. Should not be hard I think given that the image tower is quite similar to existing models
  2. We also need tests to make sure everything actually works and a documentation page. These files are usually auto-prefilled with empty files when you run transformers add-new-model-like
  3. Let the modular copy automatically when possible. I think there are a few more modules which can be copied from similar models. If you struggle with finding a similar model, you can try out a modular detector

@zucchini-nlp Thank you for your valuable insights! We’ve carefully addressed all comments and responded to your overall recommendations.

  1. We support bs > 1, like this:
from transformers import AutoProcessor, AutoModelForImageTextToText

processor = AutoProcessor.from_pretrained("PaddlePaddle/PaddleOCR-VL")
model = AutoModelForImageTextToText.from_pretrained("PaddlePaddle/PaddleOCR-VL", dtype="bfloat16")
messages1 = [
    {
        "role": "user",
        "content": [
            {"type": "image", "url": "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/ocr_demo.jpg"},
            {"type": "text", "text": "OCR:"},
        ]
    }
]
messages2 = [
    {
        "role": "user",
        "content": [
            {"type": "image", "url": "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/ocr_demo2.jpg"},
            {"type": "text", "text": "OCR:"},
        ]
    }
]
batch_messages = [messages1, messages2]
inputs = processor.apply_chat_template(
	batch_messages,
	add_generation_prompt=True,
	tokenize=True,
	return_dict=True,
	return_tensors="pt",
    padding=True,
    padding_side='left',
).to(model.device)

generated_ids = model.generate(**inputs, max_new_tokens=1024)
generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
result = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)
print(result)
  1. We still have some issues to discuss. I replied to your comment and will generate the final version of the document once it’s completed.

  2. We also added the PaddleOCRVisionConfig and PaddleOCRTextConfig into modular.

Thank you for your efforts. ❤️
PTAL.

@zhang-prog
Copy link
Contributor Author

@zucchini-nlp How do I properly add documentation pages and unit tests? I tried to use transformers add-new-model-like, which generates the new modular_xxx.py files, but this process might not be the right approach.

@zhang-prog
Copy link
Contributor Author

@zucchini-nlp PTAL. Thanks❤️

@zucchini-nlp
Copy link
Member

Sorry, taking a look.Got lost in my notifications

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, only a few comments and replied to your questions above

For the docs and the tests, they need to be in source/docs/en/model_doc and in tests folder. You can take a look at the recently merged model for an example https://github.com/huggingface/transformers/pull/41112/files#diff-857421affc3c877bca95377cbb6adb3a8374b149fcbdcc6c759ea0408fa96897

@zhang-prog
Copy link
Contributor Author

@zucchini-nlp

PTAL.❤️

_checkpoint_conversion_mapping and ignore_keys_at_rope_validation needs to be discussed.

I am working on the docs and tests.

@zucchini-nlp
Copy link
Member

Great, looking good already. We can keep the conversion mapping as is, no issue for us! There are also a few unresolved comments from the past iterations, if you can take a look

Ping me when the docs/tests are added and the CI shows ✅

@zhang-prog
Copy link
Contributor Author

zhang-prog commented Dec 5, 2025

Don't merge. Working.....

@zhang-prog
Copy link
Contributor Author

@zucchini-nlp
Why did it crash here?

image

in my environment, the test passed:

image

@zucchini-nlp
Copy link
Member

@zhang-prog worker crashed means that the tests might be using too much RAM. I see that the image sizes in tests are quite high, 600x400 images. Let's make dummy inputs and model as tiny as possible

Reviewing now

### Usage tips

> [!IMPORTANT]
> We currently recommend using the [PaddleOCR official method for inference](https://www.paddleocr.ai/latest/en/version3.x/pipeline_usage/PaddleOCR-VL.html), as it is faster and supports page-level document parsing.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious if we plan to support page-level document parsing in transformers in the future. Let us know if you need help with it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is one of our goals as well. We aim to resolve this issue, but we anticipate encountering some engineering challenges, such as the need to manage the sequential logic between the two models, which is quite complex.

In fact, we plan to submit a PR for PP-DocLayoutV2 soon and hope you can help review it.

"expert_layer_offset",
"expert_layer_period",
],
"PaddleOCRTextConfig": ["tie_word_embeddings"],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting, tie_word_embeddings is a universal attribute and I wouldn't expect CI to complain. Will check 👁️

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see. It has a comment that # Allow if the default value in the configuration class is different from the one in PreTrainedConfig, and there are more models skipping tie_word_embeddings explicitly

@zhang-prog
Copy link
Contributor Author

@zucchini-nlp

Now, CI shows ✅
PTAL. ❤️

@zucchini-nlp
Copy link
Member

run-slow: paddleocr_vl

@github-actions
Copy link
Contributor

This comment contains run-slow, running the specified jobs:

models: ["models/paddleocr_vl"]
quantizations: []

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@github-actions
Copy link
Contributor

CI Results

Workflow Run ⚙️

Model CI Report

❌ Failed tests

  • paddleocr_vl:
    tests/models/paddleocr_vl/test_modeling_paddleocr_vl.py::PaddleOCRVLModelTest::test_flex_attention_with_grads
    tests/models/paddleocr_vl/test_modeling_paddleocr_vl.py::PaddleOCRVLIntegrationTest::test_small_model_integration_test
    tests/models/paddleocr_vl/test_modeling_paddleocr_vl.py::PaddleOCRVLIntegrationTest::test_small_model_integration_test_batch

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot, huge work! Last bits will be adjusting the slow tests. Above comment shows failing cases

I will ping core maintainers in the meanwhile for review

@zhang-prog
Copy link
Contributor Author

zhang-prog commented Dec 11, 2025

@zucchini-nlp fixed, please try slow tests again. btw, some conversation can be marked as resolved?

@zucchini-nlp
Copy link
Member

run-slow: paddleocr_vl

@github-actions
Copy link
Contributor

This comment contains run-slow, running the specified jobs:

models: ["models/paddleocr_vl"]
quantizations: []

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice! 🤗

("ovis2", "Qwen2TokenizerFast" if is_tokenizers_available() else None),
("owlv2", "CLIPTokenizerFast" if is_tokenizers_available() else None),
("owlvit", "CLIPTokenizerFast" if is_tokenizers_available() else None),
("paddleocr_vl", "LlamaTokenizer" if is_tokenizers_available() else None),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
("paddleocr_vl", "LlamaTokenizer" if is_tokenizers_available() else None),
("paddleocr_vl", "TokenizersBackend" if is_tokenizers_available() else None),

from looking at the tokenizer.json on the hub, its not a llama!

"normalizer": {
    "type": "Sequence",
    "normalizers": [
      {
        "type": "Replace",
        "pattern": {
          "String": " "
        },
        "content": "▁"
      }
    ]
  },
  "pre_tokenizer": null,

while llama would initialize:

        self._tokenizer.normalizer = None
        self._tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(
            replacement="▁", prepend_scheme=_get_prepend_scheme(self.add_prefix_space, self), split=False
        )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
The rope index difference between sequence length and multimodal rope.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this in general should not be needed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

@github-actions
Copy link
Contributor

CI Results

Workflow Run ⚙️

Model CI Report

❌ Failed tests

  • paddleocr_vl:
    tests/models/paddleocr_vl/test_modeling_paddleocr_vl.py::PaddleOCRVLModelTest::test_flex_attention_with_grads

@zhang-prog
Copy link
Contributor Author

@zucchini-nlp uh, OOM on test_flex_attention_with_grads again, how can I pass this test by changing the parameters?

@zucchini-nlp
Copy link
Member

@zhang-prog the test changes hidden sizes of the model to be multiple of 16 iirc, prob that is the reason it is OOM'ing. If the model causes OOM and can't be made smaller, we can skip the test imo

@zhang-prog
Copy link
Contributor Author

@zucchini-nlp ok, I've reduced the value of hidden_size, intermediate_size and num_attention_heads. Maybe this attempt can pass the test. let's try again!

@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, paddleocr_vl

@zucchini-nlp
Copy link
Member

run-slow: paddleocr_vl

@github-actions
Copy link
Contributor

This comment contains run-slow, running the specified jobs:

models: ["models/paddleocr_vl"]
quantizations: []

@github-actions
Copy link
Contributor

CI Results

Workflow Run ⚙️

✅ No failing test specific to this PR 🎉 !

@zucchini-nlp
Copy link
Member

Great, ci green now and we can merge 🚀

@zucchini-nlp zucchini-nlp merged commit 8c84144 into huggingface:main Dec 11, 2025
26 checks passed
SangbumChoi pushed a commit to SangbumChoi/transformers that referenced this pull request Jan 23, 2026
* init

* refactor

* update

* update

* fix unresolved problems

* fix how position_ids work with flash_attn_2

* add tests and fix code

* add model_doc

* update model_doc

* fix ci

* update docstring

* add tests

* update

* add **kwargs

* update

* update

* update

* reduce max_position_embeddings in tests

* update
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants