Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/experimental/test_trainers_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def test_bco(self):
# self.assertEqual(trainer.args.generate_during_eval, True)
assert trainer.args.is_encoder_decoder
assert trainer.args.precompute_ref_log_probs
assert trainer.args.model_init_kwargs == {"trust_remote_code": True}
assert trainer.args.ref_model_init_kwargs == {"trust_remote_code": True}
assert trainer.args.model_init_kwargs == {"trust_remote_code": True, "device_map": "auto", "dtype": "auto"}
assert trainer.args.ref_model_init_kwargs == {"trust_remote_code": True, "device_map": "auto", "dtype": "auto"}
assert trainer.args.dataset_num_proc == 4
assert trainer.args.prompt_sample_size == 512
assert trainer.args.min_density_ratio == 0.2
Expand Down
6 changes: 6 additions & 0 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@


if is_peft_available():
import peft
from peft import (
LoraConfig,
PeftModel,
Expand Down Expand Up @@ -537,6 +538,11 @@ def test_train_with_peft_config_prompt_tuning(self, peft_type):
tokenizer_name_or_path="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
)
elif peft_type == "prefix_tuning":
if parse_version(peft.__version__) <= Version("0.17.1"):
pytest.xfail(
"Prefix tuning with device_map='auto' is broken in peft 0.17.1 and below. See "
"https://github.com/huggingface/peft/issues/2821"
)
peft_config = PrefixTuningConfig(
task_type=TaskType.CAUSAL_LM,
num_virtual_tokens=4,
Expand Down
6 changes: 3 additions & 3 deletions tests/test_trainers_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_cpo(self):
assert trainer.args.truncation_mode == "keep_start"
# self.assertEqual(trainer.args.generate_during_eval, True)
assert trainer.args.is_encoder_decoder
assert trainer.args.model_init_kwargs == {"trust_remote_code": True}
assert trainer.args.model_init_kwargs == {"trust_remote_code": True, "device_map": "auto", "dtype": "auto"}
assert trainer.args.dataset_num_proc == 4

def test_dpo(self):
Expand Down Expand Up @@ -189,8 +189,8 @@ def test_kto(self):
# self.assertEqual(trainer.args.generate_during_eval, True)
assert trainer.args.is_encoder_decoder
assert trainer.args.precompute_ref_log_probs
assert trainer.args.model_init_kwargs == {"trust_remote_code": True}
assert trainer.args.ref_model_init_kwargs == {"trust_remote_code": True}
assert trainer.args.model_init_kwargs == {"trust_remote_code": True, "device_map": "auto", "dtype": "auto"}
assert trainer.args.ref_model_init_kwargs == {"trust_remote_code": True, "device_map": "auto", "dtype": "auto"}
assert trainer.args.dataset_num_proc == 4

@pytest.mark.parametrize("mixtures_coef_list", [False, True])
Expand Down
6 changes: 4 additions & 2 deletions trl/experimental/bco/bco_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def __init__(
raise ValueError("You passed model_kwargs to the BCOTrainer. But your model is already instantiated.")
else:
model_init_kwargs = args.model_init_kwargs
dtype = model_init_kwargs.get("dtype")
dtype = model_init_kwargs.get("dtype", "auto")
if dtype is not None:
# Convert to `torch.dtype` if an str is passed
if isinstance(dtype, str) and dtype != "auto":
Expand All @@ -403,6 +403,7 @@ def __init__(
f"Invalid `dtype` passed to the BCOConfig. Expected a string with either `torch.dtype` or 'auto', but got {dtype}."
)
model_init_kwargs["dtype"] = dtype
model_init_kwargs["device_map"] = model_init_kwargs.get("device_map", "auto")

if args.ref_model_init_kwargs is None:
ref_model_init_kwargs = {}
Expand All @@ -412,7 +413,7 @@ def __init__(
)
else:
ref_model_init_kwargs = args.ref_model_init_kwargs
dtype = ref_model_init_kwargs.get("dtype")
dtype = ref_model_init_kwargs.get("dtype", "auto")
if dtype is not None:
# Convert to `torch.dtype` if an str is passed
if isinstance(dtype, str) and dtype != "auto":
Expand All @@ -422,6 +423,7 @@ def __init__(
f"Invalid `dtype` passed to the BCOConfig. Expected a string with either `torch.dtype` or 'auto', but got {dtype}."
)
ref_model_init_kwargs["dtype"] = dtype
ref_model_init_kwargs["device_map"] = ref_model_init_kwargs.get("device_map", "auto")

if isinstance(model, str):
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
Expand Down
3 changes: 2 additions & 1 deletion trl/trainer/cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def __init__(
raise ValueError("You passed model_kwargs to the CPOTrainer. But your model is already instantiated.")
else:
model_init_kwargs = args.model_init_kwargs
dtype = model_init_kwargs.get("dtype")
dtype = model_init_kwargs.get("dtype", "auto")
if dtype is not None:
# Convert to `torch.dtype` if an str is passed
if isinstance(dtype, str) and dtype != "auto":
Expand All @@ -170,6 +170,7 @@ def __init__(
f"Invalid `dtype` passed to the CPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {dtype}."
)
model_init_kwargs["dtype"] = dtype
model_init_kwargs["device_map"] = model_init_kwargs.get("device_map", "auto")

if isinstance(model, str):
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
Expand Down
1 change: 1 addition & 0 deletions trl/trainer/gkd_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def __init__(
if teacher_model_init_kwargs["dtype"] in ["auto", None]
else getattr(torch, teacher_model_init_kwargs["dtype"])
)
teacher_model_init_kwargs["device_map"] = teacher_model_init_kwargs.get("device_map", "auto")

if isinstance(teacher_model, str):
teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model, **teacher_model_init_kwargs)
Expand Down
4 changes: 2 additions & 2 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def __init__(
model_init_kwargs = args.model_init_kwargs or {}
if isinstance(model, str):
model_id = model
dtype = model_init_kwargs.get("dtype")
dtype = model_init_kwargs.get("dtype", "auto")
if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None:
pass # dtype is already a torch.dtype or "auto" or None
elif isinstance(dtype, str): # it's a str, but not "auto"
Expand All @@ -272,7 +272,7 @@ def __init__(
"Invalid `dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
f"a `torch.dtype` (e.g., 'float32'), but got {dtype}."
)
# Disable caching if gradient checkpointing is enabled (not supported)
model_init_kwargs["device_map"] = model_init_kwargs.get("device_map", "auto")
config = AutoConfig.from_pretrained(model_id)
architecture = getattr(transformers, config.architectures[0])
model = architecture.from_pretrained(model_id, **model_init_kwargs)
Expand Down
6 changes: 4 additions & 2 deletions trl/trainer/kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def __init__(
raise ValueError("You passed model_kwargs to the KTOTrainer. But your model is already instantiated.")
else:
model_init_kwargs = args.model_init_kwargs
dtype = model_init_kwargs.get("dtype")
dtype = model_init_kwargs.get("dtype", "auto")
if dtype is not None:
# Convert to `torch.dtype` if an str is passed
if isinstance(dtype, str) and dtype != "auto":
Expand All @@ -392,6 +392,7 @@ def __init__(
f"Invalid `dtype` passed to the KTOConfig. Expected a string with either `torch.dtype` or 'auto', but got {dtype}."
)
model_init_kwargs["dtype"] = dtype
model_init_kwargs["device_map"] = model_init_kwargs.get("device_map", "auto")

if args.ref_model_init_kwargs is None:
ref_model_init_kwargs = {}
Expand All @@ -401,7 +402,7 @@ def __init__(
)
else:
ref_model_init_kwargs = args.ref_model_init_kwargs
dtype = ref_model_init_kwargs.get("dtype")
dtype = ref_model_init_kwargs.get("dtype", "auto")
if dtype is not None:
# Convert to `torch.dtype` if an str is passed
if isinstance(dtype, str) and dtype != "auto":
Expand All @@ -411,6 +412,7 @@ def __init__(
f"Invalid `dtype` passed to the KTOConfig. Expected a string with either `torch.dtype` or 'auto', but got {dtype}."
)
ref_model_init_kwargs["dtype"] = dtype
ref_model_init_kwargs["device_map"] = ref_model_init_kwargs.get("device_map", "auto")

if isinstance(model, str):
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
Expand Down
3 changes: 2 additions & 1 deletion trl/trainer/online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def __init__(
model_id = model

# Handle dtype in model_init_kwargs
dtype = model_init_kwargs.get("dtype")
dtype = model_init_kwargs.get("dtype", "auto")
if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None:
pass
elif isinstance(dtype, str):
Expand All @@ -315,6 +315,7 @@ def __init__(
"Invalid `dtype` passed to `OnlineDPOConfig`. Expected either 'auto' or a string "
f"representing a `torch.dtype` (e.g., 'float32'), but got {dtype}."
)
model_init_kwargs["device_map"] = model_init_kwargs.get("device_map", "auto")

model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
else:
Expand Down
3 changes: 2 additions & 1 deletion trl/trainer/orpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def __init__(
raise ValueError("You passed model_kwargs to the ORPOTrainer. But your model is already instantiated.")
else:
model_init_kwargs = args.model_init_kwargs
dtype = model_init_kwargs.get("dtype")
dtype = model_init_kwargs.get("dtype", "auto")
if dtype is not None:
# Convert to `torch.dtype` if an str is passed
if isinstance(dtype, str) and dtype != "auto":
Expand All @@ -172,6 +172,7 @@ def __init__(
f"Invalid `dtype` passed to the ORPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {dtype}."
)
model_init_kwargs["dtype"] = dtype
model_init_kwargs["device_map"] = model_init_kwargs.get("device_map", "auto")

if isinstance(model, str):
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
Expand Down
3 changes: 2 additions & 1 deletion trl/trainer/reward_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def __init__(
model_init_kwargs = args.model_init_kwargs or {}
if isinstance(model, str):
model_id = model
dtype = model_init_kwargs.get("dtype")
dtype = model_init_kwargs.get("dtype", "auto")
if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None:
pass # dtype is already a torch.dtype or "auto" or None
elif isinstance(dtype, str) and dtype in ["bfloat16", "float16", "float32"]:
Expand All @@ -292,6 +292,7 @@ def __init__(
"Invalid `dtype` passed to `RewardConfig`. Expected either 'auto' or a string representing "
f"a valid `torch.dtype` (e.g., 'float32'), but got {dtype}."
)
model_init_kwargs["device_map"] = model_init_kwargs.get("device_map", "auto")
with suppress_from_pretrained_warning(transformers.modeling_utils.logger):
model = AutoModelForSequenceClassification.from_pretrained(model_id, num_labels=1, **model_init_kwargs)
else:
Expand Down
4 changes: 2 additions & 2 deletions trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def __init__(
model_init_kwargs = args.model_init_kwargs or {}
if isinstance(model, str):
model_id = model
dtype = model_init_kwargs.get("dtype")
dtype = model_init_kwargs.get("dtype", "auto")
if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None:
pass # dtype is already a torch.dtype or "auto" or None
elif isinstance(dtype, str): # it's a str, but not "auto"
Expand All @@ -267,7 +267,7 @@ def __init__(
"Invalid `dtype` passed to `RLOOConfig`. Expected either 'auto' or a string representing "
f"a `torch.dtype` (e.g., 'float32'), but got {dtype}."
)
# Disable caching if gradient checkpointing is enabled (not supported)
model_init_kwargs["device_map"] = model_init_kwargs.get("device_map", "auto")
config = AutoConfig.from_pretrained(model_id)
architecture = getattr(transformers, config.architectures[0])
model = architecture.from_pretrained(model_id, **model_init_kwargs)
Expand Down
3 changes: 2 additions & 1 deletion trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1986,7 +1986,7 @@ def create_model_from_path(model_id: str, **kwargs) -> PreTrainedModel:
[`~transformers.PreTrainedModel`]:
The instantiated model.
"""
dtype = kwargs.get("dtype")
dtype = kwargs.get("dtype", "auto")
if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None:
pass # dtype is already a torch.dtype or "auto" or None
elif isinstance(dtype, str) and dtype in ["bfloat16", "float16", "float32"]:
Expand All @@ -1996,6 +1996,7 @@ def create_model_from_path(model_id: str, **kwargs) -> PreTrainedModel:
"Invalid `dtype` passed to the config. Expected either 'auto' or a string representing "
f"a valid `torch.dtype` (e.g., 'float32'), but got {dtype}."
)
kwargs["device_map"] = kwargs.get("device_map", "auto")
config = AutoConfig.from_pretrained(model_id)
architecture = getattr(transformers, config.architectures[0])
model = architecture.from_pretrained(model_id, **kwargs)
Expand Down
Loading