Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 0 additions & 42 deletions examples/run_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,26 +37,8 @@
AutoModelForCausalLM,
AutoModelForCausalLMPipe,
AutoTokenizer,
DeepseekV2ForCausalLM,
DeepseekV2ForCausalLMPipe,
DeepseekV3ForCausalLM,
DeepseekV3ForCausalLMPipe,
Ernie4_5_MoeForCausalLM,
Ernie4_5_MoeForCausalLMPipe,
Ernie4_5ForCausalLM,
Ernie4_5ForCausalLMPipe,
Llama3Tokenizer,
LlamaForCausalLM,
LlamaForCausalLMPipe,
LlamaTokenizer,
Qwen2ForCausalLM,
Qwen2ForCausalLMPipe,
Qwen2MoeForCausalLM,
Qwen2MoeForCausalLMPipe,
Qwen3ForCausalLM,
Qwen3ForCausalLMPipe,
Qwen3MoeForCausalLM,
Qwen3MoeForCausalLMPipe,
)
from paddleformers.transformers.configuration_utils import LlmMetaConfig
from paddleformers.trl import DataConfig, ModelConfig, SFTConfig, SFTTrainer
Expand All @@ -66,27 +48,6 @@
# Fine-tune Environment Variables to support sharding stage1 overlap optimization.
os.environ["USE_CASUAL_MASK"] = "False"

flash_mask_support_list = [
DeepseekV2ForCausalLM,
DeepseekV2ForCausalLMPipe,
DeepseekV3ForCausalLM,
DeepseekV3ForCausalLMPipe,
Ernie4_5ForCausalLM,
Ernie4_5ForCausalLMPipe,
Ernie4_5_MoeForCausalLM,
Ernie4_5_MoeForCausalLMPipe,
LlamaForCausalLM,
LlamaForCausalLMPipe,
Qwen2ForCausalLM,
Qwen2ForCausalLMPipe,
Qwen2MoeForCausalLM,
Qwen2MoeForCausalLMPipe,
Qwen3ForCausalLM,
Qwen3ForCausalLMPipe,
Qwen3MoeForCausalLM,
Qwen3MoeForCausalLMPipe,
]


def main():
parser = PdArgumentParser((ModelConfig, DataConfig, SFTConfig))
Expand Down Expand Up @@ -192,9 +153,6 @@ def main():
else:
model = model_class.from_config(model_config, dtype=dtype)

Copy link
Collaborator

@lugimzzz lugimzzz Sep 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

去掉 flash_mask_support_list内容和相关import

if model_args.attn_impl == "flashmask" and not any(isinstance(model, cls) for cls in flash_mask_support_list):
raise NotImplementedError(f"{model.__class__} not support flash mask.")

if training_args.do_train and model_args.neftune:
# Inspired by https://github.com/neelsjain/NEFTune
if hasattr(model, "get_input_embeddings"):
Expand Down