Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/transformers/integrations/integration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,8 @@ def on_log(self, args, state, control, logs=None, **kwargs):
for k, v in logs.items():
if isinstance(v, (int, float)):
self.tb_writer.add_scalar(k, v, state.global_step)
elif isinstance(v, str):
self.tb_writer.add_text(k, v, state.global_step)
else:
logger.warning(
"Trainer is attempting to log a value of "
Expand Down
20 changes: 18 additions & 2 deletions src/transformers/trainer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,11 +589,21 @@ def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: Tr
class ProgressCallback(TrainerCallback):
"""
A [`TrainerCallback`] that displays the progress of training or evaluation.
You can modify `max_str_len` to control how long strings are truncated when logging.
"""

def __init__(self):
def __init__(self, max_str_len: int = 100):
"""
Initialize the callback with optional max_str_len parameter to control string truncation length.

Args:
max_str_len (`int`):
Maximum length of strings to display in logs.
Longer strings will be truncated with a message.
"""
self.training_bar = None
self.prediction_bar = None
self.max_str_len = max_str_len

def on_train_begin(self, args, state, control, **kwargs):
if state.is_world_process_zero:
Expand Down Expand Up @@ -631,7 +641,13 @@ def on_log(self, args, state, control, logs=None, **kwargs):
# but avoid doing any value pickling.
shallow_logs = {}
for k, v in logs.items():
shallow_logs[k] = v
if isinstance(v, str) and len(v) > self.max_str_len:
shallow_logs[k] = (
f"[String too long to display, length: {len(v)} > {self.max_str_len}. "
"Consider increasing `max_str_len` if needed.]"
)
else:
shallow_logs[k] = v
_ = shallow_logs.pop("total_flos", None)
# round numbers so that it looks better in console
if "epoch" in shallow_logs:
Expand Down