-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
⚠️ Add warning guidelines and update codebase to follow best practices #2350
base: main
Are you sure you want to change the base?
Changes from 3 commits
350a66a
53ba260
c74421e
5f21517
dbcac07
c273e1f
0a970e7
ef398f0
f471589
753a1b8
9a0866f
3ac3123
7307365
6897470
49f5460
230a486
d68b69d
c5f8f13
dcdc6ca
7ecd6ec
efdea04
2a21c37
7d40d43
28533ed
b786461
a1b3843
4ab7fa8
fceb7dc
0e714c0
02f5b17
f08c609
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -394,17 +394,9 @@ def __init__( | |
ref_model_init_kwargs["torch_dtype"] = torch_dtype | ||
|
||
if isinstance(model, str): | ||
warnings.warn( | ||
"You passed a model_id to the BCOTrainer. This will automatically create an " | ||
"`AutoModelForCausalLM` or a `PeftModel` (if you passed a `peft_config`) for you." | ||
) | ||
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) | ||
|
||
if isinstance(ref_model, str): | ||
warnings.warn( | ||
"You passed a ref model_id to the BCOTrainer. This will automatically create an " | ||
"`AutoModelForCausalLM`" | ||
) | ||
Comment on lines
-404
to
-407
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Correct → No warning |
||
ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs) | ||
|
||
# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16` | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -759,9 +759,12 @@ def compute_accuracy(eval_pred) -> Dict[str, float]: | |
predictions, labels = eval_pred | ||
# Here, predictions is rewards_chosen and rewards_rejected. | ||
# We want to see how much of the time rewards_chosen > rewards_rejected. | ||
if np.array(predictions[:, 0] == predictions[:, 1], dtype=float).sum() > 0: | ||
equal_predictions_count = np.array(predictions[:, 0] == predictions[:, 1], dtype=float).sum() | ||
if equal_predictions_count > 0: | ||
warnings.warn( | ||
f"There are {np.array(predictions[:, 0] == predictions[:, 1]).sum()} out of {len(predictions[:, 0])} instances where the predictions for both options are equal. As a consequence the accuracy can be misleading." | ||
f"There are {equal_predictions_count} out of {len(predictions[:, 0])} instances where the predictions for " | ||
"both options are equal. As a consequence the accuracy can be misleading.", | ||
UserWarning, | ||
Comment on lines
-774
to
+767
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Warnings must be actionable. This warning remains not actionable. Any idea to solve this? |
||
) | ||
predictions = np.argmax(predictions, axis=1) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct → No warning