Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

310P疑似不支持dist.broadcast操作,如何支持多卡 #3118

Open
yezekun opened this issue Feb 7, 2025 · 2 comments
Open

310P疑似不支持dist.broadcast操作,如何支持多卡 #3118

yezekun opened this issue Feb 7, 2025 · 2 comments
Assignees

Comments

@yezekun
Copy link

yezekun commented Feb 7, 2025

Motivation

在ascend 310p上创建以下环境torch=2.3.1+cpu、torch-npu=2.3.1.post4

LMdeploy和dlinfer代码分别使用DeepLink-org:support_310Pyao-fengchen:support_310P
ascend 310p上进行多卡的推理任务,当tp=2时

from lmdeploy import pipeline
from lmdeploy import PytorchEngineConfig, GenerationConfig

if __name__ == "__main__":
    pipe = pipeline("/mnt/data/llm/Qwen1.5-7B-Chat/",
                    backend_config=PytorchEngineConfig(
                        tp=2,
                        device_type="ascend",
                        dtype='float16',
                        eager_mode=True,
                        cache_max_entry_count=0.5))
    # question = ["Shanghai is"]
    question = ["Shanghai is the largest city in China. Please introduce it."]
    response = pipe(question, gen_config=GenerationConfig(max_new_tokens=10))
    print(response)

存在以下问题

/home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch_npu/utils/collect_env.py:59: UserWarning: Warning: The /usr/local/Ascend/ascend-toolkit/latest owner does not match the current owner.
  warnings.warn(f"Warning: The {path} owner does not match the current owner.")
/home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch_npu/utils/collect_env.py:59: UserWarning: Warning: The /usr/local/Ascend/ascend-toolkit/8.0.0.alpha001/x86_64-linux/ascend_toolkit_install.info owner does not match the current owner.
  warnings.warn(f"Warning: The {path} owner does not match the current owner.")
/home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch_npu/contrib/transfer_to_npu.py:292: ImportWarning:
    *************************************************************************************************************
    The torch.Tensor.cuda and torch.nn.Module.cuda are replaced with torch.Tensor.npu and torch.nn.Module.npu now..
    The torch.cuda.DoubleTensor is replaced with torch.npu.FloatTensor cause the double type is not supported now..
    The backend in torch.distributed.init_process_group set to hccl now..
    The torch.cuda.* and torch.cuda.amp.* are replaced with torch.npu.* and torch.npu.amp.* now..
    The device parameters have been replaced with npu in the function below:
    torch.logspace, torch.randint, torch.hann_window, torch.rand, torch.full_like, torch.ones_like, torch.rand_like, torch.randperm, torch.arange, torch.frombuffer, torch.normal, torch._empty_per_channel_affine_quantized, torch.empty_strided, torch.empty_like, torch.scalar_tensor, torch.tril_indices, torch.bartlett_window, torch.ones, torch.sparse_coo_tensor, torch.randn, torch.kaiser_window, torch.tensor, torch.triu_indices, torch.as_tensor, torch.zeros, torch.randint_like, torch.full, torch.eye, torch._sparse_csr_tensor_unsafe, torch.empty, torch._sparse_coo_tensor_unsafe, torch.blackman_window, torch.zeros_like, torch.range, torch.sparse_csr_tensor, torch.randn_like, torch.from_file, torch._cudnn_init_dropout_state, torch._empty_affine_quantized, torch.linspace, torch.hamming_window, torch.empty_quantized, torch._pin_memory, torch.autocast, torch.load, torch.Generator, torch.set_default_device, torch.Tensor.new_empty, torch.Tensor.new_empty_strided, torch.Tensor.new_full, torch.Tensor.new_ones, torch.Tensor.new_tensor, torch.Tensor.new_zeros, torch.Tensor.to, torch.nn.Module.to, torch.nn.Module.to_empty
    *************************************************************************************************************

  warnings.warn(msg, ImportWarning)
/home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch_npu/contrib/transfer_to_npu.py:247: RuntimeWarning: torch.jit.script and torch.jit.script_method will be disabled by transfer_to_npu, which currently does not support them, if you need to enable them, please do not use transfer_to_npu.
  warnings.warn(msg, RuntimeWarning)
2025-02-07 20:14:14,095 - lmdeploy - WARNING - transformers.py:22 - LMDeploy requires transformers version: [4.33.0 ~ 4.46.1], but found version: 4.48.0
/home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch_npu/utils/collect_env.py:59: UserWarning: Warning: The /usr/local/Ascend/ascend-toolkit/latest owner does not match the current owner.
  warnings.warn(f"Warning: The {path} owner does not match the current owner.")
/home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch_npu/utils/collect_env.py:59: UserWarning: Warning: The /usr/local/Ascend/ascend-toolkit/8.0.0.alpha001/x86_64-linux/ascend_toolkit_install.info owner does not match the current owner.
  warnings.warn(f"Warning: The {path} owner does not match the current owner.")
/home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch_npu/contrib/transfer_to_npu.py:292: ImportWarning:
    *************************************************************************************************************
    The torch.Tensor.cuda and torch.nn.Module.cuda are replaced with torch.Tensor.npu and torch.nn.Module.npu now..
    The torch.cuda.DoubleTensor is replaced with torch.npu.FloatTensor cause the double type is not supported now..
    The backend in torch.distributed.init_process_group set to hccl now..
    The torch.cuda.* and torch.cuda.amp.* are replaced with torch.npu.* and torch.npu.amp.* now..
    The device parameters have been replaced with npu in the function below:
    torch.logspace, torch.randint, torch.hann_window, torch.rand, torch.full_like, torch.ones_like, torch.rand_like, torch.randperm, torch.arange, torch.frombuffer, torch.normal, torch._empty_per_channel_affine_quantized, torch.empty_strided, torch.empty_like, torch.scalar_tensor, torch.tril_indices, torch.bartlett_window, torch.ones, torch.sparse_coo_tensor, torch.randn, torch.kaiser_window, torch.tensor, torch.triu_indices, torch.as_tensor, torch.zeros, torch.randint_like, torch.full, torch.eye, torch._sparse_csr_tensor_unsafe, torch.empty, torch._sparse_coo_tensor_unsafe, torch.blackman_window, torch.zeros_like, torch.range, torch.sparse_csr_tensor, torch.randn_like, torch.from_file, torch._cudnn_init_dropout_state, torch._empty_affine_quantized, torch.linspace, torch.hamming_window, torch.empty_quantized, torch._pin_memory, torch.autocast, torch.load, torch.Generator, torch.set_default_device, torch.Tensor.new_empty, torch.Tensor.new_empty_strided, torch.Tensor.new_full, torch.Tensor.new_ones, torch.Tensor.new_tensor, torch.Tensor.new_zeros, torch.Tensor.to, torch.nn.Module.to, torch.nn.Module.to_empty
    *************************************************************************************************************

  warnings.warn(msg, ImportWarning)
/home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch_npu/contrib/transfer_to_npu.py:247: RuntimeWarning: torch.jit.script and torch.jit.script_method will be disabled by transfer_to_npu, which currently does not support them, if you need to enable them, please do not use transfer_to_npu.
  warnings.warn(msg, RuntimeWarning)
Loading weights from safetensors:   0%|                                                                                             | 0/4 [00:00<?, ?it/s]Loading weights from safetensors: 100%|█████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:03<00:00,  1.10it/s]
./home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch_npu/distributed/distributed_c10d.py:97: UserWarning: HCCL doesn't support gather at the moment. Implemented with allgather instead.
  warnings.warn("HCCL doesn't support gather at the moment. Implemented with allgather instead.")
2025-02-07 20:14:52,637 - lmdeploy - ERROR - model_agent.py:442 - Rank[1] failed.
Traceback (most recent call last):
  File "/home/yzk/lmdeploy_support_310P/lmdeploy/pytorch/engine/model_agent.py", line 439, in _start_tp_process
    func(rank, *args, **kwargs)
  File "/home/yzk/lmdeploy_support_310P/lmdeploy/pytorch/engine/model_agent.py", line 385, in _tp_model_loop
    patched_model, cache_engine, _ = _tp_build_model(rank,
                                     ^^^^^^^^^^^^^^^^^^^^^
  File "/home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/yzk/lmdeploy_support_310P/lmdeploy/pytorch/engine/model_agent.py", line 334, in _tp_build_model
    raise e
  File "/home/yzk/lmdeploy_support_310P/lmdeploy/pytorch/engine/model_agent.py", line 330, in _tp_build_model
    cache_config = _broadcast_config(cache_config)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yzk/lmdeploy_support_310P/lmdeploy/pytorch/engine/model_agent.py", line 298, in _broadcast_config
    dist.gather_object(cache_config, gathered_configs)
  File "/home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch_npu/distributed/distributed_c10d.py", line 177, in _gather_object
    _gather(
  File "/home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch_npu/distributed/distributed_c10d.py", line 99, in _gather
    dist.broadcast_object_list(recv_size_list, dst, group)
  File "/home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch/distributed/c10d_logger.py", line 75, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch/distributed/distributed_c10d.py", line 2661, in broadcast_object_list
    object_tensor = torch.empty(  # type: ignore[call-overload]
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch_npu/contrib/transfer_to_npu.py", line 153, in decorated
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
RuntimeError: Trying to create tensor with negative dimension -138608640: [-138608640]
[ERROR] 2025-02-07-20:14:52 (PID:400524, Device:1, RankID:-1) ERR01003 OPS invalid value
2025-02-07 20:15:01,109 - lmdeploy - ERROR - model_agent.py:467 - TP process 0 failed with exitcode 1.
(yzk-lmdeploy) [yzk@devserver-2efb lmdeploy_support_310P]$ [ERROR] TBE Subprocess[task_distribute] raise error[], main process disappeared!
[ERROR] TBE Subprocess[task_distribute] raise error[], main process disappeared!
[ERROR] TBE Subprocess[task_distribute] raise error[], main process disappeared!
[ERROR] TBE Subprocess[task_distribute] raise error[], main process disappeared!
[ERROR] TBE Subprocess[task_distribute] raise error[], main process disappeared!
[ERROR] TBE Subprocess[task_distribute] raise error[], main process disappeared!
[ERROR] TBE Subprocess[task_distribute] raise error[], main process disappeared!
[ERROR] TBE Subprocess[task_distribute] raise error[], main process disappeared!
/home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/multiprocessing/resource_tracker.py:254: UserWarning: resource_tracker: There appear to be 34 leaked semaphore objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '

我们把gather_object_list换成allgather_object以后,_broadcast_config可以正常同步了,然而后面的_broadcast_inputs依旧不行(gather_object_list里面也是有broadcast操作的),具体修改如下所示

    def _broadcast_config_old(cache_config):
        """broadcast cache config, use minimum cache."""
        if rank == 0:
            gathered_configs = [None] * world_size
            dist.gather_object(cache_config, gathered_configs)
            num_gpu_blocks_list = [config.num_gpu_blocks for config in gathered_configs]
            num_cpu_blocks_list = [config.num_cpu_blocks for config in gathered_configs]
            min_num_gpu_blocks = min(num_gpu_blocks_list)
            min_num_cpu_blocks = min(num_cpu_blocks_list)
            cache_config.num_cpu_blocks = min_num_cpu_blocks
            cache_config.num_gpu_blocks = min_num_gpu_blocks
            config_list = [cache_config]
        else:
            gathered_configs = None
            dist.gather_object(cache_config, gathered_configs)
            config_list = [None]
        dist.broadcast_object_list(config_list)
        return config_list[0]
    def _broadcast_config(cache_config):
        print("===yzk debug config",cache_config)
        """broadcast cache config, use minimum cache."""
        gathered_configs = [None] * world_size
        dist.all_gather_object(gathered_configs,cache_config) 
        num_gpu_blocks_list = [
            config.num_gpu_blocks for config in gathered_configs
        ]
        num_cpu_blocks_list = [
            config.num_cpu_blocks for config in gathered_configs
        ]
        min_num_gpu_blocks = min(num_gpu_blocks_list)
        min_num_cpu_blocks = min(num_cpu_blocks_list)
        cache_config.num_cpu_blocks = min_num_cpu_blocks
        cache_config.num_gpu_blocks = min_num_gpu_blocks
        config_list = [cache_config]
        return config_list[0]

运行结果如下所示

2025-02-07 20:21:37,551 - lmdeploy - WARNING - async_engine.py:625 - GenerationConfig: GenerationConfig(n=1, max_new_tokens=50, do_sample=False, top_p=1.0, top_k=50, min_p=0.0, temperature=0.8, repetition_penalty=1.0, ignore_eos=False, random_seed=None, stop_words=None, bad_words=None, stop_token_ids=[151645], bad_token_ids=None, min_new_tokens=None, skip_special_tokens=True, spaces_between_special_tokens=True, logprobs=None, response_format=None, logits_processors=None, output_logits=None, output_last_hidden_state=None)
2025-02-07 20:21:37,551 - lmdeploy - WARNING - async_engine.py:626 - Since v0.6.0, lmdeploy add `do_sample` in GenerationConfig. It defaults to False, meaning greedy decoding. Please set `do_sample=True` if sampling  decoding is needed
/home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch_npu/utils/storage.py:38: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  if self.device.type != 'cpu':
[rank1]:[W AddKernelNpu.cpp:82] Warning: The oprator of add is executed, Currently High Accuracy but Low Performance OP with 64-bit has been used, Please Do Some Cast at Python Functions with 32-bit for Better Performance! (function operator())
.2025-02-07 20:21:38,452 - lmdeploy - ERROR - model_agent.py:458 - Rank[1] failed.
Traceback (most recent call last):
  File "/home/yzk/lmdeploy_support_310P/lmdeploy/pytorch/engine/model_agent.py", line 455, in _start_tp_process
    func(rank, *args, **kwargs)
  File "/home/yzk/lmdeploy_support_310P/lmdeploy/pytorch/engine/model_agent.py", line 416, in _tp_model_loop
    model_forward(
  File "/home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/yzk/lmdeploy_support_310P/lmdeploy/pytorch/engine/model_agent.py", line 151, in model_forward
    output = model(**input_dict)
             ^^^^^^^^^^^^^^^^^^^
  File "/home/yzk/lmdeploy_support_310P/lmdeploy/pytorch/backends/graph_runner.py", line 24, in __call__
    return self.model(**kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yzk/lmdeploy_support_310P/lmdeploy/pytorch/models/qwen2.py", line 324, in forward
    hidden_states = self.model(
                    ^^^^^^^^^^^
  File "/home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yzk/lmdeploy_support_310P/lmdeploy/pytorch/models/qwen2.py", line 264, in forward
    hidden_states, residual = decoder_layer(
                              ^^^^^^^^^^^^^^
  File "/home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yzk/lmdeploy_support_310P/lmdeploy/pytorch/models/qwen2.py", line 188, in forward
    hidden_states = self.self_attn(
                    ^^^^^^^^^^^^^^^
  File "/home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yzk/lmdeploy_support_310P/lmdeploy/pytorch/models/qwen2.py", line 77, in forward
    query_states, key_states = self.apply_rotary_pos_emb(
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yzk/lmdeploy_support_310P/lmdeploy/pytorch/nn/rotary_embedding.py", line 112, in forward
    return self.impl.forward(query, key, cos, sin, inplace)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yzk/lmdeploy_support_310P/lmdeploy/pytorch/backends/dlinfer/apply_rotary_emb.py", line 21, in forward
    return apply_rotary_pos_emb(query, key, cos, sin, q_embed, k_embed)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yzk/lmdeploy_support_310P/lmdeploy/pytorch/kernels/dlinfer/apply_rotary_pos_emb.py", line 21, in apply_rotary_pos_emb
    ext_ops.apply_rotary_pos_emb(query_states_reshaped,
  File "/home/yzk/dlinfer_support_310P/dlinfer/graph/custom_op.py", line 70, in patched_func
    return func_with_default(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yzk/dlinfer_support_310P/dlinfer/ops/llm.py", line 84, in apply_rotary_pos_emb
    return vendor_ops_registry["apply_rotary_pos_emb"](
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yzk/dlinfer_support_310P/dlinfer/vendor/ascend/torch_npu_ops.py", line 80, in apply_rotary_pos_emb
    return torch.ops.npu.npu_apply_rotary_pos_emb(query, key, cos, sin, "BSND")
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch/_ops.py", line 854, in __call__
    return self_._op(*args, **(kwargs or {}))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: call aclnnApplyRotaryPosEmb failed, detail:E89999: Inner Error!
E89999: [PID: 408156] 2025-02-07-20:21:38.448.775 op[ApplyRotaryPosEmb], input is empty[FUNC:GetInputParams][FILE:apply_rotary_pos_emb_tiling.cc][LINE:163]
        TraceBack (most recent call last):
       op[ApplyRotaryPosEmb], GetInputParams failed[FUNC:CheckParams][FILE:apply_rotary_pos_emb_tiling.cc][LINE:170]
       op[ApplyRotaryPosEmb], CheckParams return failed.[FUNC:Tiling4ApplyRotaryPosEmb][FILE:apply_rotary_pos_emb_tiling.cc][LINE:447]
       ApplyRotaryPosEmb do tiling failed, ret is -1.
       Check NnopbaseExecutorDoTiling(executor) failed
       Check NnopbaseExecutorTilingAndUpdateBinInfo(executor) failed
       Check NnopbaseRunForWorkspace(*executor, workspaceSize) failed

[ERROR] 2025-02-07-20:21:38 (PID:408156, Device:1, RankID:-1) ERR01100 OPS call acl api failed
[W OpParamMaker.cpp:387] Warning: E90000: [PID: 408156] 2025-02-07-20:21:38.563.236 Compile operator failed, cause: invalid literal for int() with base 10: ''   File "/usr/local/Ascend/ascend-toolkit/latest/python/site-packages/tbe/common/repository_manager/interface.py", line 83, in cann_kb_search_get
    return public_interact_get(key)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/Ascend/ascend-toolkit/latest/python/site-packages/tbe/common/repository_manager/route.py", line 400, in public_interact_get
    status, msg_mgr = get_msg_obj()
                      ^^^^^^^^^^^^^
  File "/usr/local/Ascend/ascend-toolkit/latest/python/site-packages/tbe/common/repository_manager/route.py", line 322, in get_msg_obj
    return client_get_msg_obj(msg_file_path, file_name)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/Ascend/ascend-toolkit/latest/python/site-packages/tbe/common/repository_manager/utils/multiprocess_util.py", line 191, in client_get_msg_obj
    secrets_size = int(fp.read(MARK_DIGEST_SIZE).decode())
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

        TraceBack (most recent call last):
        Fail to get tune knowledge for op[GatherV224, GatherV2].[FUNC:UpdateNodeCompileParams][FILE:op_compiler.cc][LINE:668]
        Call OptimizeFusedGraph failed, ret:4294967295, engine_name:AIcoreEngine, graph_name:partition0_rank1_new_sub_graph3[FUNC:OptimizeSubGraph][FILE:graph_optimize.cc][LINE:119]
        subgraph 0 optimize failed[FUNC:OptimizeSubGraphWithMultiThreads][FILE:graph_manager.cc][LINE:814]
        build graph failed, graph id:23, ret:4294967295[FUNC:BuildModelWithGraphId][FILE:ge_generator.cc][LINE:1615]
        [Build][SingleOpModel]call ge interface generator.BuildSingleOpModel failed. ge result = 4294967295[FUNC:ReportCallError][FILE:log_inner.cpp][LINE:161]
        [Build][Op]Fail to build op model[FUNC:ReportInnerError][FILE:log_inner.cpp][LINE:145]
        build op model failed, result = 500002[FUNC:ReportInnerError][FILE:log_inner.cpp][LINE:145]
 (function ExecFunc)
2025-02-07 20:21:46,365 - lmdeploy - ERROR - model_agent.py:483 - TP process 0 failed with exitcode 1.
(yzk-lmdeploy) [yzk@devserver-2efb lmdeploy_support_310P]$ [ERROR] TBE Subprocess[task_distribute] raise error[], main process disappeared!
[ERROR] TBE Subprocess[task_distribute] raise error[], main process disappeared!
[ERROR] TBE Subprocess[task_distribute] raise error[], main process disappeared!
[ERROR] TBE Subprocess[task_distribute] raise error[], main process disappeared!
[ERROR] TBE Subprocess[task_distribute] raise error[], main process disappeared!
[ERROR] TBE Subprocess[task_distribute] raise error[], main process disappeared!
[ERROR] TBE Subprocess[task_distribute] raise error[], main process disappeared!
[ERROR] TBE Subprocess[task_distribute] raise error[], main process disappeared!
/home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/multiprocessing/resource_tracker.py:254: UserWarning: resource_tracker: There appear to be 34 leaked semaphore objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '

因此,我们猜想是不是310P不支持broadcast等dist操作。使用单独的python脚本测试torch.distributed的all_reduce和broadcast操作,发现all_reduce可以正常执行,broadcast不可以(测试代码见后)。最后,我们问了昇腾官方,说310P还真不支持broadcast。

Image

请问LMDeploy最近有支持310P多卡的计划吗?如果310P不支持dist.broadcast操作的话,应该怎么修改来避免进行broadcast操作呢

使用PyTorch测试HCCL

测试all_reduce

import os
import torch
import torch_npu
import torch.distributed as dist

rank_id = int(os.getenv("LOCAL_RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1))  #获取world_size和rank_id
init_method = 'tcp://127.0.0.1:23456'

torch.npu.set_device(rank_id)

dist.init_process_group(backend='hccl',
                        init_method=init_method,
                        world_size=world_size,
                        rank=rank_id)  #初始化通讯组
tensor = torch.ones((8, ), dtype=torch.float16).npu()

print(f"rank_id: {rank_id}, world_size: {world_size}")

print(f"before allreduce: {tensor}")
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)  #allreduce操作
print(f"after allreduce: {tensor}")
LOCAL_RANK=0 WORLD_SIZE=2 python /tmp/allreduce_dist_test.py
LOCAL_RANK=1 WORLD_SIZE=2 python /tmp/allreduce_dist_test.py # 另一个窗口

rank0/1输出正确结果:

before allreduce: tensor([1., 1., 1., 1., 1., 1., 1., 1.], device='npu:0', dtype=torch.float16)
after allreduce: tensor([2., 2., 2., 2., 2., 2., 2., 2.], device='npu:0', dtype=torch.float16)

测试broadcast

import os
import torch
import torch_npu
import torch.distributed as dist

rank_id = int(os.getenv("LOCAL_RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1))  #获取world_size和rank_id
init_method = 'tcp://127.0.0.1:23456'

torch.npu.set_device(rank_id)

dist.init_process_group(backend='hccl',
                        init_method=init_method,
                        world_size=world_size,
                        rank=rank_id)  #初始化通讯组
if rank_id == 0:
    tensor = torch.randn((2, 5), dtype=torch.float16).npu()
else:
    tensor = torch.zeros((2, 5), dtype=torch.float16).npu()
tensor = tensor

print(f"rank_id: {rank_id}, world_size: {world_size}")

print(f"before broadcast: {tensor}")
dist.broadcast(tensor, src=0)  # broadcast操作
print(f"after broadcast: {tensor}")
LOCAL_RANK=0 WORLD_SIZE=2 python /tmp/broadcast_dist_test.py
LOCAL_RANK=1 WORLD_SIZE=2 python /tmp/broadcast_dist_test.py # 另一个窗口

rank1输出错误结果:

.before broadcast: tensor([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]], device='npu:1', dtype=torch.float16)
after broadcast: tensor([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]], device='npu:1', dtype=torch.float16)

Related resources

No response

Additional context

No response

@jinminxi104
Copy link
Collaborator

310P刚刚跑通单卡,多卡问题我们分析下

@yezekun
Copy link
Author

yezekun commented Feb 8, 2025

310P刚刚跑通单卡,多卡问题我们分析下

好的,非常感谢,期待LMdeploy未来能在更广的平台上兼容

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants