Refactor KTO [1/N]: Modernize model initialization#4783
Refactor KTO [1/N]: Modernize model initialization#4783albertvillanova merged 1 commit intohuggingface:mainfrom
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
Nice! If it's easier for you, I think it's fine to have a big refactoring PR like in #3906 |
| # Reference model initialization | ||
| if isinstance(ref_model, str): | ||
| ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs) | ||
| ref_model_init_kwargs = args.ref_model_init_kwargs or {} | ||
| # Distributed training requires device_map=None | ||
| if args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]: | ||
| ref_model_init_kwargs["device_map"] = None | ||
| ref_model = create_model_from_path(ref_model, **ref_model_init_kwargs) | ||
| else: | ||
| if ref_model is not None and args.ref_model_init_kwargs is not None: | ||
| logger.warning( | ||
| "You passed `ref_model_init_kwargs` to the KTOConfig, but your ref_model is already instantiated. " | ||
| "The `ref_model_init_kwargs` will be ignored." | ||
| ) |
There was a problem hiding this comment.
In GRPO/RLOO/DPO refactored, the ref model if loaded after super().__init__(...), but we can still align later
There was a problem hiding this comment.
Thank you for catching this: that's indeed better architecture.
I agree we can align this later, as I was planning to do on the phase 3 refactoring plan: Reference Model Handling specifically planned for ref_model improvements.
Thanks for your suggestion, but I would prefer to keep the PRs small for review quality and risk management. Each PR is independently valuable and can be reviewed in 15-30 minutes. While a big refactoring PR sounds efficient, I think it creates high risk, poor review quality, slower iteration, and harder debugging. Indeed, I am already finding difficult to resolve conflicts each time I am merging the main branch to this other PR: #4700. IMO, small PRs are better for quality, speed, and maintainability. Happy to discuss if you have concerns about the granularity! 😅 |
Refactor KTO [1/N]: Modernize model initialization.
This PR modernizes KTOTrainer's model initialization to align with SFTTrainer's clean and maintainable patterns. It replaces manual model loading with the
create_model_from_path()helper function.Part of:
Problem
Before (KTO):
model_init_kwargsandref_model_init_kwargs(43 lines)getattr(torch, dtype)AutoModelForCausalLM.from_pretrainedAfter (Aligned with SFT):
or {}patterncreate_model_from_pathhelper