Skip to content

Commit 751fb1d

Browse files
authored
🏛️ Improve DPO configuration documentation structure (#2561)
* better structure dpo config * fix tests * fix regex * add contributing guidelines
1 parent edabe0a commit 751fb1d

File tree

4 files changed

+335
-221
lines changed

4 files changed

+335
-221
lines changed

Diff for: CONTRIBUTING.md

+100-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ pip install -e .[dev]
3838

3939
## Fixing outstanding issues
4040

41-
If you notice an issue with the existing code and have a fix in mind, feel free to [start contributing](#create-a-pull-request) and open a Pull Request!
41+
If you notice an issue with the existing code and have a fix in mind, feel free to [start contributing](#submitting-a-pull-request-pr) and open a Pull Request!
4242

4343
## Submitting a bug-related issue or feature request
4444

@@ -257,6 +257,105 @@ That's how `make test` is implemented (without the `pip install` line)!
257257
You can specify a smaller set of tests to test only the feature
258258
you're working on.
259259
260+
### Writing documentation
261+
262+
High-quality documentation is crucial for maintaining a project that is easy to use, understand, and extend. When adding new features, ensure they are thoroughly documented to maintain consistency and clarity throughout the project.
263+
264+
To illustrate what good documentation looks like, here’s an example of a well-documented function:
265+
266+
````python
267+
def replicate_str(string: str, n: int, sep: str = " ") -> str:
268+
r"""
269+
Replicate a string `n` times with a separator.
270+
271+
Args:
272+
string (`str`):
273+
String to replicate.
274+
n (`int`):
275+
Number of times to replicate the string.
276+
sep (`str`, *optional*, defaults to `" "`):
277+
Separator to use between each replication.
278+
279+
Returns:
280+
`str`: The replicated string.
281+
282+
Examples:
283+
```python
284+
>>> replicate_str("hello", 3)
285+
"hello hello hello"
286+
>>> replicate_str("hello", 3, sep=", ")
287+
"hello, hello, hello"
288+
```
289+
"""
290+
return sep.join([string] * n)
291+
````
292+
293+
* **Line Wrapping:** Applied a consistent line wrap at column 120 to improve readability.
294+
* **Definite Articles:** Removed definite articles where possible to streamline language. (Eg: Changed "The string to replicate" to "String to replicate")
295+
* **Type Annotations:**
296+
* Always include type definitions, indicating if a parameter is optional and specifying the default value.
297+
* Note that `Optional` means that the value can be `None`, and `*optional*` means that it is not required for the user to pass a value.
298+
E.g., for arguments that can't be `None` and aren't required:
299+
300+
```python
301+
foo (`int`, *optional*, defaults to `4`):
302+
```
303+
304+
For arguments that can be `None` and are required:
305+
306+
```python
307+
foo (`Optional[int]`):
308+
```
309+
310+
for arguments that can be `None` and aren't required:
311+
312+
```python
313+
foo (`Optional[int]`, *optional*, defaults to `None`):
314+
```
315+
316+
* **String Defaults:**
317+
* Ensured that default string values are wrapped in double quotes:
318+
319+
```python
320+
defaults to `"foo"`
321+
```
322+
323+
* **Dictionary Typing:**
324+
* Replaced generic `dict` type hints with more explicit `dict[str, Any]` to clarify expected key-value pairs.
325+
* **Default Value Formatting:**
326+
* Consistently surrounded default values with backticks for improved formatting:
327+
328+
```python
329+
defaults to `4`
330+
```
331+
332+
* **Sub-sectioning:** When the number of arguments is large, consider breaking them into sub-sections for better readability.
333+
334+
```python
335+
def calculate_statistics(data: list[float], precision: int = 2, include_variance: bool = False) -> dict[str, float]:
336+
r"""
337+
Calculates basic statistics for a given dataset.
338+
339+
Args:
340+
> Data inputs
341+
342+
data (`list[float]`):
343+
A list of numerical values to analyze.
344+
345+
> Configuration parameters
346+
347+
precision (`int`, *optional*, defaults to `2`):
348+
Number of decimal places to round the results.
349+
include_variance (`bool`, *optional*, defaults to `False`):
350+
Whether to include the variance of the dataset in the results.
351+
352+
Returns:
353+
`dict[str, float]`:
354+
A dictionary containing calculated statistics such as mean, median, and optionally variance.
355+
"""
356+
...
357+
```
358+
260359
### Deprecation and Backward Compatibility
261360

262361
Our approach to deprecation and backward compatibility is flexible and based on the feature’s usage and impact. Each deprecation is carefully evaluated, aiming to balance innovation with user needs.

Diff for: tests/test_dpo_trainer.py

+12-19
Original file line numberDiff line numberDiff line change
@@ -457,9 +457,10 @@ def test_dpo_trainer_padding_token_is_none(self):
457457

458458
with self.assertRaisesRegex(
459459
ValueError,
460-
expected_regex=r"Can't find `pad_token_id` in the `processing_class`. "
461-
r"Explicitly set `tokenizer.pad_token` \(e.g. `tokenizer.pad_token = tokenizer.eos_token`\) "
462-
r"before instantiating the trainer.",
460+
expected_regex=r"`padding_value` is not specified in `DPOConfig`, and `pad_token_id` is missing in "
461+
r"the `processing_class`. Please either set the `padding_value` argument in `DPOConfig`, or set "
462+
r"`tokenizer.pad_token` \(e.g., `tokenizer.pad_token = tokenizer.eos_token`\) before instantiating "
463+
r"the trainer.",
463464
):
464465
trainer = DPOTrainer(
465466
model=self.model,
@@ -490,24 +491,16 @@ def test_dpo_trainer_w_dataset_num_proc(self):
490491
dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference")
491492

492493
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
493-
tokenizer.pad_token = None
494494

495-
with self.assertRaisesRegex(
496-
ValueError,
497-
expected_regex=r"Can't find `pad_token_id` in the `processing_class`. "
498-
r"Explicitly set `tokenizer.pad_token` \(e.g. `tokenizer.pad_token = tokenizer.eos_token`\) "
499-
r"before instantiating the trainer.",
500-
):
501-
trainer = DPOTrainer(
502-
model=self.model,
503-
ref_model=None,
504-
args=training_args,
505-
processing_class=tokenizer,
506-
train_dataset=dummy_dataset["train"],
507-
eval_dataset=dummy_dataset["test"],
508-
)
495+
trainer = DPOTrainer(
496+
model=self.model,
497+
args=training_args,
498+
processing_class=tokenizer,
499+
train_dataset=dummy_dataset["train"],
500+
eval_dataset=dummy_dataset["test"],
501+
)
509502

510-
trainer.train()
503+
trainer.train()
511504

512505
def test_tr_dpo_trainer(self):
513506
with tempfile.TemporaryDirectory() as tmp_dir:

0 commit comments

Comments
 (0)