Conversation
|
Hi @yash-srivastava19 thanks for this PR, but this is not how we should fix that since ideally we should catch that either by checking that the received type is a So a more suitable fix should be the following: model_init_kwargs["torch_dtype"] = (
model_init_kwargs["torch_dtype"]
if model_init_kwargs["torch_dtype"] in ["auto", None]
or isinstance(model_init_kwargs["torch_dtype"], torch.dtype)
else getattr(torch, model_init_kwargs["torch_dtype"])
)Anyway, I'll let the authors chime in with their thoughts and ideas about a potential fix! Thanks anyway 🤗 |
younesbelkada
left a comment
There was a problem hiding this comment.
Thanks a lot for this ! I second what @alvarobartt said above, we can change this fix to something like:
diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py
index e739b2d..80e11ad 100644
--- a/trl/trainer/sft_trainer.py
+++ b/trl/trainer/sft_trainer.py
@@ -159,11 +159,13 @@ class SFTTrainer(Trainer):
raise ValueError("You passed model_init_kwargs to the SFTConfig, but your model is already instantiated.")
else:
model_init_kwargs = args.model_init_kwargs
- model_init_kwargs["torch_dtype"] = (
- model_init_kwargs["torch_dtype"]
- if model_init_kwargs["torch_dtype"] in ["auto", None]
- else getattr(torch, model_init_kwargs["torch_dtype"])
- )
+ torch_dtype = model_init_kwargs["torch_dtype"]
+
+ # Convert to `torch.dtype` if an str is passed
+ if isinstance(torch_dtype, str) and torch_dtype != "auto":
+ torch_dtype = getattr(torch, torch_dtype)
+
+ model_init_kwargs["torch_dtype"] = torch_dtype
if infinite is not None:
warnings.warn(And it worked fine on my end! Would you be happy to apply these changes instead in this PR?
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Yes, it is much more optimal. Agreed |
|
Did the json encoding error rectified as well or it pertains even after the fix? |
|
Thanks ! that's another issue we can fix in a follow up PR ! |
|
Hi here @yash-srivastava19 friendly ping to check about the status of this PR 👍🏻 Is it something you are still happy / comfortable to work with? Or would you prefer us to take over instead? Just let us know, thanks 🤗 |
|
Hi here @yash-srivastava19 thanks for the effort, we'll be closing this PR in favour of #1807, and you've been included as a contributor there 🤗 Thanks a lot for the effort! |
#1751 mentioned that the TRL CLI is not completely capturing the torch_dtype. I thought the issue was urgent, so quickly patched a hacky fix, which at least initiates the SFT Trainer.
Original Issue :
On running the following command :
The error was that trl sft is does not identify it as a string when calling
getattr(torch, model_init_kwargs["torch_dtype"]).The fix was made which allows the to not break the pipeline at this stage. Although it is a hacky fix, I'm willing to work on it further :)
The error after that is from the transformers library that isn't able to serialize the dtype object(screenshot attached):