Skip to content

Conversation

@bukejiyu
Copy link
Collaborator

@bukejiyu bukejiyu commented Oct 22, 2025

Motivation

原始PT权重加载逻辑导致 H2D性能出现数倍劣化,因此修改PTloading逻辑,提升模型 loading性能

Modifications

改动概述

除 ViT / Resampler 外的模型,PT 权重加载逻辑调整如下:
原逻辑: 加载权重 -> 转置 -> param.copy_(weight)
新逻辑: 创建与 checkpoint 对齐的参数 -> param.copy_(weight) -> after_loading_fn 负责转置

依赖paddle框架PR

  1. PR H2D copy 优化
  2. PR修复MMAP未正确释放
  3. cpu连续 copy修复

改动内容

修改了HF上PT模型的loading方式
已重构:
1.bf16
2.weightonly
3.deepgemm fp8 在线量化
4.trtion backend : fp8/Wfp8Afp8MoEMethod/triton weight only

Usage or Command

Accuracy Tests

ci/ce

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

@paddle-bot
Copy link

paddle-bot bot commented Oct 22, 2025

Thanks for your contribution!

@bukejiyu bukejiyu changed the title [loader] Refactor PT model loading [Loader] Refactor PT model loading Oct 22, 2025
@bukejiyu bukejiyu force-pushed the v1_loader_speed_up branch 2 times, most recently from a526f60 to 233ca08 Compare October 28, 2025 14:05
@bukejiyu bukejiyu force-pushed the v1_loader_speed_up branch 3 times, most recently from 6db5f59 to 5b3e605 Compare November 5, 2025 09:04
@YuanRisheng YuanRisheng requested a review from Copilot November 5, 2025 11:50
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces comprehensive support for loading PyTorch model weights in FastDeploy, with a focus on handling weight transposition between PyTorch and Paddle formats and optimizing the model loading pipeline. Key improvements include migrating to safetensors 0.7.0rc0 for direct GPU tensor loading, refactoring weight processing logic, and introducing a new post-loading processing phase.

  • Migrated safetensors dependency to version 0.7.0rc0 with direct framework integration
  • Implemented process_final_after_loading for post-loading weight transformations
  • Refactored weight transpose logic into centralized process_weight_transpose and h2d_copy functions
  • Updated quantization methods to handle PyTorch vs Paddle weight format differences

Reviewed Changes

Copilot reviewed 31 out of 32 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
requirements.txt Added safetensors 0.7.0rc0 for improved tensor loading
fastdeploy/model_executor/utils.py Added weight transpose utilities, h2d_copy, and multi-config context manager
fastdeploy/model_executor/load_weight_utils.py Updated safetensors loader to use Paddle framework, modified cache logic
fastdeploy/model_executor/layers/quantization/*.py Refactored quantization methods to handle format-specific weight shapes and transpose logic
fastdeploy/model_executor/layers/moe/*.py Updated MoE layers with format-aware weight handling and transpose operations
fastdeploy/model_executor/layers/linear.py Added transpose processing to linear layers for PyTorch format compatibility
fastdeploy/model_executor/layers/lm_head.py Implemented weight transpose in lm_head for tied embeddings
fastdeploy/model_executor/models/*.py Updated all model load_weights methods to call process_final_after_loading
fastdeploy/engine/*.py Set OMP_NUM_THREADS environment variable to 3

weight_loader(param, loaded_weight)
model_sublayer_name = re.sub(r"\.(weight)$", "", model_param_name)
process_weights_after_loading_fn(model_sublayer_name, param)
process_final_after_loading(self, self.fd_config)
Copy link

Copilot AI Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The call to process_final_after_loading is placed inside the loop, causing it to be executed for every weight loaded. This should be moved outside the loop (after line 215) to run only once after all weights are loaded, matching the pattern used in other model files like qwen3.py and qwen2.py.

Suggested change
process_final_after_loading(self, self.fd_config)
process_final_after_loading(self, self.fd_config)

Copilot uses AI. Check for mistakes.
Comment on lines 313 to 314
process_final_after_loading(self, self.fd_config)

Copy link

Copilot AI Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The call to process_final_after_loading is placed inside the loop, causing it to be executed for every weight loaded. This should be moved outside the loop (after line 312) to run only once after all weights are loaded, matching the pattern used in other model files like qwen3.py and qwen2.py.

Suggested change
process_final_after_loading(self, self.fd_config)
process_final_after_loading(self, self.fd_config)

Copilot uses AI. Check for mistakes.
Comment on lines 431 to 432
process_final_after_loading(self, self.fd_config)

Copy link

Copilot AI Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The call to process_final_after_loading is placed inside the loop, causing it to be executed for every weight loaded. This should be moved outside the loop (after line 430) to run only once after all weights are loaded, matching the pattern used in other model files like qwen3.py and qwen2.py.

Suggested change
process_final_after_loading(self, self.fd_config)
process_final_after_loading(self, self.fd_config)

Copilot uses AI. Check for mistakes.
weight_cache_dir = None
enable_cache = False
if envs.FD_ENABLE_MODEL_LOAD_CACHE:
if envs.FD_ENABLE_MODEL_LOAD_CACHE and fd_config.quant_config is not None:
Copy link

Copilot AI Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding the condition fd_config.quant_config is not None will disable caching for non-quantized models. If this is intentional behavior change, it should be documented. If caching should work for all models, this condition should be removed or modified.

Copilot uses AI. Check for mistakes.
self.quant_method: Optional[QuantMethodBase] = UnquantizedLinearMethod()

self.bias = None
if self.with_bias:
Copy link

Copilot AI Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The bias dtype was changed from self._dtype to self.weight_dtype. Ensure that self.weight_dtype is always defined when with_bias=True, as this could cause AttributeError if weight_dtype is not set during initialization.

Suggested change
if self.with_bias:
if self.with_bias:
# Ensure self.weight_dtype is set before using it
if not hasattr(self, "weight_dtype") or self.weight_dtype is None:
self.weight_dtype = self._dtype

Copilot uses AI. Check for mistakes.
"PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION": "python",
"NCCL_ALGO": "Ring",
"FLAGS_max_partition_size": int(os.getenv("FLAGS_max_partition_size", 1024)),
"OMP_NUM_THREADS": 3,
Copy link

Copilot AI Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The hardcoded value of 3 for OMP_NUM_THREADS may not be optimal for all deployment scenarios. Consider making this configurable or documenting why this specific value was chosen.

Suggested change
"OMP_NUM_THREADS": 3,
"OMP_NUM_THREADS": int(os.getenv("OMP_NUM_THREADS", 3)),

Copilot uses AI. Check for mistakes.
"FLAGS_use_append_attn": 1,
"NCCL_ALGO": "Ring",
"FLAGS_max_partition_size": int(os.getenv("FLAGS_max_partition_size", 1024)),
"OMP_NUM_THREADS": 3,
Copy link

Copilot AI Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The hardcoded value of 3 for OMP_NUM_THREADS may not be optimal for all deployment scenarios. Consider making this configurable or documenting why this specific value was chosen.

Suggested change
"OMP_NUM_THREADS": 3,
"OMP_NUM_THREADS": int(os.getenv("OMP_NUM_THREADS", 3)),

Copilot uses AI. Check for mistakes.
else:
# v0 loader or torch model format
weight_shape = layer.weight_shape
weight_scale_inv_shape = weight_scale_inv_shape
Copy link

Copilot AI Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assignment assigns a variable to itself.

Suggested change
weight_scale_inv_shape = weight_scale_inv_shape

Copilot uses AI. Check for mistakes.
Comment on lines -850 to -865
if self.nranks > 0:
if self.with_bias:
# col parallel
_set_var_distributed(self.bias, split_axis=0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里怎么删了

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

行切的bias,不用切分

Comment on lines 348 to 350
_process_quantize()
else:
_process_quantize()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

看起来这三行是可以简写的,去掉else直接把_process_quantize()写到外边

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@YuanRisheng
Copy link
Collaborator

这里PT 权重加载逻辑调整的原因是啥,哪步对性能提升有帮助呢

@bukejiyu bukejiyu force-pushed the v1_loader_speed_up branch 3 times, most recently from 1be7aef to 3b3ecf6 Compare November 6, 2025 12:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants