Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tempfile
import unittest

import pytest
import torch
from datasets import Dataset, features
from parameterized import parameterized
Expand Down Expand Up @@ -840,6 +842,68 @@ def test_dpo_lora_force_use_ref(self):
# train the model
trainer.train()

def test_dpo_trainer_torch_dtype(self):
# See https://github.com/huggingface/trl/issues/1751
dummy_dataset = self._init_dummy_dataset()
with tempfile.TemporaryDirectory() as tmp_dir:
dpo_config = DPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=1,
model_init_kwargs={"torch_dtype": "float16"},
ref_model_init_kwargs={"torch_dtype": "float16"},
)

trainer = DPOTrainer(
model=self.model_id,
ref_model=self.model_id,
tokenizer=self.tokenizer,
args=dpo_config,
train_dataset=dummy_dataset,
)
assert trainer.model.config.torch_dtype == torch.float16
assert trainer.ref_model.config.torch_dtype == torch.float16

# Now test when `torch_dtype` is provided but is wrong to either the model or the ref_model
with tempfile.TemporaryDirectory() as tmp_dir:
dpo_config = DPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=1,
model_init_kwargs={"torch_dtype": -1},
)

with pytest.raises(
ValueError,
match="Invalid `torch_dtype` passed to the DPOConfig. Expected a string with either `torch.dtype` or 'auto', but got -1.",
):
_ = DPOTrainer(
model=self.model_id,
tokenizer=self.tokenizer,
args=dpo_config,
train_dataset=dummy_dataset,
)

with tempfile.TemporaryDirectory() as tmp_dir:
dpo_config = DPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=1,
ref_model_init_kwargs={"torch_dtype": -1},
)

with pytest.raises(
ValueError,
match="Invalid `torch_dtype` passed to the DPOConfig. Expected a string with either `torch.dtype` or 'auto', but got -1.",
):
_ = DPOTrainer(
model=self.model_id,
ref_model=self.model_id,
tokenizer=self.tokenizer,
args=dpo_config,
train_dataset=dummy_dataset,
)

def test_dpo_loss_alpha_div_f(self):
model_id = "trl-internal-testing/tiny-random-LlamaForCausalLM"
tokenizer = AutoTokenizer.from_pretrained(model_id)
Expand Down
42 changes: 42 additions & 0 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1273,3 +1273,45 @@ def __call__(self, examples):
assert trainer.state.log_history[0]["eval_loss"] is not None

assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")

def test_sft_trainer_torch_dtype(self):
# See https://github.com/huggingface/trl/issues/1751
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = SFTConfig(
output_dir=tmp_dir,
eval_strategy="steps",
max_steps=4,
eval_steps=2,
save_steps=2,
per_device_train_batch_size=2,
model_init_kwargs={"torch_dtype": torch.float16},
)
trainer = SFTTrainer(
model=self.model_id,
args=training_args,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
)
assert trainer.model.config.torch_dtype == torch.float16

# Now test when `torch_dtype` is provided but is wrong
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = SFTConfig(
output_dir=tmp_dir,
eval_strategy="steps",
max_steps=4,
eval_steps=2,
save_steps=2,
per_device_train_batch_size=2,
model_init_kwargs={"torch_dtype": -1},
)
with pytest.raises(
ValueError,
match="Invalid `torch_dtype` passed to the SFTConfig. Expected a string with either `torch.dtype` or 'auto', but got -1.",
):
_ = SFTTrainer(
model=self.model_id,
args=training_args,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
)
35 changes: 25 additions & 10 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,11 +181,19 @@ def __init__(
)
else:
model_init_kwargs = args.model_init_kwargs
model_init_kwargs["torch_dtype"] = (
model_init_kwargs["torch_dtype"]
if model_init_kwargs["torch_dtype"] in ["auto", None]
else getattr(torch, model_init_kwargs["torch_dtype"])
)

torch_dtype = model_init_kwargs["torch_dtype"]
if torch_dtype is not None:
# Convert to `torch.dtype` if an str is passed
if isinstance(torch_dtype, str) and torch_dtype != "auto":
torch_dtype = getattr(torch, torch_dtype)

if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
raise ValueError(
f"Invalid `torch_dtype` passed to the DPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
)

model_init_kwargs["torch_dtype"] = torch_dtype

if ref_model_init_kwargs is not None:
warnings.warn(
Expand All @@ -201,11 +209,18 @@ def __init__(
)
else:
ref_model_init_kwargs = args.ref_model_init_kwargs
ref_model_init_kwargs["torch_dtype"] = (
ref_model_init_kwargs["torch_dtype"]
if ref_model_init_kwargs["torch_dtype"] in ["auto", None]
else getattr(torch, ref_model_init_kwargs["torch_dtype"])
)
torch_dtype = ref_model_init_kwargs["torch_dtype"]
if torch_dtype is not None:
# Convert to `torch.dtype` if an str is passed
if isinstance(torch_dtype, str) and torch_dtype != "auto":
torch_dtype = getattr(torch, torch_dtype)

if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
raise ValueError(
f"Invalid `torch_dtype` passed to the DPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
)

ref_model_init_kwargs["torch_dtype"] = torch_dtype

if isinstance(model, str):
warnings.warn(
Expand Down
18 changes: 13 additions & 5 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,19 @@ def __init__(
raise ValueError("You passed model_init_kwargs to the SFTConfig, but your model is already instantiated.")
else:
model_init_kwargs = args.model_init_kwargs
model_init_kwargs["torch_dtype"] = (
model_init_kwargs["torch_dtype"]
if model_init_kwargs["torch_dtype"] in ["auto", None]
else getattr(torch, model_init_kwargs["torch_dtype"])
)

torch_dtype = model_init_kwargs["torch_dtype"]
if torch_dtype is not None:
# Convert to `torch.dtype` if an str is passed
if isinstance(torch_dtype, str) and torch_dtype != "auto":
torch_dtype = getattr(torch, torch_dtype)

if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
raise ValueError(
f"Invalid `torch_dtype` passed to the SFTConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
)

model_init_kwargs["torch_dtype"] = torch_dtype

if infinite is not None:
warnings.warn(
Expand Down