Skip to content

⚠️ Add warning guidelines and update codebase to follow best practices #2350

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 32 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
350a66a
Add guidelines for working with warnings in the codebase
qgallouedec Nov 11, 2024
53ba260
Remove unnecessary warnings and improve code initialization
qgallouedec Nov 11, 2024
c74421e
Fix warnings and improve accuracy calculation
qgallouedec Nov 11, 2024
5f21517
Add rich library dependency for text formatting
qgallouedec Nov 12, 2024
dbcac07
Update LoRA weight loading warning message
qgallouedec Nov 12, 2024
c273e1f
Fix logging and import issues in AlignPropConfig
qgallouedec Nov 12, 2024
0a970e7
Fix warnings and improve code readability
qgallouedec Nov 12, 2024
ef398f0
Remove unused import statements
qgallouedec Nov 12, 2024
f471589
Refactor CPOTrainer class in cpo_trainer.py
qgallouedec Nov 12, 2024
753a1b8
Merge branch 'main' into warnings
qgallouedec Nov 12, 2024
9a0866f
Remove unnecessary warnings and raise ValueError for missing model
qgallouedec Nov 12, 2024
3ac3123
Fix warnings and improve code consistency
qgallouedec Nov 12, 2024
7307365
Update CONTRIBUTING.md to clarify the purpose of warnings
qgallouedec Nov 12, 2024
6897470
Fix string formatting in DataCollatorForCompletionOnlyLM class
qgallouedec Nov 12, 2024
49f5460
Merge branch 'main' into warnings
qgallouedec Nov 18, 2024
230a486
Merge branch 'main' into warnings
qgallouedec Nov 18, 2024
d68b69d
Merge branch 'main' into warnings
qgallouedec Nov 20, 2024
c5f8f13
Merge branch 'main' into warnings
qgallouedec Nov 22, 2024
dcdc6ca
Update SimPO loss parameters in CPOTrainer
qgallouedec Nov 24, 2024
7ecd6ec
Merge branch 'warnings' of https://github.com/huggingface/trl into wa…
qgallouedec Nov 24, 2024
efdea04
Merge branch 'main' into warnings
qgallouedec Nov 24, 2024
2a21c37
Merge branch 'main' into warnings
qgallouedec Nov 25, 2024
7d40d43
Fix warnings and remove unnecessary code in ConstantLengthDataset class
qgallouedec Nov 25, 2024
28533ed
Clarify warning guidelines
qgallouedec Nov 25, 2024
b786461
Rewrite the entire section
qgallouedec Nov 25, 2024
a1b3843
Fix capitalization in CONTRIBUTING.md
qgallouedec Nov 25, 2024
4ab7fa8
Fix formatting in CONTRIBUTING.md
qgallouedec Nov 25, 2024
fceb7dc
Merge branch 'main' into warnings
qgallouedec Nov 26, 2024
0e714c0
Merge branch 'main' into warnings
qgallouedec Nov 26, 2024
02f5b17
Merge branch 'main' into warnings
qgallouedec Nov 26, 2024
f08c609
Merge branch 'main' into warnings
qgallouedec Nov 26, 2024
0cf434a
Merge branch 'main' into warnings
qgallouedec Nov 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -283,3 +283,18 @@ The deprecation and removal schedule is based on each feature's usage and impact
- **Widely-Used Components**: For a feature with high usage, we aim for a more gradual transition period of approximately **5 months**, generally scheduling deprecation around **5 minor releases** after the initial warning.

These examples represent the two ends of a continuum. The specific timeline for each feature will be determined individually, balancing innovation with user stability needs.

### Working with Warnings

When working with warnings in the codebase, please follow these principles:

1. **Warnings must be actionable**
Every warning raised should be actionable and provide clear guidance on how to address or resolve the underlying issue. For example, a deprecation warning should include an alternative method or function that can be used.

2. **Warnings should not indicate normal behavior**
Warnings should not be triggered for issues that do not affect functionality. They must not appear for the expected, intended operation of the software. Warnings should highlight potential problems, not reflect normal behavior.

3. **Use the appropriate warning type**
Use the appropriate warning types (e.g., `DeprecationWarning`, `UserWarning`) for features that are being phased out or for behaviors that should be addressed in future versions.

By following these guidelines, we ensure that warnings remain meaningful, actionable, and contribute to the long-term health of the project.
3 changes: 2 additions & 1 deletion examples/scripts/reward_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@
if model_config.use_peft and model_config.lora_task_type != "SEQ_CLS":
warnings.warn(
"You are using a `task_type` that is different than `SEQ_CLS` for PEFT. This will lead to silent bugs"
" Make sure to pass --lora_task_type SEQ_CLS when using this script with PEFT."
" Make sure to pass --lora_task_type SEQ_CLS when using this script with PEFT.",
UserWarning,
)

##############
Expand Down
3 changes: 2 additions & 1 deletion trl/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,8 @@ def randn_tensor(
warnings.warn(
f"The passed generator was created on 'cpu' even though a tensor on {device} was expected."
f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
f" slighly speed up this function by passing a generator that was created on the {device} device."
f" slighly speed up this function by passing a generator that was created on the {device} device.",
UserWarning,
)
elif gen_device_type != device.type and gen_device_type == "cuda":
raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.")
Expand Down
19 changes: 12 additions & 7 deletions trl/environment/base_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import re
import warnings
from typing import Optional

import torch
Expand Down Expand Up @@ -145,8 +144,10 @@ def show_text(self, show_legend=False):
Print the text history.
"""
if not is_rich_available():
warnings.warn("install rich to display text")
return
raise ImportError(
"The `rich` library is required to display text with formatting. "
"Install it using `pip install rich`."
)

text = Text(self.text)
text.stylize(self.prompt_color, self.text_spans[0][0], self.text_spans[1][0])
Expand All @@ -167,8 +168,10 @@ def show_tokens(self, tokenizer, show_legend=False):
Print the history tokens.
"""
if not is_rich_available():
warnings.warn("install rich to display tokens")
return
raise ImportError(
"The `rich` library is required to display tokens with formatting. "
"Install it using `pip install rich`."
)

text = Text()
prompt_end = self.token_spans[0][1]
Expand All @@ -192,8 +195,10 @@ def show_colour_legend(self):
Print the colour legend.
"""
if not is_rich_available():
warnings.warn("install rich to display colour legend")
return
raise ImportError(
"The `rich` library is required to display colour legends with formatting. "
"Install it using `pip install rich`."
)
text = Text("\n\n(Colour Legend: ")
text.append("Prompt", style=self.prompt_color)
text.append("|")
Expand Down
5 changes: 3 additions & 2 deletions trl/models/modeling_sd_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,8 +808,9 @@ def __init__(self, pretrained_model_name: str, *, pretrained_model_revision: str
except OSError:
if use_lora:
warnings.warn(
"If you are aware that the pretrained model has no lora weights to it, ignore this message. "
"Otherwise please check the if `pytorch_lora_weights.safetensors` exists in the model folder."
"Trying to load LoRA weights but no LoRA weights found. Set `use_lora=False` or check that "
"`pytorch_lora_weights.safetensors` exists in the model folder.",
UserWarning,
)

self.sd_pipeline.scheduler = DDIMScheduler.from_config(self.sd_pipeline.scheduler.config)
Expand Down
8 changes: 0 additions & 8 deletions trl/trainer/alignprop_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,14 +139,6 @@ def to_dict(self):
return flatten_dict(output_dict)

def __post_init__(self):
if self.log_with not in ["wandb", "tensorboard"]:
warnings.warn(
"Accelerator tracking only supports image logging if `log_with` is set to 'wandb' or 'tensorboard'."
)

if self.log_with == "wandb" and not is_torchvision_available():
warnings.warn("Wandb image logging requires torchvision to be installed")

if self.train_use_8bit_adam and not is_bitsandbytes_available():
raise ImportError(
"You need to install bitsandbytes to use 8bit Adam. "
Expand Down
29 changes: 7 additions & 22 deletions trl/trainer/bco_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,17 +394,9 @@ def __init__(
ref_model_init_kwargs["torch_dtype"] = torch_dtype

if isinstance(model, str):
warnings.warn(
"You passed a model_id to the BCOTrainer. This will automatically create an "
"`AutoModelForCausalLM` or a `PeftModel` (if you passed a `peft_config`) for you."
)
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)

if isinstance(ref_model, str):
warnings.warn(
"You passed a ref model_id to the BCOTrainer. This will automatically create an "
"`AutoModelForCausalLM`"
)
ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs)

# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
Expand Down Expand Up @@ -573,8 +565,11 @@ def make_inputs_require_grad(module, input, output):
self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
warnings.warn(
"You set `output_router_logits` to True in the model config, but `router_aux_loss_coef` is set to 0.0,"
" meaning the auxiliary loss will not be used."
"You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
"`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
"greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
"loss.",
UserWarning,
)

# Underlying Distribution Matching argument
Expand Down Expand Up @@ -705,7 +700,6 @@ def make_inputs_require_grad(module, input, output):
self.running = RunningMoments(accelerator=self.accelerator)

if self.embedding_func is None:
warnings.warn("You did not pass `embedding_func` underlying distribution matching feature is deactivated.")
return

chosen_embeddings = self._get_sample_prompt_embeddings(desirable, sample_size=self.args.prompt_sample_size)
Expand Down Expand Up @@ -875,16 +869,12 @@ def _load_optimizer_and_scheduler(self, checkpoint):
return
# when loading optimizer and scheduler from checkpoint, also load the running delta object.
running_file = os.path.join(checkpoint, RUNNING_NAME)
if not os.path.isfile(running_file):
warnings.warn(f"Missing file {running_file}. Will use a new running delta value for BCO loss calculation")
else:
if os.path.isfile(running_file):
self.running = RunningMoments.load_from_json(self.accelerator, running_file)

if self.match_underlying_distribution:
clf_file = os.path.join(checkpoint, CLF_NAME)
if not os.path.isfile(running_file):
warnings.warn(f"Missing file {clf_file}. Will use a new UDM classifier for BCO loss calculation")
else:
if os.path.isfile(running_file):
self.clf.set_params(**torch.load(clf_file, weights_only=True, map_location="cpu"))

@contextmanager
Expand Down Expand Up @@ -1350,11 +1340,6 @@ def prediction_step(
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
):
if not self.use_dpo_data_collator:
warnings.warn(
"prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
"DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
)
if ignore_keys is None:
if hasattr(model, "config"):
ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
Expand Down
7 changes: 5 additions & 2 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,9 +759,12 @@ def compute_accuracy(eval_pred) -> Dict[str, float]:
predictions, labels = eval_pred
# Here, predictions is rewards_chosen and rewards_rejected.
# We want to see how much of the time rewards_chosen > rewards_rejected.
if np.array(predictions[:, 0] == predictions[:, 1], dtype=float).sum() > 0:
equal_predictions_count = np.array(predictions[:, 0] == predictions[:, 1], dtype=float).sum()
if equal_predictions_count > 0:
warnings.warn(
f"There are {np.array(predictions[:, 0] == predictions[:, 1]).sum()} out of {len(predictions[:, 0])} instances where the predictions for both options are equal. As a consequence the accuracy can be misleading."
f"There are {equal_predictions_count} out of {len(predictions[:, 0])} instances where the predictions for "
"both options are equal. As a consequence the accuracy can be misleading.",
UserWarning,
Comment on lines -774 to +767
Copy link
Member Author

@qgallouedec qgallouedec Nov 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warnings must be actionable.
Warnings should not indicate normal behavior.
Use the appropriate warning type.

This warning remains not actionable. Any idea to solve this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about promoting this to logger.info?

)
predictions = np.argmax(predictions, axis=1)

Expand Down
Loading