Skip to content

Commit

Permalink
update trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay committed Dec 5, 2024
1 parent a71783b commit 759d832
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 7 deletions.
7 changes: 6 additions & 1 deletion llm/run_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,11 @@ def main():
if tokenizer.chat_template is not None:
data_args.eval_with_do_generation = False

if training_args.do_eval:
logger.warning("Warning: 'do_eval' is set to True, but will be set to False for Embedding training currently.")
training_args.do_eval = False
training_args.evaluation_strategy = "no"

if data_args.dataset_name_or_path is None:
raise ValueError(f"Please specific dataset name or path (got {data_args.dataset_name_or_path})")
elif os.path.exists(os.path.join(data_args.dataset_name_or_path, "train.json")) or os.path.exists(
Expand Down Expand Up @@ -259,7 +264,7 @@ def main():
padding = True

if training_args.pipeline_parallel_degree > 1:
metrics = None
raise NotImplementedError("Cannot support pipeline parallel for Embedding training now.")
else:
metrics = compute_metrics

Expand Down
6 changes: 3 additions & 3 deletions paddlenlp/transformers/qwen2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1696,13 +1696,13 @@ def forward(
q_reps = nn.functional.normalize(q_reps, axis=-1)
p_reps = nn.functional.normalize(p_reps, axis=-1)

if return_encode:
return q_reps, p_reps

if self.embedding_negatives_cross_device:
q_reps = self._dist_gather_tensor(q_reps)
p_reps = self._dist_gather_tensor(p_reps)

if return_encode:
return q_reps, p_reps

loss = self.in_batch_negative_loss(q_reps, p_reps)
return loss

Expand Down
5 changes: 2 additions & 3 deletions paddlenlp/trl/embedding_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,11 @@


class EmbeddingTrainer(Trainer):
def __init__(self, model_args, use_gradient_cache=False, **kwargs):
def __init__(self, model_args, **kwargs):
super().__init__(**kwargs)

self.model_args = model_args
self.embedding_negatives_cross_device = model_args.embedding_negatives_cross_device
self.use_gradient_cache = use_gradient_cache
self.accum_data = []
self.accum_freq = 0
self.accum_q_features = []
Expand Down Expand Up @@ -168,7 +167,7 @@ def training_step(
if self.args.pipeline_parallel_degree > 1:
raise NotImplementedError("Cannot support pipeline parallel for Embedding training now.")

if self.args.gradient_accumulation_steps == 1 or not self.use_gradient_cache:
if self.args.gradient_accumulation_steps == 1:
return super().training_step(model, inputs)
else:
self.forward_no_grad(model, inputs)
Expand Down

0 comments on commit 759d832

Please sign in to comment.