Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

⚠️ Add warning guidelines and update codebase to follow best practices #2350

Open
wants to merge 31 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
350a66a
Add guidelines for working with warnings in the codebase
qgallouedec Nov 11, 2024
53ba260
Remove unnecessary warnings and improve code initialization
qgallouedec Nov 11, 2024
c74421e
Fix warnings and improve accuracy calculation
qgallouedec Nov 11, 2024
5f21517
Add rich library dependency for text formatting
qgallouedec Nov 12, 2024
dbcac07
Update LoRA weight loading warning message
qgallouedec Nov 12, 2024
c273e1f
Fix logging and import issues in AlignPropConfig
qgallouedec Nov 12, 2024
0a970e7
Fix warnings and improve code readability
qgallouedec Nov 12, 2024
ef398f0
Remove unused import statements
qgallouedec Nov 12, 2024
f471589
Refactor CPOTrainer class in cpo_trainer.py
qgallouedec Nov 12, 2024
753a1b8
Merge branch 'main' into warnings
qgallouedec Nov 12, 2024
9a0866f
Remove unnecessary warnings and raise ValueError for missing model
qgallouedec Nov 12, 2024
3ac3123
Fix warnings and improve code consistency
qgallouedec Nov 12, 2024
7307365
Update CONTRIBUTING.md to clarify the purpose of warnings
qgallouedec Nov 12, 2024
6897470
Fix string formatting in DataCollatorForCompletionOnlyLM class
qgallouedec Nov 12, 2024
49f5460
Merge branch 'main' into warnings
qgallouedec Nov 18, 2024
230a486
Merge branch 'main' into warnings
qgallouedec Nov 18, 2024
d68b69d
Merge branch 'main' into warnings
qgallouedec Nov 20, 2024
c5f8f13
Merge branch 'main' into warnings
qgallouedec Nov 22, 2024
dcdc6ca
Update SimPO loss parameters in CPOTrainer
qgallouedec Nov 24, 2024
7ecd6ec
Merge branch 'warnings' of https://github.com/huggingface/trl into wa…
qgallouedec Nov 24, 2024
efdea04
Merge branch 'main' into warnings
qgallouedec Nov 24, 2024
2a21c37
Merge branch 'main' into warnings
qgallouedec Nov 25, 2024
7d40d43
Fix warnings and remove unnecessary code in ConstantLengthDataset class
qgallouedec Nov 25, 2024
28533ed
Clarify warning guidelines
qgallouedec Nov 25, 2024
b786461
Rewrite the entire section
qgallouedec Nov 25, 2024
a1b3843
Fix capitalization in CONTRIBUTING.md
qgallouedec Nov 25, 2024
4ab7fa8
Fix formatting in CONTRIBUTING.md
qgallouedec Nov 25, 2024
fceb7dc
Merge branch 'main' into warnings
qgallouedec Nov 26, 2024
0e714c0
Merge branch 'main' into warnings
qgallouedec Nov 26, 2024
02f5b17
Merge branch 'main' into warnings
qgallouedec Nov 26, 2024
f08c609
Merge branch 'main' into warnings
qgallouedec Nov 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -283,3 +283,18 @@ The deprecation and removal schedule is based on each feature's usage and impact
- **Widely-Used Components**: For a feature with high usage, we aim for a more gradual transition period of approximately **5 months**, generally scheduling deprecation around **5 minor releases** after the initial warning.

These examples represent the two ends of a continuum. The specific timeline for each feature will be determined individually, balancing innovation with user stability needs.

### Working with Warnings

When working with warnings in the codebase, please follow these principles:

1. **Warnings must be actionable**
Every warning raised should be actionable and provide clear guidance on how to address or resolve the underlying issue. For example, a deprecation warning should include an alternative method or function that can be used.

2. **Warnings should not indicate normal behavior**
Warnings should not be triggered for issues that do not affect functionality. They must not appear for the expected, intended operation of the software. Warnings should highlight potential problems, not reflect normal behavior.

3. **Use the appropriate warning type**
Use the appropriate warning types (e.g., `DeprecationWarning`, `UserWarning`) for features that are being phased out or for behaviors that should be addressed in future versions.

By following these guidelines, we ensure that warnings remain meaningful, actionable, and contribute to the long-term health of the project.
3 changes: 2 additions & 1 deletion examples/scripts/reward_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@
if model_config.use_peft and model_config.lora_task_type != "SEQ_CLS":
warnings.warn(
"You are using a `task_type` that is different than `SEQ_CLS` for PEFT. This will lead to silent bugs"
" Make sure to pass --lora_task_type SEQ_CLS when using this script with PEFT."
" Make sure to pass --lora_task_type SEQ_CLS when using this script with PEFT.",
UserWarning,
)

##############
Expand Down
3 changes: 2 additions & 1 deletion trl/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,8 @@ def randn_tensor(
warnings.warn(
f"The passed generator was created on 'cpu' even though a tensor on {device} was expected."
f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
f" slighly speed up this function by passing a generator that was created on the {device} device."
f" slighly speed up this function by passing a generator that was created on the {device} device.",
UserWarning,
)
elif gen_device_type != device.type and gen_device_type == "cuda":
raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.")
Expand Down
8 changes: 0 additions & 8 deletions trl/trainer/bco_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,17 +394,9 @@ def __init__(
ref_model_init_kwargs["torch_dtype"] = torch_dtype

if isinstance(model, str):
warnings.warn(
"You passed a model_id to the BCOTrainer. This will automatically create an "
"`AutoModelForCausalLM` or a `PeftModel` (if you passed a `peft_config`) for you."
)
Comment on lines -397 to -400
Copy link
Member Author

@qgallouedec qgallouedec Nov 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct → No warning

model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)

if isinstance(ref_model, str):
warnings.warn(
"You passed a ref model_id to the BCOTrainer. This will automatically create an "
"`AutoModelForCausalLM`"
)
Comment on lines -404 to -407
Copy link
Member Author

@qgallouedec qgallouedec Nov 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct → No warning

ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs)

# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
Expand Down
7 changes: 5 additions & 2 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,9 +759,12 @@ def compute_accuracy(eval_pred) -> Dict[str, float]:
predictions, labels = eval_pred
# Here, predictions is rewards_chosen and rewards_rejected.
# We want to see how much of the time rewards_chosen > rewards_rejected.
if np.array(predictions[:, 0] == predictions[:, 1], dtype=float).sum() > 0:
equal_predictions_count = np.array(predictions[:, 0] == predictions[:, 1], dtype=float).sum()
if equal_predictions_count > 0:
warnings.warn(
f"There are {np.array(predictions[:, 0] == predictions[:, 1]).sum()} out of {len(predictions[:, 0])} instances where the predictions for both options are equal. As a consequence the accuracy can be misleading."
f"There are {equal_predictions_count} out of {len(predictions[:, 0])} instances where the predictions for "
"both options are equal. As a consequence the accuracy can be misleading.",
UserWarning,
Comment on lines -774 to +767
Copy link
Member Author

@qgallouedec qgallouedec Nov 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warnings must be actionable.
Warnings should not indicate normal behavior.
Use the appropriate warning type.

This warning remains not actionable. Any idea to solve this?

)
predictions = np.argmax(predictions, axis=1)

Expand Down
Loading