We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
在ascend 310p上创建以下环境torch=2.3.1+cpu、torch-npu=2.3.1.post4
LMdeploy和dlinfer代码分别使用DeepLink-org:support_310P和yao-fengchen:support_310P ascend 310p上进行多卡的推理任务,当tp=2时
from lmdeploy import pipeline from lmdeploy import PytorchEngineConfig, GenerationConfig if __name__ == "__main__": pipe = pipeline("/mnt/data/llm/Qwen1.5-7B-Chat/", backend_config=PytorchEngineConfig( tp=2, device_type="ascend", dtype='float16', eager_mode=True, cache_max_entry_count=0.5)) # question = ["Shanghai is"] question = ["Shanghai is the largest city in China. Please introduce it."] response = pipe(question, gen_config=GenerationConfig(max_new_tokens=10)) print(response)
存在以下问题
/home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch_npu/utils/collect_env.py:59: UserWarning: Warning: The /usr/local/Ascend/ascend-toolkit/latest owner does not match the current owner. warnings.warn(f"Warning: The {path} owner does not match the current owner.") /home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch_npu/utils/collect_env.py:59: UserWarning: Warning: The /usr/local/Ascend/ascend-toolkit/8.0.0.alpha001/x86_64-linux/ascend_toolkit_install.info owner does not match the current owner. warnings.warn(f"Warning: The {path} owner does not match the current owner.") /home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch_npu/contrib/transfer_to_npu.py:292: ImportWarning: ************************************************************************************************************* The torch.Tensor.cuda and torch.nn.Module.cuda are replaced with torch.Tensor.npu and torch.nn.Module.npu now.. The torch.cuda.DoubleTensor is replaced with torch.npu.FloatTensor cause the double type is not supported now.. The backend in torch.distributed.init_process_group set to hccl now.. The torch.cuda.* and torch.cuda.amp.* are replaced with torch.npu.* and torch.npu.amp.* now.. The device parameters have been replaced with npu in the function below: torch.logspace, torch.randint, torch.hann_window, torch.rand, torch.full_like, torch.ones_like, torch.rand_like, torch.randperm, torch.arange, torch.frombuffer, torch.normal, torch._empty_per_channel_affine_quantized, torch.empty_strided, torch.empty_like, torch.scalar_tensor, torch.tril_indices, torch.bartlett_window, torch.ones, torch.sparse_coo_tensor, torch.randn, torch.kaiser_window, torch.tensor, torch.triu_indices, torch.as_tensor, torch.zeros, torch.randint_like, torch.full, torch.eye, torch._sparse_csr_tensor_unsafe, torch.empty, torch._sparse_coo_tensor_unsafe, torch.blackman_window, torch.zeros_like, torch.range, torch.sparse_csr_tensor, torch.randn_like, torch.from_file, torch._cudnn_init_dropout_state, torch._empty_affine_quantized, torch.linspace, torch.hamming_window, torch.empty_quantized, torch._pin_memory, torch.autocast, torch.load, torch.Generator, torch.set_default_device, torch.Tensor.new_empty, torch.Tensor.new_empty_strided, torch.Tensor.new_full, torch.Tensor.new_ones, torch.Tensor.new_tensor, torch.Tensor.new_zeros, torch.Tensor.to, torch.nn.Module.to, torch.nn.Module.to_empty ************************************************************************************************************* warnings.warn(msg, ImportWarning) /home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch_npu/contrib/transfer_to_npu.py:247: RuntimeWarning: torch.jit.script and torch.jit.script_method will be disabled by transfer_to_npu, which currently does not support them, if you need to enable them, please do not use transfer_to_npu. warnings.warn(msg, RuntimeWarning) 2025-02-07 20:14:14,095 - lmdeploy - WARNING - transformers.py:22 - LMDeploy requires transformers version: [4.33.0 ~ 4.46.1], but found version: 4.48.0 /home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch_npu/utils/collect_env.py:59: UserWarning: Warning: The /usr/local/Ascend/ascend-toolkit/latest owner does not match the current owner. warnings.warn(f"Warning: The {path} owner does not match the current owner.") /home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch_npu/utils/collect_env.py:59: UserWarning: Warning: The /usr/local/Ascend/ascend-toolkit/8.0.0.alpha001/x86_64-linux/ascend_toolkit_install.info owner does not match the current owner. warnings.warn(f"Warning: The {path} owner does not match the current owner.") /home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch_npu/contrib/transfer_to_npu.py:292: ImportWarning: ************************************************************************************************************* The torch.Tensor.cuda and torch.nn.Module.cuda are replaced with torch.Tensor.npu and torch.nn.Module.npu now.. The torch.cuda.DoubleTensor is replaced with torch.npu.FloatTensor cause the double type is not supported now.. The backend in torch.distributed.init_process_group set to hccl now.. The torch.cuda.* and torch.cuda.amp.* are replaced with torch.npu.* and torch.npu.amp.* now.. The device parameters have been replaced with npu in the function below: torch.logspace, torch.randint, torch.hann_window, torch.rand, torch.full_like, torch.ones_like, torch.rand_like, torch.randperm, torch.arange, torch.frombuffer, torch.normal, torch._empty_per_channel_affine_quantized, torch.empty_strided, torch.empty_like, torch.scalar_tensor, torch.tril_indices, torch.bartlett_window, torch.ones, torch.sparse_coo_tensor, torch.randn, torch.kaiser_window, torch.tensor, torch.triu_indices, torch.as_tensor, torch.zeros, torch.randint_like, torch.full, torch.eye, torch._sparse_csr_tensor_unsafe, torch.empty, torch._sparse_coo_tensor_unsafe, torch.blackman_window, torch.zeros_like, torch.range, torch.sparse_csr_tensor, torch.randn_like, torch.from_file, torch._cudnn_init_dropout_state, torch._empty_affine_quantized, torch.linspace, torch.hamming_window, torch.empty_quantized, torch._pin_memory, torch.autocast, torch.load, torch.Generator, torch.set_default_device, torch.Tensor.new_empty, torch.Tensor.new_empty_strided, torch.Tensor.new_full, torch.Tensor.new_ones, torch.Tensor.new_tensor, torch.Tensor.new_zeros, torch.Tensor.to, torch.nn.Module.to, torch.nn.Module.to_empty ************************************************************************************************************* warnings.warn(msg, ImportWarning) /home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch_npu/contrib/transfer_to_npu.py:247: RuntimeWarning: torch.jit.script and torch.jit.script_method will be disabled by transfer_to_npu, which currently does not support them, if you need to enable them, please do not use transfer_to_npu. warnings.warn(msg, RuntimeWarning) Loading weights from safetensors: 0%| | 0/4 [00:00<?, ?it/s]Loading weights from safetensors: 100%|█████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:03<00:00, 1.10it/s] ./home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch_npu/distributed/distributed_c10d.py:97: UserWarning: HCCL doesn't support gather at the moment. Implemented with allgather instead. warnings.warn("HCCL doesn't support gather at the moment. Implemented with allgather instead.") 2025-02-07 20:14:52,637 - lmdeploy - ERROR - model_agent.py:442 - Rank[1] failed. Traceback (most recent call last): File "/home/yzk/lmdeploy_support_310P/lmdeploy/pytorch/engine/model_agent.py", line 439, in _start_tp_process func(rank, *args, **kwargs) File "/home/yzk/lmdeploy_support_310P/lmdeploy/pytorch/engine/model_agent.py", line 385, in _tp_model_loop patched_model, cache_engine, _ = _tp_build_model(rank, ^^^^^^^^^^^^^^^^^^^^^ File "/home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "/home/yzk/lmdeploy_support_310P/lmdeploy/pytorch/engine/model_agent.py", line 334, in _tp_build_model raise e File "/home/yzk/lmdeploy_support_310P/lmdeploy/pytorch/engine/model_agent.py", line 330, in _tp_build_model cache_config = _broadcast_config(cache_config) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/yzk/lmdeploy_support_310P/lmdeploy/pytorch/engine/model_agent.py", line 298, in _broadcast_config dist.gather_object(cache_config, gathered_configs) File "/home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch_npu/distributed/distributed_c10d.py", line 177, in _gather_object _gather( File "/home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch_npu/distributed/distributed_c10d.py", line 99, in _gather dist.broadcast_object_list(recv_size_list, dst, group) File "/home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch/distributed/c10d_logger.py", line 75, in wrapper return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "/home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch/distributed/distributed_c10d.py", line 2661, in broadcast_object_list object_tensor = torch.empty( # type: ignore[call-overload] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch_npu/contrib/transfer_to_npu.py", line 153, in decorated return fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ RuntimeError: Trying to create tensor with negative dimension -138608640: [-138608640] [ERROR] 2025-02-07-20:14:52 (PID:400524, Device:1, RankID:-1) ERR01003 OPS invalid value 2025-02-07 20:15:01,109 - lmdeploy - ERROR - model_agent.py:467 - TP process 0 failed with exitcode 1. (yzk-lmdeploy) [yzk@devserver-2efb lmdeploy_support_310P]$ [ERROR] TBE Subprocess[task_distribute] raise error[], main process disappeared! [ERROR] TBE Subprocess[task_distribute] raise error[], main process disappeared! [ERROR] TBE Subprocess[task_distribute] raise error[], main process disappeared! [ERROR] TBE Subprocess[task_distribute] raise error[], main process disappeared! [ERROR] TBE Subprocess[task_distribute] raise error[], main process disappeared! [ERROR] TBE Subprocess[task_distribute] raise error[], main process disappeared! [ERROR] TBE Subprocess[task_distribute] raise error[], main process disappeared! [ERROR] TBE Subprocess[task_distribute] raise error[], main process disappeared! /home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/multiprocessing/resource_tracker.py:254: UserWarning: resource_tracker: There appear to be 34 leaked semaphore objects to clean up at shutdown warnings.warn('resource_tracker: There appear to be %d '
我们把gather_object_list换成allgather_object以后,_broadcast_config可以正常同步了,然而后面的_broadcast_inputs依旧不行(gather_object_list里面也是有broadcast操作的),具体修改如下所示
def _broadcast_config_old(cache_config): """broadcast cache config, use minimum cache.""" if rank == 0: gathered_configs = [None] * world_size dist.gather_object(cache_config, gathered_configs) num_gpu_blocks_list = [config.num_gpu_blocks for config in gathered_configs] num_cpu_blocks_list = [config.num_cpu_blocks for config in gathered_configs] min_num_gpu_blocks = min(num_gpu_blocks_list) min_num_cpu_blocks = min(num_cpu_blocks_list) cache_config.num_cpu_blocks = min_num_cpu_blocks cache_config.num_gpu_blocks = min_num_gpu_blocks config_list = [cache_config] else: gathered_configs = None dist.gather_object(cache_config, gathered_configs) config_list = [None] dist.broadcast_object_list(config_list) return config_list[0] def _broadcast_config(cache_config): print("===yzk debug config",cache_config) """broadcast cache config, use minimum cache.""" gathered_configs = [None] * world_size dist.all_gather_object(gathered_configs,cache_config) num_gpu_blocks_list = [ config.num_gpu_blocks for config in gathered_configs ] num_cpu_blocks_list = [ config.num_cpu_blocks for config in gathered_configs ] min_num_gpu_blocks = min(num_gpu_blocks_list) min_num_cpu_blocks = min(num_cpu_blocks_list) cache_config.num_cpu_blocks = min_num_cpu_blocks cache_config.num_gpu_blocks = min_num_gpu_blocks config_list = [cache_config] return config_list[0]
运行结果如下所示
2025-02-07 20:21:37,551 - lmdeploy - WARNING - async_engine.py:625 - GenerationConfig: GenerationConfig(n=1, max_new_tokens=50, do_sample=False, top_p=1.0, top_k=50, min_p=0.0, temperature=0.8, repetition_penalty=1.0, ignore_eos=False, random_seed=None, stop_words=None, bad_words=None, stop_token_ids=[151645], bad_token_ids=None, min_new_tokens=None, skip_special_tokens=True, spaces_between_special_tokens=True, logprobs=None, response_format=None, logits_processors=None, output_logits=None, output_last_hidden_state=None) 2025-02-07 20:21:37,551 - lmdeploy - WARNING - async_engine.py:626 - Since v0.6.0, lmdeploy add `do_sample` in GenerationConfig. It defaults to False, meaning greedy decoding. Please set `do_sample=True` if sampling decoding is needed /home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch_npu/utils/storage.py:38: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage() if self.device.type != 'cpu': [rank1]:[W AddKernelNpu.cpp:82] Warning: The oprator of add is executed, Currently High Accuracy but Low Performance OP with 64-bit has been used, Please Do Some Cast at Python Functions with 32-bit for Better Performance! (function operator()) .2025-02-07 20:21:38,452 - lmdeploy - ERROR - model_agent.py:458 - Rank[1] failed. Traceback (most recent call last): File "/home/yzk/lmdeploy_support_310P/lmdeploy/pytorch/engine/model_agent.py", line 455, in _start_tp_process func(rank, *args, **kwargs) File "/home/yzk/lmdeploy_support_310P/lmdeploy/pytorch/engine/model_agent.py", line 416, in _tp_model_loop model_forward( File "/home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "/home/yzk/lmdeploy_support_310P/lmdeploy/pytorch/engine/model_agent.py", line 151, in model_forward output = model(**input_dict) ^^^^^^^^^^^^^^^^^^^ File "/home/yzk/lmdeploy_support_310P/lmdeploy/pytorch/backends/graph_runner.py", line 24, in __call__ return self.model(**kwargs) ^^^^^^^^^^^^^^^^^^^^ File "/home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/yzk/lmdeploy_support_310P/lmdeploy/pytorch/models/qwen2.py", line 324, in forward hidden_states = self.model( ^^^^^^^^^^^ File "/home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/yzk/lmdeploy_support_310P/lmdeploy/pytorch/models/qwen2.py", line 264, in forward hidden_states, residual = decoder_layer( ^^^^^^^^^^^^^^ File "/home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/yzk/lmdeploy_support_310P/lmdeploy/pytorch/models/qwen2.py", line 188, in forward hidden_states = self.self_attn( ^^^^^^^^^^^^^^^ File "/home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/yzk/lmdeploy_support_310P/lmdeploy/pytorch/models/qwen2.py", line 77, in forward query_states, key_states = self.apply_rotary_pos_emb( ^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/yzk/lmdeploy_support_310P/lmdeploy/pytorch/nn/rotary_embedding.py", line 112, in forward return self.impl.forward(query, key, cos, sin, inplace) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/yzk/lmdeploy_support_310P/lmdeploy/pytorch/backends/dlinfer/apply_rotary_emb.py", line 21, in forward return apply_rotary_pos_emb(query, key, cos, sin, q_embed, k_embed) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/yzk/lmdeploy_support_310P/lmdeploy/pytorch/kernels/dlinfer/apply_rotary_pos_emb.py", line 21, in apply_rotary_pos_emb ext_ops.apply_rotary_pos_emb(query_states_reshaped, File "/home/yzk/dlinfer_support_310P/dlinfer/graph/custom_op.py", line 70, in patched_func return func_with_default(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/yzk/dlinfer_support_310P/dlinfer/ops/llm.py", line 84, in apply_rotary_pos_emb return vendor_ops_registry["apply_rotary_pos_emb"]( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/yzk/dlinfer_support_310P/dlinfer/vendor/ascend/torch_npu_ops.py", line 80, in apply_rotary_pos_emb return torch.ops.npu.npu_apply_rotary_pos_emb(query, key, cos, sin, "BSND") ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/site-packages/torch/_ops.py", line 854, in __call__ return self_._op(*args, **(kwargs or {})) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: call aclnnApplyRotaryPosEmb failed, detail:E89999: Inner Error! E89999: [PID: 408156] 2025-02-07-20:21:38.448.775 op[ApplyRotaryPosEmb], input is empty[FUNC:GetInputParams][FILE:apply_rotary_pos_emb_tiling.cc][LINE:163] TraceBack (most recent call last): op[ApplyRotaryPosEmb], GetInputParams failed[FUNC:CheckParams][FILE:apply_rotary_pos_emb_tiling.cc][LINE:170] op[ApplyRotaryPosEmb], CheckParams return failed.[FUNC:Tiling4ApplyRotaryPosEmb][FILE:apply_rotary_pos_emb_tiling.cc][LINE:447] ApplyRotaryPosEmb do tiling failed, ret is -1. Check NnopbaseExecutorDoTiling(executor) failed Check NnopbaseExecutorTilingAndUpdateBinInfo(executor) failed Check NnopbaseRunForWorkspace(*executor, workspaceSize) failed [ERROR] 2025-02-07-20:21:38 (PID:408156, Device:1, RankID:-1) ERR01100 OPS call acl api failed [W OpParamMaker.cpp:387] Warning: E90000: [PID: 408156] 2025-02-07-20:21:38.563.236 Compile operator failed, cause: invalid literal for int() with base 10: '' File "/usr/local/Ascend/ascend-toolkit/latest/python/site-packages/tbe/common/repository_manager/interface.py", line 83, in cann_kb_search_get return public_interact_get(key) ^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/Ascend/ascend-toolkit/latest/python/site-packages/tbe/common/repository_manager/route.py", line 400, in public_interact_get status, msg_mgr = get_msg_obj() ^^^^^^^^^^^^^ File "/usr/local/Ascend/ascend-toolkit/latest/python/site-packages/tbe/common/repository_manager/route.py", line 322, in get_msg_obj return client_get_msg_obj(msg_file_path, file_name) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/Ascend/ascend-toolkit/latest/python/site-packages/tbe/common/repository_manager/utils/multiprocess_util.py", line 191, in client_get_msg_obj secrets_size = int(fp.read(MARK_DIGEST_SIZE).decode()) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ TraceBack (most recent call last): Fail to get tune knowledge for op[GatherV224, GatherV2].[FUNC:UpdateNodeCompileParams][FILE:op_compiler.cc][LINE:668] Call OptimizeFusedGraph failed, ret:4294967295, engine_name:AIcoreEngine, graph_name:partition0_rank1_new_sub_graph3[FUNC:OptimizeSubGraph][FILE:graph_optimize.cc][LINE:119] subgraph 0 optimize failed[FUNC:OptimizeSubGraphWithMultiThreads][FILE:graph_manager.cc][LINE:814] build graph failed, graph id:23, ret:4294967295[FUNC:BuildModelWithGraphId][FILE:ge_generator.cc][LINE:1615] [Build][SingleOpModel]call ge interface generator.BuildSingleOpModel failed. ge result = 4294967295[FUNC:ReportCallError][FILE:log_inner.cpp][LINE:161] [Build][Op]Fail to build op model[FUNC:ReportInnerError][FILE:log_inner.cpp][LINE:145] build op model failed, result = 500002[FUNC:ReportInnerError][FILE:log_inner.cpp][LINE:145] (function ExecFunc) 2025-02-07 20:21:46,365 - lmdeploy - ERROR - model_agent.py:483 - TP process 0 failed with exitcode 1. (yzk-lmdeploy) [yzk@devserver-2efb lmdeploy_support_310P]$ [ERROR] TBE Subprocess[task_distribute] raise error[], main process disappeared! [ERROR] TBE Subprocess[task_distribute] raise error[], main process disappeared! [ERROR] TBE Subprocess[task_distribute] raise error[], main process disappeared! [ERROR] TBE Subprocess[task_distribute] raise error[], main process disappeared! [ERROR] TBE Subprocess[task_distribute] raise error[], main process disappeared! [ERROR] TBE Subprocess[task_distribute] raise error[], main process disappeared! [ERROR] TBE Subprocess[task_distribute] raise error[], main process disappeared! [ERROR] TBE Subprocess[task_distribute] raise error[], main process disappeared! /home/yzk/.conda/envs/yzk-lmdeploy/lib/python3.11/multiprocessing/resource_tracker.py:254: UserWarning: resource_tracker: There appear to be 34 leaked semaphore objects to clean up at shutdown warnings.warn('resource_tracker: There appear to be %d '
因此,我们猜想是不是310P不支持broadcast等dist操作。使用单独的python脚本测试torch.distributed的all_reduce和broadcast操作,发现all_reduce可以正常执行,broadcast不可以(测试代码见后)。最后,我们问了昇腾官方,说310P还真不支持broadcast。
请问LMDeploy最近有支持310P多卡的计划吗?如果310P不支持dist.broadcast操作的话,应该怎么修改来避免进行broadcast操作呢
import os import torch import torch_npu import torch.distributed as dist rank_id = int(os.getenv("LOCAL_RANK", 0)) world_size = int(os.getenv("WORLD_SIZE", 1)) #获取world_size和rank_id init_method = 'tcp://127.0.0.1:23456' torch.npu.set_device(rank_id) dist.init_process_group(backend='hccl', init_method=init_method, world_size=world_size, rank=rank_id) #初始化通讯组 tensor = torch.ones((8, ), dtype=torch.float16).npu() print(f"rank_id: {rank_id}, world_size: {world_size}") print(f"before allreduce: {tensor}") dist.all_reduce(tensor, op=dist.ReduceOp.SUM) #allreduce操作 print(f"after allreduce: {tensor}")
LOCAL_RANK=0 WORLD_SIZE=2 python /tmp/allreduce_dist_test.py LOCAL_RANK=1 WORLD_SIZE=2 python /tmp/allreduce_dist_test.py # 另一个窗口
rank0/1输出正确结果:
before allreduce: tensor([1., 1., 1., 1., 1., 1., 1., 1.], device='npu:0', dtype=torch.float16) after allreduce: tensor([2., 2., 2., 2., 2., 2., 2., 2.], device='npu:0', dtype=torch.float16)
import os import torch import torch_npu import torch.distributed as dist rank_id = int(os.getenv("LOCAL_RANK", 0)) world_size = int(os.getenv("WORLD_SIZE", 1)) #获取world_size和rank_id init_method = 'tcp://127.0.0.1:23456' torch.npu.set_device(rank_id) dist.init_process_group(backend='hccl', init_method=init_method, world_size=world_size, rank=rank_id) #初始化通讯组 if rank_id == 0: tensor = torch.randn((2, 5), dtype=torch.float16).npu() else: tensor = torch.zeros((2, 5), dtype=torch.float16).npu() tensor = tensor print(f"rank_id: {rank_id}, world_size: {world_size}") print(f"before broadcast: {tensor}") dist.broadcast(tensor, src=0) # broadcast操作 print(f"after broadcast: {tensor}")
LOCAL_RANK=0 WORLD_SIZE=2 python /tmp/broadcast_dist_test.py LOCAL_RANK=1 WORLD_SIZE=2 python /tmp/broadcast_dist_test.py # 另一个窗口
rank1输出错误结果:
.before broadcast: tensor([[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]], device='npu:1', dtype=torch.float16) after broadcast: tensor([[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]], device='npu:1', dtype=torch.float16)
No response
The text was updated successfully, but these errors were encountered:
310P刚刚跑通单卡,多卡问题我们分析下
Sorry, something went wrong.
好的,非常感谢,期待LMdeploy未来能在更广的平台上兼容
jinminxi104
yao-fengchen
No branches or pull requests
Motivation
在ascend 310p上创建以下环境torch=2.3.1+cpu、torch-npu=2.3.1.post4
LMdeploy和dlinfer代码分别使用DeepLink-org:support_310P和yao-fengchen:support_310P
ascend 310p上进行多卡的推理任务,当tp=2时
存在以下问题
我们把gather_object_list换成allgather_object以后,_broadcast_config可以正常同步了,然而后面的_broadcast_inputs依旧不行(gather_object_list里面也是有broadcast操作的),具体修改如下所示
运行结果如下所示
因此,我们猜想是不是310P不支持broadcast等dist操作。使用单独的python脚本测试torch.distributed的all_reduce和broadcast操作,发现all_reduce可以正常执行,broadcast不可以(测试代码见后)。最后,我们问了昇腾官方,说310P还真不支持broadcast。
请问LMDeploy最近有支持310P多卡的计划吗?如果310P不支持dist.broadcast操作的话,应该怎么修改来避免进行broadcast操作呢
使用PyTorch测试HCCL
测试all_reduce
rank0/1输出正确结果:
测试broadcast
rank1输出错误结果:
Related resources
No response
Additional context
No response
The text was updated successfully, but these errors were encountered: