You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
1. I have searched related issues but cannot get the expected help.
2. The bug has not been fixed in the latest version.
3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
Describe the bug
Context
When build from source in container env, i found 2 bugs and fixed it, but i am quiet new to CUDA, so i want to discuss with you.
I use docker image comes from nvcr.io/nvidia/pytorch:24.06-py3, and the bugs can be reproduce on 0.6.3 and 0.6.4
Bug 1
Location
in lmdeploy/src/turbomind/kernels/gemm/moe_utils_v2.cu, failure log:
145.9 /workspace/lmdeploy/src/turbomind/kernels/gemm/moe_utils_v2.cu(609): error: namespace "std" has no member "cerr"
145.9 std::cerr << "/workspace/lmdeploy/src/turbomind/kernels/gemm/moe_utils_v2.cu" << "(" << 609 << "): unsupported moe config: e
xpert_num=" << experts 145.9 ^
145.9 145.9 /workspace/lmdeploy/src/turbomind/kernels/gemm/moe_utils_v2.cu(960): error: namespace "std" has no member "cerr"
145.9 std::cerr << "/workspace/lmdeploy/src/turbomind/kernels/gemm/moe_utils_v2.cu" << "(" << 960 << "): unsupported moe config: expert_num=" << expert_num 145.9 ^ 145.9 145.9 2 errors detected in the compilation of "/workspace/lmdeploy/src/turbomind/kernels/gemm/moe_utils_v2.cu".
Possible Fix
simply add #include <iostream> into that .cu file
Bug 2
Location
After fixing Bug1, another bug shows at lmdeploy/src/turbomind/kernels/gemm/test/test_utils.cu:83:36:
/workspace/lmdeploy/src/turbomind/kernels/gemm/test/test_utils.cu:83:36: required from ‘std::vector<float> turbomind::FastCompare(const T*, const T*, int, int, cudaStream_t, float, float) [with T = __half; cudaStream_t = CUstream_st*]’
/workspace/lmdeploy/src/turbomind/kernels/gemm/test/test_utils.cu:116:141: required from here
/usr/local/cuda-12.5/targets/x86_64-linux/include/cuda/std/detail/libcxx/include/__functional/invoke.h:484:16: error: static assertion failed: Attempt to use an extended __device__ lambda in a context that requires querying its return type in host code. Use a named function object, a __host__ __device__ lambda, or cuda::proclaim_return_type instead.
484 | static_assert(!__nv_is_extended_device_lambda_closure_type(_Fp),
| ~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/usr/local/cuda-12.5/targets/x86_64-linux/include/cuda/std/detail/libcxx/include/__functional/invoke.h:484:16: note: ‘!(bool)__nv_extended_device_lambda_trait_helper<__nv_dl_wrapper_t<__nv_dl_trailing_return_tag<std::vector<float> (*)(const __half*, const __half*, int, int, CUstream_st*, float, float), turbomind::FastCompare<__half>, cuda::std::__4::tuple<float, float, float, float, float, float, long int>, 1>, float, float> >::value’ evaluates to false
Possible Fix
i saw the author added a comment here, but did not do as it said, wired... so my fix would be:
change
Lmdeploy can pass the build without error after applying the fixes above. Since i am not very fimiliar with CUDA and this project, i want to know if the fixes above would cause any problem elswhere? I am glad to make my contribution if everything is fine, thanks!
Reproduction
see above
Environment
Note: This comes from the pre-built lmdeploy in my environment, maybe irrelevant to this issue
sys.platform: linux
Python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
CUDA available: True
MUSA available: False
numpy_random_seed: 2147483648
GPU 0,1,2: NVIDIA GeForce RTX 4090
CUDA_HOME: /usr/local/cuda
NVCC: Cuda compilation tools, release 12.5, V12.5.40
GCC: x86_64-linux-gnu-gcc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
PyTorch: 2.4.0+cu121
PyTorch compiling details: PyTorch built with:
- GCC 9.3
- C++ Version: 201703
- Intel(R) oneAPI Math Kernel Library Version 2022.2-Product Build 20220804 for Intel(R) 64 architecture applications
- Intel(R) MKL-DNN v3.4.2 (Git Hash 1137e04ec0b5251ca2b4400a4fd3c667ce843d67)
- OpenMP 201511 (a.k.a. OpenMP 4.5)
- LAPACK is enabled (usually provided by MKL)
- NNPACK is enabled
- CPU capability usage: AVX512
- CUDA Runtime 12.1
- NVCC architecture flags: -gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_90,code=sm_90
- CuDNN 90.1 (built against CUDA 12.4)
- Magma 2.6.1
- Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=12.1, CUDNN_VERSION=9.1.0, CXX_COMPILER=/opt/rh/devtoolset-9/root/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=0 -fabi-version=11 -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DUSE_FBGEMM -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wsuggest-override -Wno-psabi -Wno-error=pedantic -Wno-error=old-style-cast -Wno-missing-braces -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=2.4.0, USE_CUDA=ON, USE_CUDNN=ON, USE_CUSPARSELT=1, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_GLOO=ON, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=1, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, USE_ROCM_KERNEL_ASSERT=OFF,
TorchVision: 0.19.0+cu121
LMDeploy: 0.6.3+4e5cc16
transformers: 4.46.3
gradio: 5.7.1
fastapi: 0.115.5
pydantic: 2.10.2
triton: 3.0.0
NVIDIA Topology:
GPU0 GPU1 GPU2 NIC0 NIC1 CPU Affinity NUMA Affinity GPU NUMA ID
GPU0 X NODE SYS SYS SYS 0-35,72-107 0 N/A
GPU1 NODE X SYS SYS SYS 0-35,72-107 0 N/A
GPU2 SYS SYS X NODE NODE 36-71,108-143 1 N/A
NIC0 SYS SYS NODE X PIX
NIC1 SYS SYS NODE PIX X
Legend:
X = Self
SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
PIX = Connection traversing at most a single PCIe bridge
NV# = Connection traversing a bonded set of # NVLinks
NIC Legend:
NIC0: mlx5_0
NIC1: mlx5_1
Error traceback
No response
The text was updated successfully, but these errors were encountered:
Checklist
Describe the bug
Context
When build from source in container env, i found 2 bugs and fixed it, but i am quiet new to CUDA, so i want to discuss with you.
I use docker image comes from
nvcr.io/nvidia/pytorch:24.06-py3
, and the bugs can be reproduce on0.6.3
and0.6.4
Bug 1
Location
in
lmdeploy/src/turbomind/kernels/gemm/moe_utils_v2.cu
, failure log:Possible Fix
simply add
#include <iostream>
into that.cu
fileBug 2
Location
After fixing Bug1, another bug shows at
lmdeploy/src/turbomind/kernels/gemm/test/test_utils.cu:83:36
:Possible Fix
i saw the author added a comment here, but did not do as it said, wired... so my fix would be:
change
[=] __device__(auto tup) {
to
Conclusion
Lmdeploy can pass the build without error after applying the fixes above. Since i am not very fimiliar with CUDA and this project, i want to know if the fixes above would cause any problem elswhere? I am glad to make my contribution if everything is fine, thanks!
Reproduction
see above
Environment
Note: This comes from the pre-built lmdeploy in my environment, maybe irrelevant to this issue
Error traceback
No response
The text was updated successfully, but these errors were encountered: