Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
tastelikefeet committed Dec 4, 2024
1 parent 738e9fd commit c4539fb
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 25 deletions.
File renamed without changes.
File renamed without changes.
24 changes: 11 additions & 13 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,24 @@ Swift DOCUMENTATION
:caption: Get Started

GetStarted/SWIFT安装.md
GetStarted/界面训练推理.md
GetStarted/推送模型.md
GetStarted/使用tuners.md
GetStarted/ResTuning.md
GetStarted/SCEdit.md
GetStarted/在SWIFT内使用PEFT.md
GetStarted/快速开始.md
GetStarted/界面使用.md

.. toctree::
:maxdepth: 2
:caption: Instruction

Instruction/index.md
Instruction/LLM微调文档.md
Instruction/LLM推理文档.md
Instruction/LLM评测文档.md
Instruction/LLM量化与导出文档.md
Instruction/LLM实验文档.md
Instruction/预训练及微调.md
Instruction/人类对齐.md
Instruction/推理和部署.md
Instruction/评测.md
Instruction/导出.md
Instruction/命令行参数.md
Instruction/NPU支持.md
Instruction/使用tuners.md
Instruction/支持的模型和数据集.md
Instruction/自定义与拓展.md
Instruction/推送模型.md
Instruction/ReleaseNote3.0.md
Instruction/常见问题整理.md


Expand Down
8 changes: 4 additions & 4 deletions swift/llm/argument/base_args/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,12 @@ def _init_model_info(self) -> torch.dtype:
self.model_dir = self.model_info.model_dir
self.model_type = self.model_info.model_type
if self.rope_scaling is not None and isinstance(self.rope_scaling, str):
assert self.max_length is not None
max_model_len_no_scaling = self.model_info.max_model_len_no_scaling
assert self.max_length is not None, 'Use max_model_len together with rope_scaling'
rope_scaling = self.model_info.rope_scaling or {}
max_model_len = self.model_info.max_model_len
rope_scaling_factor = 1.0
if max_model_len_no_scaling:
rope_scaling_factor = max(float(math.ceil(self.max_length / max_model_len_no_scaling)), 1.0)
if max_model_len:
rope_scaling_factor = max(float(math.ceil(self.max_length / max_model_len)), 1.0)
if rope_scaling:
rope_scaling_factor = max(rope_scaling.get('factor', -1), rope_scaling_factor)
rope_scaling['type'] = self.rope_scaling
Expand Down
3 changes: 1 addition & 2 deletions swift/llm/model/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,6 @@ def _get_model_info(model_dir: str, model_type: Optional[str], quantization_conf
quant_info = HfConfigFactory.get_quant_info(config_dict) or {}
torch_dtype = HfConfigFactory.get_torch_dtype(config_dict, quant_info)
max_model_len = HfConfigFactory.get_max_model_len(config_dict)
max_model_len_no_scaling = HfConfigFactory.get_max_model_len(config_dict, ignore_rope_scaling=True)
rope_scaling = HfConfigFactory.get_config_attr(config_dict, 'rope_scaling')

if model_type is None:
Expand All @@ -382,7 +381,7 @@ def _get_model_info(model_dir: str, model_type: Optional[str], quantization_conf
elif len(model_types) == 1:
model_type = model_types[0]

res = ModelInfo(model_type, model_dir, torch_dtype, max_model_len, max_model_len_no_scaling,
res = ModelInfo(model_type, model_dir, torch_dtype, max_model_len,
quant_info.get('quant_method'), quant_info.get('quant_bits'), rope_scaling)
return res

Expand Down
7 changes: 1 addition & 6 deletions swift/llm/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ class ModelInfo:
model_dir: str
torch_dtype: torch.dtype
max_model_len: int
max_model_len_no_scaling: int
quant_method: Literal['gptq', 'awq', 'bnb', 'aqlm', 'hqq', None]
quant_bits: int
rope_scaling: Dict[str, Any]
Expand Down Expand Up @@ -123,7 +122,7 @@ def set_config_attr(config: Union[PretrainedConfig, Dict[str, Any]], attr_name:
setattr(config, attr_name, value)

@staticmethod
def get_max_model_len(config: Union[PretrainedConfig, Dict[str, Any]], ignore_rope_scaling=False) -> Optional[int]:
def get_max_model_len(config: Union[PretrainedConfig, Dict[str, Any]]) -> Optional[int]:
"""Get the max length supported by the model"""
INF = int(1e9)
max_model_len = INF
Expand All @@ -145,10 +144,6 @@ def get_max_model_len(config: Union[PretrainedConfig, Dict[str, Any]], ignore_ro
max_model_len = min(max_model_len, max_len_key)
if max_model_len == INF:
max_model_len = None

if (not ignore_rope_scaling and max_model_len and getattr(config, 'rope_scaling', None)
and config.rope_scaling.get('factor')):
max_model_len = max(int(max_model_len * config.rope_scaling.get('factor')), max_model_len)
return max_model_len

@staticmethod
Expand Down

0 comments on commit c4539fb

Please sign in to comment.