Skip to content

Commit

Permalink
delete unused code
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay committed Dec 6, 2024
1 parent b92df93 commit d3d5a7f
Showing 1 changed file with 11 additions and 64 deletions.
75 changes: 11 additions & 64 deletions llm/run_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import sys

import paddle
from utils.argument import EmbeddingArgument, GenerateArgument
from utils.argument import EmbeddingArgument

from paddlenlp.data import DataCollatorForEmbedding
from paddlenlp.datasets import EmbeddingIterableDataset, load_dataset
Expand All @@ -39,15 +39,14 @@


def main():
parser = PdArgumentParser((GenerateArgument, ModelConfig, DataConfig, SFTConfig, EmbeddingArgument))
parser = PdArgumentParser((ModelConfig, DataConfig, SFTConfig, EmbeddingArgument))
if len(sys.argv) >= 2 and sys.argv[1].endswith(".json"):
gen_args, model_args, data_args, training_args, embedding_args = parser.parse_json_file_and_cmd_lines()
model_args, data_args, training_args, embedding_args = parser.parse_json_file_and_cmd_lines()
else:
gen_args, model_args, data_args, training_args, embedding_args = parser.parse_args_into_dataclasses()
model_args, data_args, training_args, embedding_args = parser.parse_args_into_dataclasses()

training_args.print_config(model_args, "Model")
training_args.print_config(data_args, "Data")
training_args.print_config(gen_args, "Generation")

# Setup GPU & distributed training
paddle.set_device(training_args.device)
Expand All @@ -57,12 +56,8 @@ def main():
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16 or training_args.bf16}"
)

if (
training_args.pipeline_parallel_degree > 1
or training_args.tensor_parallel_degree > 1
or training_args.sharding_parallel_degree > 1
):
raise ValueError("Now embedding training only supports in data parallel mode.")
if training_args.pipeline_parallel_degree > 1:
raise NotImplementedError("Cannot support pipeline parallel for Embedding training now.")

# Detecting last checkpoint.
last_checkpoint = None
Expand Down Expand Up @@ -263,11 +258,6 @@ def main():
else:
padding = True

if training_args.pipeline_parallel_degree > 1:
raise NotImplementedError("Cannot support pipeline parallel for Embedding training now.")
else:
metrics = compute_metrics

data_collator_fn = DataCollatorForEmbedding(
tokenizer=tokenizer,
max_query_len=embedding_args.max_query_len,
Expand All @@ -284,7 +274,7 @@ def main():
train_dataset=train_ds,
eval_dataset=dev_ds,
tokenizer=tokenizer,
compute_metrics=metrics,
compute_metrics=compute_metrics,
data_collator=data_collator_fn,
)
trainable_parameters = [p for p in model.parameters() if not p.stop_gradient]
Expand All @@ -298,53 +288,10 @@ def main():
elif last_checkpoint is not None:
checkpoint = last_checkpoint
train_result = trainer.train(resume_from_checkpoint=checkpoint)
if training_args.benchmark:
total_effective_tokens = (
sum([len(i["input_ids"]) for i in trainer.train_dataset]) * train_result.metrics["progress_or_epoch"]
)
effective_tokens_per_second = total_effective_tokens / train_result.metrics["train_runtime"]
logger.info(f"Effective_Tokens_per_second: {effective_tokens_per_second} ")
logger.info("Benchmark done.")
else:
if model_args.save_to_aistudio:
kwargs = {}
if model_args.aistudio_token is not None:
kwargs["token"] = model_args.aistudio_token
# PEFT Model only save PEFT parameters, if pretrained model obtains from aistudio
if model_args.from_aistudio and (model_args.lora or model_args.prefix_tuning):
kwargs["base_model"] = model_args.model_name_or_path
else:
trainer.tokenizer.save_to_aistudio(
repo_id=model_args.aistudio_repo_id,
private=model_args.aistudio_repo_private,
license=model_args.aistudio_repo_license,
exist_ok=True,
**kwargs,
)
trainer.model.save_to_aistudio(
repo_id=model_args.aistudio_repo_id,
private=model_args.aistudio_repo_private,
license=model_args.aistudio_repo_license,
merge_tensor_parallel=training_args.tensor_parallel_degree > 1,
exist_ok=True,
**kwargs,
)

if not training_args.autotuner_benchmark:
trainer.save_model(merge_tensor_parallel=training_args.tensor_parallel_degree > 1)
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()

# Evaluation test set
if training_args.do_predict:
test_ds = load_dataset(
"json",
data_files=os.path.join(data_args.dataset_name_or_path, "test.json"),
lazy=data_args.lazy,
)[0]
eval_result = trainer.predict(test_ds).metrics
trainer.log_metrics("test", eval_result)
trainer.save_model(merge_tensor_parallel=training_args.tensor_parallel_degree > 1)
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()

# Evaluation dev set
if training_args.do_eval:
Expand Down

0 comments on commit d3d5a7f

Please sign in to comment.