Skip to content

[Feature] VLMs support for GRPO#2752

Closed
GAD-cell wants to merge 1459 commits intounslothai:mainfrom
GAD-cell:VLM_GRPO
Closed

[Feature] VLMs support for GRPO#2752
GAD-cell wants to merge 1459 commits intounslothai:mainfrom
GAD-cell:VLM_GRPO

Conversation

@GAD-cell
Copy link
Copy Markdown
Contributor

This PR aims to add support for VLMs in GRPO, which is currently not supported by HF.

I've implemented a working version that does not yet include VLLM or video input support (mainly due to limited resources for testing video inputs haha).
I added a new variable, use_vision, to the GRPO config. Setting use_vision = True enables vision inputs, while use_vision = False keeps the default GRPO behavior. Default is False.
I also had to change a function in unsloth_zoo.peft_utils (requires_grad_post_hook) to make it work.
I've tested the implementation with Qwen 2.5 VL 7B for 250 steps, and training appears to proceed correctly (see TensorBoard screenshots for reference).

danielhanchen and others added 30 commits April 29, 2025 11:17
Update mapper.py to add Qwen3 base
* bug fix unslothai#2008 (unslothai#2039)

* fix (unslothai#2051)

* Update loader.py

* Update pyproject.toml

* Update pyproject.toml

* Update vision.py

* more prints

* Update loader.py

* LoRA 16bit fix

* Update vision.py

* Update vision.py

* Update _utils.py

* Update vision.py

* move forced float32

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* move print

* Update _utils.py

* disable bfloat16

* Fix forced float32

* move float32

* Ensure trust_remote_code propegates down to unsloth_compile_transformers (unslothai#2075)

* Update _utils.py

* Show both `peft_error` and `autoconfig_error`, not just `autoconfig_error` (unslothai#2080)

When loading a PEFT model fails, only the `autoconfig_error` is shown. Instead of the `peft_error`, which is what really matters when we're trying to load a PEFT adapter, the user will see something like this:

```
RuntimeError: Unrecognized model in my_model. Should have a `model_type` key in its config.json, or contain one of the following strings in its name: albert, align, altclip, ...
```

This PR just changes it so `autoconfig_error` and `peft_error` are both displayed.

* fix error message (unslothai#2046)

* Update vision.py

* Update _utils.py

* Update pyproject.toml

* Update __init__.py

* Update __init__.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update rl_replacements.py

* Update vision.py

* Update rl_replacements.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Remove double generate patch

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update mapper.py

* Update vision.py

* fix: config.torch_dtype in LlamaModel_fast_forward_inference (unslothai#2091)

* fix: config.torch_dtype in LlamaModel_fast_forward_inference

* Update llama.py

* update for consistency

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>

* versioning

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* model_type_arch

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update loader.py

* check

* Update _utils.py

* Update loader.py

* Update loader.py

* Remove prints

* Update README.md

typo

* Update _utils.py

* Update _utils.py

* versioning

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update vision.py

* HF Transfer

* fix(utils): add missing importlib import to fix NameError (unslothai#2134)

This commit fixes a NameError that occurs when `importlib` is referenced in _utils.py
without being imported, especially when UNSLOTH_USE_MODELSCOPE=1 is enabled.
By adding the missing import statement, the code will no longer throw a NameError.

* Add QLoRA Train and Merge16bit Test (unslothai#2130)

* add reference and unsloth lora merging tests

* add test / dataset printing to test scripts

* allow running tests from repo root

* add qlora test readme

* more readme edits

* ruff formatting

* additional readme comments

* forgot to add actual tests

* add apache license

* Update pyproject.toml

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update loader.py

* Update loader.py

* Revert

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Bug fix

* Update mapper.py

* check SDPA for Mistral 3, Pixtral

* Update vision.py

* Versioning

* Update rl_replacements.py

* Update README.md

* add model registry

* move hf hub utils to unsloth/utils

* refactor global model info dicts to dataclasses

* fix dataclass init

* fix llama registration

* remove deprecated key function

* start registry reog

* add llama vision

* quant types -> Enum

* remap literal quant types to QuantType Enum

* add llama model registration

* fix quant tag mapping

* add qwen2.5 models to registry

* add option to include original model in registry

* handle quant types per model size

* separate registration of base and instruct llama3.2

* add QwenQVQ to registry

* add gemma3 to registry

* add phi

* add deepseek v3

* add deepseek r1 base

* add deepseek r1 zero

* add deepseek distill llama

* add deepseek distill models

* remove redundant code when constructing model names

* add mistral small to registry

* rename model registration methods

* rename deepseek registration methods

* refactor naming for mistral and phi

* add global register models

* refactor model registration tests for new registry apis

* add model search method

* remove deprecated registration api

* add quant type test

* add registry readme

* make llama registration more specific

* clear registry when executing individual model registration file

* more registry readme updates

* Update _auto_install.py

* Llama4

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Synthetic data

* Update mapper.py

* Xet and Synthetic

* Update synthetic.py

* Update loader.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update pyproject.toml

* Delete .gitignore

---------

Co-authored-by: Mukkesh Ganesh <mukmckenzie@gmail.com>
Co-authored-by: Kareem <81531392+KareemMusleh@users.noreply.github.com>
Co-authored-by: Xander Hawthorne <167850078+CuppaXanax@users.noreply.github.com>
Co-authored-by: Isaac Breen <isaac.breen@icloud.com>
Co-authored-by: lurf21 <93976703+lurf21@users.noreply.github.com>
Co-authored-by: Jack Shi Wei Lun <87535974+jackswl@users.noreply.github.com>
Co-authored-by: naliazheli <nalia0316@gmail.com>
Co-authored-by: jeromeku <jerome.ku@gmail.com>
Co-authored-by: Michael Han <107991372+shimmyshimmer@users.noreply.github.com>
* move float32

* Ensure trust_remote_code propegates down to unsloth_compile_transformers (unslothai#2075)

* Update _utils.py

* Show both `peft_error` and `autoconfig_error`, not just `autoconfig_error` (unslothai#2080)

When loading a PEFT model fails, only the `autoconfig_error` is shown. Instead of the `peft_error`, which is what really matters when we're trying to load a PEFT adapter, the user will see something like this:

```
RuntimeError: Unrecognized model in my_model. Should have a `model_type` key in its config.json, or contain one of the following strings in its name: albert, align, altclip, ...
```

This PR just changes it so `autoconfig_error` and `peft_error` are both displayed.

* fix error message (unslothai#2046)

* Update vision.py

* Update _utils.py

* Update pyproject.toml

* Update __init__.py

* Update __init__.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update rl_replacements.py

* Update vision.py

* Update rl_replacements.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Remove double generate patch

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update mapper.py

* Update vision.py

* fix: config.torch_dtype in LlamaModel_fast_forward_inference (unslothai#2091)

* fix: config.torch_dtype in LlamaModel_fast_forward_inference

* Update llama.py

* update for consistency

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>

* versioning

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* model_type_arch

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update loader.py

* check

* Update _utils.py

* Update loader.py

* Update loader.py

* Remove prints

* Update README.md

typo

* Update _utils.py

* Update _utils.py

* versioning

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update vision.py

* HF Transfer

* fix(utils): add missing importlib import to fix NameError (unslothai#2134)

This commit fixes a NameError that occurs when `importlib` is referenced in _utils.py
without being imported, especially when UNSLOTH_USE_MODELSCOPE=1 is enabled.
By adding the missing import statement, the code will no longer throw a NameError.

* Add QLoRA Train and Merge16bit Test (unslothai#2130)

* add reference and unsloth lora merging tests

* add test / dataset printing to test scripts

* allow running tests from repo root

* add qlora test readme

* more readme edits

* ruff formatting

* additional readme comments

* forgot to add actual tests

* add apache license

* Update pyproject.toml

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update loader.py

* Update loader.py

* Revert

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Bug fix

* Update mapper.py

* check SDPA for Mistral 3, Pixtral

* Update vision.py

* Versioning

* Update rl_replacements.py

* Update README.md

* add model registry

* move hf hub utils to unsloth/utils

* refactor global model info dicts to dataclasses

* fix dataclass init

* fix llama registration

* remove deprecated key function

* start registry reog

* add llama vision

* quant types -> Enum

* remap literal quant types to QuantType Enum

* add llama model registration

* fix quant tag mapping

* add qwen2.5 models to registry

* add option to include original model in registry

* handle quant types per model size

* separate registration of base and instruct llama3.2

* add QwenQVQ to registry

* add gemma3 to registry

* add phi

* add deepseek v3

* add deepseek r1 base

* add deepseek r1 zero

* add deepseek distill llama

* add deepseek distill models

* remove redundant code when constructing model names

* add mistral small to registry

* rename model registration methods

* rename deepseek registration methods

* refactor naming for mistral and phi

* add global register models

* refactor model registration tests for new registry apis

* add model search method

* remove deprecated registration api

* add quant type test

* add registry readme

* make llama registration more specific

* clear registry when executing individual model registration file

* more registry readme updates

* Update _auto_install.py

* Llama4

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Synthetic data

* Update mapper.py

* Xet and Synthetic

* Update synthetic.py

* Update loader.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update pyproject.toml

* Delete .gitignore

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update _utils.py

* Update pyproject.toml

* Update synthetic.py

* Update synthetic.py

---------

Co-authored-by: Xander Hawthorne <167850078+CuppaXanax@users.noreply.github.com>
Co-authored-by: Isaac Breen <isaac.breen@icloud.com>
Co-authored-by: Kareem <81531392+KareemMusleh@users.noreply.github.com>
Co-authored-by: lurf21 <93976703+lurf21@users.noreply.github.com>
Co-authored-by: Jack Shi Wei Lun <87535974+jackswl@users.noreply.github.com>
Co-authored-by: naliazheli <nalia0316@gmail.com>
Co-authored-by: jeromeku <jerome.ku@gmail.com>
Co-authored-by: Michael Han <107991372+shimmyshimmer@users.noreply.github.com>
Qwen3 notebook
* add moe grouped gemm kernel

* add benchmark, README

* remove formatting from __init__.py
Datta0 and others added 13 commits July 9, 2025 14:07
* fix for casual mask

* use un_casual in sdpa

* add missing mask

* fix for type
* Move tensors to right devices

* fix multi gpu for non mistral models

* multi GPU RoPE for gemma2

* Finish up multi GPU inference

* Make multiGPU rope a list

* Remove unnecessary transfer to CPU

* Remove unnecessary move to CPU

* Donot move inputs to device yet

will be handled separately in another PR

* Move inputs to appropriate decoder device

* Make device count global variable

* Cleanup RoPE device code

* Fixup num_gpu to device count

* Cleanup device counts

* Use device index for RoPE get_cache

* Donot typecast

* Use tuple instead of list for tensors. Use device index directly

* fixup move to device logic
* rename deepseek registration methods

* refactor naming for mistral and phi

* add global register models

* refactor model registration tests for new registry apis

* add model search method

* remove deprecated registration api

* add quant type test

* add registry readme

* make llama registration more specific

* clear registry when executing individual model registration file

* more registry readme updates

* Update _auto_install.py

* Llama4

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Synthetic data

* Update mapper.py

* Xet and Synthetic

* Update synthetic.py

* Update loader.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update pyproject.toml

* Delete .gitignore

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update _utils.py

* Update pyproject.toml

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update synthetic.py

* Update chat_templates.py

* Seasame force float16 / float32

* Fix Seasame

* Update loader.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update loader.py

* is_multimodal

* Update loader.py

* Update loader.py

* Update loader.py

* Update loader.py

* Update vision.py

* Update vision.py

* Update vision.py

* UNSLOTH_DISABLE_STATIC_GENERATION

* Update vision.py

* Auto vision detection

* Sesame

* Whisper

* Update loader.py

* Update loader.py

* Update loader.py

* Update mapper.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update vision.py

* Update loader.py

* Update loader.py

* Update loader.py

* Update loader.py

* Update _utils.py

* Update rl.py

* versioning

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* Update rl.py

* logging

* Update pyproject.toml

* Update rl.py

* versioning

* Update rl.py

* Update rl.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl.py

* Update rl_replacements.py

* Update rl_replacements.py

* logits / temperature

* Update rl_replacements.py

* Update pyproject.toml

* Update rl_replacements.py

* Update rl_replacements.py

* Debugging only

* Update llama.py

* Update llama.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl_replacements.py

* Generic efficient GRPO

* Update rl_replacements.py

* Update rl_replacements.py

* Remove debugging

* Update rl_replacements.py

* Update rl_replacements.py

* Update vision.py

* Update llama.py

* Update rl_replacements.py

* versioning

* Update _utils.py

* Update vision.py

* Update mapper.py

* Update loader.py

* Update mapper.py

* Update vision.py

* Update loader.py

* Update vision.py

* Update loader.py

* Update _utils.py

* Update vision.py

* gradient checkpointing

* Gemma 3N fixes

* Update loader.py

* Versioning

* Gemma 3N fixes

* Update vision.py

* Update vision.py

* Update loader.py

* Update vision.py

* Fix setup.py

* setup.py

* Prints

* Update setup.py

* Update setup.py

* Update setup.py

* Update pyproject.toml

* Update pyproject.toml

* Update pyproject.toml

* Update pyproject.toml

* Update pyproject.toml

* Update pyproject.toml

* Update vision.py

* Update vision.py

* Update pyproject.toml

* Update vision.py

* Update _utils.py

* Update __init__.py

* Update __init__.py

* Small fixes

* Update vision.py

* Update vision.py

* versioning

* Update __init__.py

---------

Co-authored-by: jeromeku <jerome.ku@gmail.com>
Co-authored-by: Michael Han <107991372+shimmyshimmer@users.noreply.github.com>
…-per-token-logps-argument-mismatch

Fix argument mismatch in GRPO _get_per_token_logps lambda function
@Sweaterdog
Copy link
Copy Markdown

Hey there! I was testing out this branch to use GRPO for text-based tasks on a model that supports vision (A Qwen2.5-VL 3B model that I had primed already). I keep getting this error. I will paste my notebook here to see if it is an issue with my code (Might be, I Frankensteined it) but it is an issue with HF Transformers, so I don't know.

https://github.com/Sweaterdog/curly-goggles

@GAD-cell
Copy link
Copy Markdown
Contributor Author

Hey there! I was testing out this branch to use GRPO for text-based tasks on a model that supports vision (A Qwen2.5-VL 3B model that I had primed already). I keep getting this error. I will paste my notebook here to see if it is an issue with my code (Might be, I Frankensteined it) but it is an issue with HF Transformers, so I don't know.

https://github.com/Sweaterdog/curly-goggles

Hey !
I think you forgot to paste the error haha.
However in your notebook I didn't see any installation for the dependencies (look at this).

@Sweaterdog
Copy link
Copy Markdown

Hey there! I was testing out this branch to use GRPO for text-based tasks on a model that supports vision (A Qwen2.5-VL 3B model that I had primed already). I keep getting this error. I will paste my notebook here to see if it is an issue with my code (Might be, I Frankensteined it) but it is an issue with HF Transformers, so I don't know.

https://github.com/Sweaterdog/curly-goggles

Hey !
I think you forgot to paste the error haha.
However in your notebook I didn't see any installation for the dependencies (look at this).

Sorry! I can paste the errors ASAP. I am using it all locally, and not running this on Google colab, hence why I don't have the other dependencies installed.

@Sweaterdog
Copy link
Copy Markdown

This was the error that I am getting:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
File ~/Desktop/Coding_Projects/Unsloth/.venv/lib/python3.12/site-packages/unsloth/models/vision.py:227, in unsloth_base_fast_generate(self, *args, **kwargs)
    226     with torch.inference_mode(), autocaster:
--> 227         output = self._old_generate(*args, **kwargs)
    228 except:

File ~/Desktop/Coding_Projects/Unsloth/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    115 with ctx_factory():
--> 116     return func(*args, **kwargs)

File ~/Desktop/Coding_Projects/Unsloth/.venv/lib/python3.12/site-packages/transformers/generation/utils.py:2625, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, use_model_defaults, custom_generate, **kwargs)
   2624     # 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
-> 2625     result = self._sample(
   2626         input_ids,
   2627         logits_processor=prepared_logits_processor,
   2628         stopping_criteria=prepared_stopping_criteria,
   2629         generation_config=generation_config,
   2630         synced_gpus=synced_gpus,
   2631         streamer=streamer,
   2632         **model_kwargs,
   2633     )
   2635 elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
   2636     # 11. interleave input_ids with `num_beams` additional sequences per batch

File ~/Desktop/Coding_Projects/Unsloth/.venv/lib/python3.12/site-packages/transformers/generation/utils.py:3606, in GenerationMixin._sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, **model_kwargs)
   3605 if is_prefill:
-> 3606     outputs = self(**model_inputs, return_dict=True)
   3607     is_prefill = False

File ~/Desktop/Coding_Projects/Unsloth/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
   1750 else:
-> 1751     return self._call_impl(*args, **kwargs)

File ~/Desktop/Coding_Projects/Unsloth/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1762, in Module._call_impl(self, *args, **kwargs)
   1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1760         or _global_backward_pre_hooks or _global_backward_hooks
   1761         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762     return forward_call(*args, **kwargs)
   1764 result = None

File ~/Desktop/Coding_Projects/Unsloth/unsloth_compiled_cache/unsloth_compiled_module_qwen2_5_vl.py:743, in Qwen2_5_VLForConditionalGeneration.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, pixel_values, pixel_values_videos, image_grid_thw, video_grid_thw, rope_deltas, cache_position, second_per_grid_ts, **kwargs)
    723 def forward(
    724     self,
    725     input_ids: torch.LongTensor = None,
   (...)    741     **kwargs: Unpack[KwargsForCausalLM],
    742 ) -> Union[tuple, Qwen2_5_VLCausalLMOutputWithPast]:
--> 743     return Qwen2_5_VLForConditionalGeneration_forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, pixel_values, pixel_values_videos, image_grid_thw, video_grid_thw, rope_deltas, cache_position, second_per_grid_ts, **kwargs)

File ~/Desktop/Coding_Projects/Unsloth/.venv/lib/python3.12/site-packages/transformers/utils/generic.py:943, in can_return_tuple.<locals>.wrapper(self, *args, **kwargs)
    942 try:
--> 943     output = func(self, *args, **kwargs)
    944     if is_requested_to_return_tuple or (is_configured_to_return_tuple and is_top_level_module):

File ~/Desktop/Coding_Projects/Unsloth/unsloth_compiled_cache/unsloth_compiled_module_qwen2_5_vl.py:566, in Qwen2_5_VLForConditionalGeneration_forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, pixel_values, pixel_values_videos, image_grid_thw, video_grid_thw, rope_deltas, cache_position, second_per_grid_ts, **kwargs)
    562 output_hidden_states = (
    563     output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
    564 )
--> 566 outputs = self.model(
    567     input_ids=input_ids,
    568     pixel_values=pixel_values,
    569     pixel_values_videos=pixel_values_videos,
    570     image_grid_thw=image_grid_thw,
    571     video_grid_thw=video_grid_thw,
    572     second_per_grid_ts=second_per_grid_ts,
    573     position_ids=position_ids,
    574     attention_mask=attention_mask,
    575     past_key_values=past_key_values,
    576     inputs_embeds=inputs_embeds,
    577     use_cache=use_cache,
    578     output_attentions=output_attentions,
    579     output_hidden_states=output_hidden_states,
    580     return_dict=True,
    581     cache_position=cache_position,
    582     **kwargs,
    583 )
    585 hidden_states = outputs[0]

File ~/Desktop/Coding_Projects/Unsloth/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
   1750 else:
-> 1751     return self._call_impl(*args, **kwargs)

File ~/Desktop/Coding_Projects/Unsloth/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1762, in Module._call_impl(self, *args, **kwargs)
   1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1760         or _global_backward_pre_hooks or _global_backward_hooks
   1761         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762     return forward_call(*args, **kwargs)
   1764 result = None

File ~/Desktop/Coding_Projects/Unsloth/.venv/lib/python3.12/site-packages/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py:1291, in Qwen2_5_VLModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, pixel_values, pixel_values_videos, image_grid_thw, video_grid_thw, rope_deltas, cache_position, second_per_grid_ts, **kwargs)
   1290 attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)
-> 1291 attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
   1292 attention_mask_tensor = (1.0 - attention_mask_tensor).int()

TypeError: torch.finfo() requires a floating point input type. Use torch.iinfo to handle 'torch.finfo'

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
Cell In[6], line 1
----> 1 trainer.train()

File ~/Desktop/Coding_Projects/Unsloth/.venv/lib/python3.12/site-packages/transformers/trainer.py:2206, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   2204         hf_hub_utils.enable_progress_bars()
   2205 else:
-> 2206     return inner_training_loop(
   2207         args=args,
   2208         resume_from_checkpoint=resume_from_checkpoint,
   2209         trial=trial,
   2210         ignore_keys_for_eval=ignore_keys_for_eval,
   2211     )

File <string>:321, in _fast_inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)

File <string>:28, in _unsloth_training_step(self, model, inputs, num_items_in_batch)

File ~/Desktop/Coding_Projects/Unsloth/.venv/lib/python3.12/site-packages/trl/extras/profiling.py:98, in profiling_decorator.<locals>.wrapper(self, *args, **kwargs)
     95 @functools.wraps(func)
     96 def wrapper(self, *args, **kwargs):
     97     with profiling_context(self, func.__name__):
---> 98         return func(self, *args, **kwargs)

File ~/Desktop/Coding_Projects/Unsloth/unsloth_compiled_cache/UnslothGRPOTrainer.py:1613, in _UnslothGRPOTrainer._prepare_inputs(self, generation_batch)
   1610 generate_every = self.args.steps_per_generation * self.num_iterations
   1611 if self._step % generate_every == 0 or self._buffered_inputs is None:
   1612     # self._buffered_inputs=None can occur when resuming from a checkpoint
-> 1613     generation_batch = self._generate_and_score_completions(generation_batch)
   1614     if self.use_vision : generation_batch['pixel_values']=generation_batch['pixel_values'].view(generation_batch['prompt_ids'].size(0), -1, generation_batch['pixel_values'].size(1)) # (batch_size * n_patches, dim embedding)->(batch_size,n_patches,dim embeddding)
   1615     generation_batch = shuffle_tensor_dict(generation_batch)

File ~/Desktop/Coding_Projects/Unsloth/unsloth_compiled_cache/UnslothGRPOTrainer.py:1804, in _UnslothGRPOTrainer._generate_and_score_completions(self, inputs)
   1798     with (
   1799         FSDP.summon_full_params(self.model_wrapped, recurse=False)
   1800         if self.is_fsdp_enabled
   1801         else nullcontext()
   1802     ):
   1803         if self.use_vision : prompt_completion_ids = unwrapped_model.generate(prompt_ids, attention_mask=prompt_mask,pixel_values = pixel_values,image_grid_thw=image_grid_thw, generation_config=self.generation_config)
-> 1804         else : prompt_completion_ids = unwrapped_model.generate(prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config)
   1806 # Compute prompt length and extract completion ids
   1807 prompt_length = prompt_ids.size(1)

File ~/Desktop/Coding_Projects/Unsloth/.venv/lib/python3.12/site-packages/unsloth/models/rl.py:70, in PatchRL.<locals>.unsloth_unwrap_model_for_generation.<locals>.generate_with_clone(*args, **kwargs)
     69 def generate_with_clone(*args, **kwargs):
---> 70     out = original_generate(*args, **kwargs)
     71     if isinstance(out, torch.Tensor):
     72         return out.clone()

File ~/Desktop/Coding_Projects/Unsloth/.venv/lib/python3.12/site-packages/peft/peft_model.py:1968, in PeftModelForCausalLM.generate(self, *args, **kwargs)
   1966     with self._enable_peft_forward_hooks(*args, **kwargs):
   1967         kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
-> 1968         outputs = self.base_model.generate(*args, **kwargs)
   1969 else:
   1970     outputs = self.base_model.generate(**kwargs)

File ~/Desktop/Coding_Projects/Unsloth/.venv/lib/python3.12/site-packages/unsloth/models/vision.py:232, in unsloth_base_fast_generate(self, *args, **kwargs)
    230     kwargs.pop("prompt_lookup_num_tokens", None)
    231     with torch.inference_mode(), autocaster:
--> 232         output = self._old_generate(*args, **kwargs)
    233 finally:
    234     pass

File ~/Desktop/Coding_Projects/Unsloth/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    113 @functools.wraps(func)
    114 def decorate_context(*args, **kwargs):
    115     with ctx_factory():
--> 116         return func(*args, **kwargs)

File ~/Desktop/Coding_Projects/Unsloth/.venv/lib/python3.12/site-packages/transformers/generation/utils.py:2625, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, use_model_defaults, custom_generate, **kwargs)
   2617     input_ids, model_kwargs = self._expand_inputs_for_generation(
   2618         input_ids=input_ids,
   2619         expand_size=generation_config.num_return_sequences,
   2620         is_encoder_decoder=self.config.is_encoder_decoder,
   2621         **model_kwargs,
   2622     )
   2624     # 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
-> 2625     result = self._sample(
   2626         input_ids,
   2627         logits_processor=prepared_logits_processor,
   2628         stopping_criteria=prepared_stopping_criteria,
   2629         generation_config=generation_config,
   2630         synced_gpus=synced_gpus,
   2631         streamer=streamer,
   2632         **model_kwargs,
   2633     )
   2635 elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
   2636     # 11. interleave input_ids with `num_beams` additional sequences per batch
   2637     input_ids, model_kwargs = self._expand_inputs_for_generation(
   2638         input_ids=input_ids,
   2639         expand_size=generation_config.num_beams,
   2640         is_encoder_decoder=self.config.is_encoder_decoder,
   2641         **model_kwargs,
   2642     )

File ~/Desktop/Coding_Projects/Unsloth/.venv/lib/python3.12/site-packages/transformers/generation/utils.py:3606, in GenerationMixin._sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, **model_kwargs)
   3603 model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
   3605 if is_prefill:
-> 3606     outputs = self(**model_inputs, return_dict=True)
   3607     is_prefill = False
   3608 else:

File ~/Desktop/Coding_Projects/Unsloth/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
   1749     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1750 else:
-> 1751     return self._call_impl(*args, **kwargs)

File ~/Desktop/Coding_Projects/Unsloth/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1762, in Module._call_impl(self, *args, **kwargs)
   1757 # If we don't have any hooks, we want to skip the rest of the logic in
   1758 # this function, and just call forward.
   1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1760         or _global_backward_pre_hooks or _global_backward_hooks
   1761         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762     return forward_call(*args, **kwargs)
   1764 result = None
   1765 called_always_called_hooks = set()

File ~/Desktop/Coding_Projects/Unsloth/unsloth_compiled_cache/unsloth_compiled_module_qwen2_5_vl.py:743, in Qwen2_5_VLForConditionalGeneration.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, pixel_values, pixel_values_videos, image_grid_thw, video_grid_thw, rope_deltas, cache_position, second_per_grid_ts, **kwargs)
    723 def forward(
    724     self,
    725     input_ids: torch.LongTensor = None,
   (...)    741     **kwargs: Unpack[KwargsForCausalLM],
    742 ) -> Union[tuple, Qwen2_5_VLCausalLMOutputWithPast]:
--> 743     return Qwen2_5_VLForConditionalGeneration_forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, pixel_values, pixel_values_videos, image_grid_thw, video_grid_thw, rope_deltas, cache_position, second_per_grid_ts, **kwargs)

File ~/Desktop/Coding_Projects/Unsloth/.venv/lib/python3.12/site-packages/transformers/utils/generic.py:943, in can_return_tuple.<locals>.wrapper(self, *args, **kwargs)
    940     set_attribute_for_modules(self, "_is_top_level_module", False)
    942 try:
--> 943     output = func(self, *args, **kwargs)
    944     if is_requested_to_return_tuple or (is_configured_to_return_tuple and is_top_level_module):
    945         output = output.to_tuple()

File ~/Desktop/Coding_Projects/Unsloth/unsloth_compiled_cache/unsloth_compiled_module_qwen2_5_vl.py:566, in Qwen2_5_VLForConditionalGeneration_forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, pixel_values, pixel_values_videos, image_grid_thw, video_grid_thw, rope_deltas, cache_position, second_per_grid_ts, **kwargs)
    561 output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
    562 output_hidden_states = (
    563     output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
    564 )
--> 566 outputs = self.model(
    567     input_ids=input_ids,
    568     pixel_values=pixel_values,
    569     pixel_values_videos=pixel_values_videos,
    570     image_grid_thw=image_grid_thw,
    571     video_grid_thw=video_grid_thw,
    572     second_per_grid_ts=second_per_grid_ts,
    573     position_ids=position_ids,
    574     attention_mask=attention_mask,
    575     past_key_values=past_key_values,
    576     inputs_embeds=inputs_embeds,
    577     use_cache=use_cache,
    578     output_attentions=output_attentions,
    579     output_hidden_states=output_hidden_states,
    580     return_dict=True,
    581     cache_position=cache_position,
    582     **kwargs,
    583 )
    585 hidden_states = outputs[0]
    586 logits = EMPTY_LOGITS

File ~/Desktop/Coding_Projects/Unsloth/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
   1749     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1750 else:
-> 1751     return self._call_impl(*args, **kwargs)

File ~/Desktop/Coding_Projects/Unsloth/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1762, in Module._call_impl(self, *args, **kwargs)
   1757 # If we don't have any hooks, we want to skip the rest of the logic in
   1758 # this function, and just call forward.
   1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1760         or _global_backward_pre_hooks or _global_backward_hooks
   1761         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762     return forward_call(*args, **kwargs)
   1764 result = None
   1765 called_always_called_hooks = set()

File ~/Desktop/Coding_Projects/Unsloth/.venv/lib/python3.12/site-packages/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py:1291, in Qwen2_5_VLModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, pixel_values, pixel_values_videos, image_grid_thw, video_grid_thw, rope_deltas, cache_position, second_per_grid_ts, **kwargs)
   1289 if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
   1290     attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)
-> 1291     attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
   1292     attention_mask_tensor = (1.0 - attention_mask_tensor).int()
   1294 # Calculate RoPE index once per generation in the pre-fill stage only.
   1295 # When compiling, we can't check tensor values thus we check only input length
   1296 # It is safe to assume that `length!=1` means we're in pre-fill because compiled
   1297 # models currently cannot do asssisted decoding

TypeError: torch.finfo() requires a floating point input type. Use torch.iinfo to handle 'torch.finfo'

@GAD-cell
Copy link
Copy Markdown
Contributor Author

GAD-cell commented Jul 11, 2025

This was the error that I am getting:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
File ~/Desktop/Coding_Projects/Unsloth/.venv/lib/python3.12/site-packages/unsloth/models/vision.py:227, in unsloth_base_fast_generate(self, *args, **kwargs)
    226     with torch.inference_mode(), autocaster:
--> 227         output = self._old_generate(*args, **kwargs)
    228 except:

File ~/Desktop/Coding_Projects/Unsloth/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    115 with ctx_factory():
--> 116     return func(*args, **kwargs)

File ~/Desktop/Coding_Projects/Unsloth/.venv/lib/python3.12/site-packages/transformers/generation/utils.py:2625, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, use_model_defaults, custom_generate, **kwargs)
   2624     # 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
-> 2625     result = self._sample(
   2626         input_ids,
   2627         logits_processor=prepared_logits_processor,
   2628         stopping_criteria=prepared_stopping_criteria,
   2629         generation_config=generation_config,
   2630         synced_gpus=synced_gpus,
   2631         streamer=streamer,
   2632         **model_kwargs,
   2633     )
   2635 elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
   2636     # 11. interleave input_ids with `num_beams` additional sequences per batch

File ~/Desktop/Coding_Projects/Unsloth/.venv/lib/python3.12/site-packages/transformers/generation/utils.py:3606, in GenerationMixin._sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, **model_kwargs)
   3605 if is_prefill:
-> 3606     outputs = self(**model_inputs, return_dict=True)
   3607     is_prefill = False

File ~/Desktop/Coding_Projects/Unsloth/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
   1750 else:
-> 1751     return self._call_impl(*args, **kwargs)

File ~/Desktop/Coding_Projects/Unsloth/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1762, in Module._call_impl(self, *args, **kwargs)
   1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1760         or _global_backward_pre_hooks or _global_backward_hooks
   1761         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762     return forward_call(*args, **kwargs)
   1764 result = None

File ~/Desktop/Coding_Projects/Unsloth/unsloth_compiled_cache/unsloth_compiled_module_qwen2_5_vl.py:743, in Qwen2_5_VLForConditionalGeneration.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, pixel_values, pixel_values_videos, image_grid_thw, video_grid_thw, rope_deltas, cache_position, second_per_grid_ts, **kwargs)
    723 def forward(
    724     self,
    725     input_ids: torch.LongTensor = None,
   (...)    741     **kwargs: Unpack[KwargsForCausalLM],
    742 ) -> Union[tuple, Qwen2_5_VLCausalLMOutputWithPast]:
--> 743     return Qwen2_5_VLForConditionalGeneration_forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, pixel_values, pixel_values_videos, image_grid_thw, video_grid_thw, rope_deltas, cache_position, second_per_grid_ts, **kwargs)

File ~/Desktop/Coding_Projects/Unsloth/.venv/lib/python3.12/site-packages/transformers/utils/generic.py:943, in can_return_tuple.<locals>.wrapper(self, *args, **kwargs)
    942 try:
--> 943     output = func(self, *args, **kwargs)
    944     if is_requested_to_return_tuple or (is_configured_to_return_tuple and is_top_level_module):

File ~/Desktop/Coding_Projects/Unsloth/unsloth_compiled_cache/unsloth_compiled_module_qwen2_5_vl.py:566, in Qwen2_5_VLForConditionalGeneration_forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, pixel_values, pixel_values_videos, image_grid_thw, video_grid_thw, rope_deltas, cache_position, second_per_grid_ts, **kwargs)
    562 output_hidden_states = (
    563     output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
    564 )
--> 566 outputs = self.model(
    567     input_ids=input_ids,
    568     pixel_values=pixel_values,
    569     pixel_values_videos=pixel_values_videos,
    570     image_grid_thw=image_grid_thw,
    571     video_grid_thw=video_grid_thw,
    572     second_per_grid_ts=second_per_grid_ts,
    573     position_ids=position_ids,
    574     attention_mask=attention_mask,
    575     past_key_values=past_key_values,
    576     inputs_embeds=inputs_embeds,
    577     use_cache=use_cache,
    578     output_attentions=output_attentions,
    579     output_hidden_states=output_hidden_states,
    580     return_dict=True,
    581     cache_position=cache_position,
    582     **kwargs,
    583 )
    585 hidden_states = outputs[0]

File ~/Desktop/Coding_Projects/Unsloth/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
   1750 else:
-> 1751     return self._call_impl(*args, **kwargs)

File ~/Desktop/Coding_Projects/Unsloth/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1762, in Module._call_impl(self, *args, **kwargs)
   1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1760         or _global_backward_pre_hooks or _global_backward_hooks
   1761         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762     return forward_call(*args, **kwargs)
   1764 result = None

File ~/Desktop/Coding_Projects/Unsloth/.venv/lib/python3.12/site-packages/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py:1291, in Qwen2_5_VLModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, pixel_values, pixel_values_videos, image_grid_thw, video_grid_thw, rope_deltas, cache_position, second_per_grid_ts, **kwargs)
   1290 attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)
-> 1291 attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
   1292 attention_mask_tensor = (1.0 - attention_mask_tensor).int()

TypeError: torch.finfo() requires a floating point input type. Use torch.iinfo to handle 'torch.finfo'

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
Cell In[6], line 1
----> 1 trainer.train()

File ~/Desktop/Coding_Projects/Unsloth/.venv/lib/python3.12/site-packages/transformers/trainer.py:2206, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   2204         hf_hub_utils.enable_progress_bars()
   2205 else:
-> 2206     return inner_training_loop(
   2207         args=args,
   2208         resume_from_checkpoint=resume_from_checkpoint,
   2209         trial=trial,
   2210         ignore_keys_for_eval=ignore_keys_for_eval,
   2211     )

File <string>:321, in _fast_inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)

File <string>:28, in _unsloth_training_step(self, model, inputs, num_items_in_batch)

File ~/Desktop/Coding_Projects/Unsloth/.venv/lib/python3.12/site-packages/trl/extras/profiling.py:98, in profiling_decorator.<locals>.wrapper(self, *args, **kwargs)
     95 @functools.wraps(func)
     96 def wrapper(self, *args, **kwargs):
     97     with profiling_context(self, func.__name__):
---> 98         return func(self, *args, **kwargs)

File ~/Desktop/Coding_Projects/Unsloth/unsloth_compiled_cache/UnslothGRPOTrainer.py:1613, in _UnslothGRPOTrainer._prepare_inputs(self, generation_batch)
   1610 generate_every = self.args.steps_per_generation * self.num_iterations
   1611 if self._step % generate_every == 0 or self._buffered_inputs is None:
   1612     # self._buffered_inputs=None can occur when resuming from a checkpoint
-> 1613     generation_batch = self._generate_and_score_completions(generation_batch)
   1614     if self.use_vision : generation_batch['pixel_values']=generation_batch['pixel_values'].view(generation_batch['prompt_ids'].size(0), -1, generation_batch['pixel_values'].size(1)) # (batch_size * n_patches, dim embedding)->(batch_size,n_patches,dim embeddding)
   1615     generation_batch = shuffle_tensor_dict(generation_batch)

File ~/Desktop/Coding_Projects/Unsloth/unsloth_compiled_cache/UnslothGRPOTrainer.py:1804, in _UnslothGRPOTrainer._generate_and_score_completions(self, inputs)
   1798     with (
   1799         FSDP.summon_full_params(self.model_wrapped, recurse=False)
   1800         if self.is_fsdp_enabled
   1801         else nullcontext()
   1802     ):
   1803         if self.use_vision : prompt_completion_ids = unwrapped_model.generate(prompt_ids, attention_mask=prompt_mask,pixel_values = pixel_values,image_grid_thw=image_grid_thw, generation_config=self.generation_config)
-> 1804         else : prompt_completion_ids = unwrapped_model.generate(prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config)
   1806 # Compute prompt length and extract completion ids
   1807 prompt_length = prompt_ids.size(1)

File ~/Desktop/Coding_Projects/Unsloth/.venv/lib/python3.12/site-packages/unsloth/models/rl.py:70, in PatchRL.<locals>.unsloth_unwrap_model_for_generation.<locals>.generate_with_clone(*args, **kwargs)
     69 def generate_with_clone(*args, **kwargs):
---> 70     out = original_generate(*args, **kwargs)
     71     if isinstance(out, torch.Tensor):
     72         return out.clone()

File ~/Desktop/Coding_Projects/Unsloth/.venv/lib/python3.12/site-packages/peft/peft_model.py:1968, in PeftModelForCausalLM.generate(self, *args, **kwargs)
   1966     with self._enable_peft_forward_hooks(*args, **kwargs):
   1967         kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
-> 1968         outputs = self.base_model.generate(*args, **kwargs)
   1969 else:
   1970     outputs = self.base_model.generate(**kwargs)

File ~/Desktop/Coding_Projects/Unsloth/.venv/lib/python3.12/site-packages/unsloth/models/vision.py:232, in unsloth_base_fast_generate(self, *args, **kwargs)
    230     kwargs.pop("prompt_lookup_num_tokens", None)
    231     with torch.inference_mode(), autocaster:
--> 232         output = self._old_generate(*args, **kwargs)
    233 finally:
    234     pass

File ~/Desktop/Coding_Projects/Unsloth/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    113 @functools.wraps(func)
    114 def decorate_context(*args, **kwargs):
    115     with ctx_factory():
--> 116         return func(*args, **kwargs)

File ~/Desktop/Coding_Projects/Unsloth/.venv/lib/python3.12/site-packages/transformers/generation/utils.py:2625, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, use_model_defaults, custom_generate, **kwargs)
   2617     input_ids, model_kwargs = self._expand_inputs_for_generation(
   2618         input_ids=input_ids,
   2619         expand_size=generation_config.num_return_sequences,
   2620         is_encoder_decoder=self.config.is_encoder_decoder,
   2621         **model_kwargs,
   2622     )
   2624     # 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
-> 2625     result = self._sample(
   2626         input_ids,
   2627         logits_processor=prepared_logits_processor,
   2628         stopping_criteria=prepared_stopping_criteria,
   2629         generation_config=generation_config,
   2630         synced_gpus=synced_gpus,
   2631         streamer=streamer,
   2632         **model_kwargs,
   2633     )
   2635 elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
   2636     # 11. interleave input_ids with `num_beams` additional sequences per batch
   2637     input_ids, model_kwargs = self._expand_inputs_for_generation(
   2638         input_ids=input_ids,
   2639         expand_size=generation_config.num_beams,
   2640         is_encoder_decoder=self.config.is_encoder_decoder,
   2641         **model_kwargs,
   2642     )

File ~/Desktop/Coding_Projects/Unsloth/.venv/lib/python3.12/site-packages/transformers/generation/utils.py:3606, in GenerationMixin._sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, **model_kwargs)
   3603 model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
   3605 if is_prefill:
-> 3606     outputs = self(**model_inputs, return_dict=True)
   3607     is_prefill = False
   3608 else:

File ~/Desktop/Coding_Projects/Unsloth/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
   1749     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1750 else:
-> 1751     return self._call_impl(*args, **kwargs)

File ~/Desktop/Coding_Projects/Unsloth/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1762, in Module._call_impl(self, *args, **kwargs)
   1757 # If we don't have any hooks, we want to skip the rest of the logic in
   1758 # this function, and just call forward.
   1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1760         or _global_backward_pre_hooks or _global_backward_hooks
   1761         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762     return forward_call(*args, **kwargs)
   1764 result = None
   1765 called_always_called_hooks = set()

File ~/Desktop/Coding_Projects/Unsloth/unsloth_compiled_cache/unsloth_compiled_module_qwen2_5_vl.py:743, in Qwen2_5_VLForConditionalGeneration.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, pixel_values, pixel_values_videos, image_grid_thw, video_grid_thw, rope_deltas, cache_position, second_per_grid_ts, **kwargs)
    723 def forward(
    724     self,
    725     input_ids: torch.LongTensor = None,
   (...)    741     **kwargs: Unpack[KwargsForCausalLM],
    742 ) -> Union[tuple, Qwen2_5_VLCausalLMOutputWithPast]:
--> 743     return Qwen2_5_VLForConditionalGeneration_forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, pixel_values, pixel_values_videos, image_grid_thw, video_grid_thw, rope_deltas, cache_position, second_per_grid_ts, **kwargs)

File ~/Desktop/Coding_Projects/Unsloth/.venv/lib/python3.12/site-packages/transformers/utils/generic.py:943, in can_return_tuple.<locals>.wrapper(self, *args, **kwargs)
    940     set_attribute_for_modules(self, "_is_top_level_module", False)
    942 try:
--> 943     output = func(self, *args, **kwargs)
    944     if is_requested_to_return_tuple or (is_configured_to_return_tuple and is_top_level_module):
    945         output = output.to_tuple()

File ~/Desktop/Coding_Projects/Unsloth/unsloth_compiled_cache/unsloth_compiled_module_qwen2_5_vl.py:566, in Qwen2_5_VLForConditionalGeneration_forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, pixel_values, pixel_values_videos, image_grid_thw, video_grid_thw, rope_deltas, cache_position, second_per_grid_ts, **kwargs)
    561 output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
    562 output_hidden_states = (
    563     output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
    564 )
--> 566 outputs = self.model(
    567     input_ids=input_ids,
    568     pixel_values=pixel_values,
    569     pixel_values_videos=pixel_values_videos,
    570     image_grid_thw=image_grid_thw,
    571     video_grid_thw=video_grid_thw,
    572     second_per_grid_ts=second_per_grid_ts,
    573     position_ids=position_ids,
    574     attention_mask=attention_mask,
    575     past_key_values=past_key_values,
    576     inputs_embeds=inputs_embeds,
    577     use_cache=use_cache,
    578     output_attentions=output_attentions,
    579     output_hidden_states=output_hidden_states,
    580     return_dict=True,
    581     cache_position=cache_position,
    582     **kwargs,
    583 )
    585 hidden_states = outputs[0]
    586 logits = EMPTY_LOGITS

File ~/Desktop/Coding_Projects/Unsloth/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
   1749     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1750 else:
-> 1751     return self._call_impl(*args, **kwargs)

File ~/Desktop/Coding_Projects/Unsloth/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1762, in Module._call_impl(self, *args, **kwargs)
   1757 # If we don't have any hooks, we want to skip the rest of the logic in
   1758 # this function, and just call forward.
   1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1760         or _global_backward_pre_hooks or _global_backward_hooks
   1761         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762     return forward_call(*args, **kwargs)
   1764 result = None
   1765 called_always_called_hooks = set()

File ~/Desktop/Coding_Projects/Unsloth/.venv/lib/python3.12/site-packages/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py:1291, in Qwen2_5_VLModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, pixel_values, pixel_values_videos, image_grid_thw, video_grid_thw, rope_deltas, cache_position, second_per_grid_ts, **kwargs)
   1289 if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
   1290     attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)
-> 1291     attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
   1292     attention_mask_tensor = (1.0 - attention_mask_tensor).int()
   1294 # Calculate RoPE index once per generation in the pre-fill stage only.
   1295 # When compiling, we can't check tensor values thus we check only input length
   1296 # It is safe to assume that `length!=1` means we're in pre-fill because compiled
   1297 # models currently cannot do asssisted decoding

TypeError: torch.finfo() requires a floating point input type. Use torch.iinfo to handle 'torch.finfo'

This is due to the transformers version.
You should use transformers 4.52.4.
I recommend you install all the dependencies from the link I provided on the last comment even if you are not on a colab session. And it should work. :)

@Sweaterdog
Copy link
Copy Markdown

Ah! Thank you so much. It is working now!

@Sweaterdog
Copy link
Copy Markdown

One thing I noticed. When I went to run PPO fine tuning for a different model I ended up getting this error if I used this version.

from transformers import TrainingArguments
from trl import SFTTrainer
from unsloth import is_bfloat16_supported

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset,
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    dataset_num_proc=6,
    packing=False,
    args=TrainingArguments(
        per_device_train_batch_size=1,      # Reduce further for stability
        gradient_accumulation_steps=1,      # Effective batch size = 4
        warmup_ratio=0.1,                   # Double the warmup
        num_train_epochs=1,                 
        learning_rate=6e-5,                 # HALF the current rate
        max_grad_norm=0.5,                  # Tighter gradient clipping
        fp16=not is_bfloat16_supported(),
        bf16=is_bfloat16_supported(),
        logging_steps=25,                   
        optim="adamw_8bit",
        weight_decay=0.005,                 # Reduce weight decay
        lr_scheduler_type="cosine",         
        seed=3407,
        output_dir="outputs",
        save_steps=10000                     # More frequent saves
    ),
)
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[4], line 5
      2 from trl import SFTTrainer
      3 from unsloth import is_bfloat16_supported
----> 5 trainer = SFTTrainer(
      6     model=model,
      7     tokenizer=tokenizer,
      8     train_dataset=dataset,
      9     dataset_text_field="text",
     10     max_seq_length=max_seq_length,
     11     dataset_num_proc=6,
     12     packing=False,
     13     args=TrainingArguments(
     14         per_device_train_batch_size=1,      # Reduce further for stability
     15         gradient_accumulation_steps=1,      # Effective batch size = 4
     16         warmup_ratio=0.1,                   # Double the warmup
     17         num_train_epochs=1,                 
     18         learning_rate=6e-5,                 # HALF the current rate
     19         max_grad_norm=0.5,                  # Tighter gradient clipping
     20         fp16=not is_bfloat16_supported(),
     21         bf16=is_bfloat16_supported(),
     22         logging_steps=25,                   
     23         optim="adamw_8bit",
     24         weight_decay=0.005,                 # Reduce weight decay
     25         lr_scheduler_type="cosine",         
     26         seed=3407,
     27         output_dir="outputs",
     28         save_steps=10000                     # More frequent saves
     29     ),
     30 )

File ~/Desktop/Coding_Projects/Unsloth/.venv/lib/python3.12/site-packages/unsloth/trainer.py:209, in _backwards_compatible_trainer.<locals>.new_init(self, *args, **kwargs)
    207     kwargs["args"] = config
    208 pass
--> 209 original_init(self, *args, **kwargs)

File ~/Desktop/Coding_Projects/Unsloth/unsloth_compiled_cache/UnslothSFTTrainer.py:1005, in UnslothSFTTrainer.__init__(self, model, args, data_collator, train_dataset, eval_dataset, processing_class, compute_loss_func, compute_metrics, callbacks, optimizer_cls_and_kwargs, preprocess_logits_for_metrics, peft_config, formatting_func, **kwargs)
    987 def __init__(
    988     self,
    989     model,
   (...)   1002     **kwargs
   1003 ):
   1004     if args is None: args = UnslothSFTConfig()
-> 1005     self.use_vision = args.use_vision
   1006     use_bf16 = getattr(args, 'bf16', False)
   1007     if type(use_bf16) is not bool: use_bf16 = False

AttributeError: 'TrainingArguments' object has no attribute 'use_vision'

And when I add use_vision and set it to none, I get an error about an unexpected argument, which is use_vision. Mind you this is also all on a Language only model using FastLanguageModel

@Larry-Gan
Copy link
Copy Markdown

@GAD-cell Does this support G3emma 3n? It seems it's not compatible with the newest transformers version that's needed for 3n:
GAD-cell/vlm-grpo#13

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.