Skip to content

Commit

Permalink
Add device_map config to support chatglm2-6b (#734)
Browse files Browse the repository at this point in the history
chatglm-6b和chatglm2-6b的参数命名不一致,本次提交旨在解决chatglm2-6b device_map 创建的问题。在chatglm_auto_configure_device_map 函数中新增了chatglm2-6b device_map 创建的相关代码。
  • Loading branch information
Jingsong-Yan authored Jun 30, 2023
1 parent 51ed739 commit 421ce3d
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions models/loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,18 +257,29 @@ def chatglm_auto_configure_device_map(self, num_gpus: int) -> Dict[str, int]:
# 在调用chat或者stream_chat时,input_ids会被放到model.device上
# 如果transformer.word_embeddings.device和model.device不同,则会导致RuntimeError
# 因此这里将transformer.word_embeddings,transformer.final_layernorm,lm_head都放到第一张卡上
device_map = {f'{layer_prefix}.word_embeddings': 0,

encode = ""
if 'chatglm2' in self.model_name:
device_map = {
f"{layer_prefix}.embedding.word_embeddings": 0,
f"{layer_prefix}.rotary_pos_emb": 0,
f"{layer_prefix}.output_layer": 0,
f"{layer_prefix}.encoder.final_layernorm": 0,
f"base_model.model.output_layer": 0
}
encode = ".encoder"
else:
device_map = {f'{layer_prefix}.word_embeddings': 0,
f'{layer_prefix}.final_layernorm': 0, 'lm_head': 0,
f'base_model.model.lm_head': 0, }

used = 2
gpu_target = 0
for i in range(num_trans_layers):
if used >= per_gpu_layers:
gpu_target += 1
used = 0
assert gpu_target < num_gpus
device_map[f'{layer_prefix}.layers.{i}'] = gpu_target
device_map[f'{layer_prefix}{encode}.layers.{i}'] = gpu_target
used += 1

return device_map
Expand Down

0 comments on commit 421ce3d

Please sign in to comment.