Skip to content

Commit 2a13c68

Browse files
junrushaoLeshengJin
andcommitted
Init
Co-Authored-by: Lesheng Jin <[email protected]>
1 parent 5d0ef94 commit 2a13c68

File tree

33 files changed

+2687
-11
lines changed

33 files changed

+2687
-11
lines changed

CMakeLists.txt

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ include(cmake/utils/Utils.cmake)
66
include(cmake/utils/Summary.cmake)
77
include(cmake/utils/Linker.cmake)
88
include(cmake/utils/FindCUDA.cmake)
9+
include(cmake/utils/FindNCCL.cmake)
910
include(cmake/utils/FindOpenCL.cmake)
1011
include(cmake/utils/FindVulkan.cmake)
1112
include(cmake/utils/FindLLVM.cmake)
@@ -25,6 +26,7 @@ endif()
2526
# and add set(OPTION VALUE) to override these build options.
2627
# Alernatively, use cmake -DOPTION=VALUE through command-line.
2728
tvm_option(USE_CUDA "Build with CUDA" OFF)
29+
tvm_option(USE_NCCL "Build with NCCL" OFF)
2830
tvm_option(USE_OPENCL "Build with OpenCL" OFF)
2931
tvm_option(USE_OPENCL_ENABLE_HOST_PTR "Enable OpenCL memory object access to host" OFF)
3032
tvm_option(USE_OPENCL_GTEST "Path to OpenCL specific gtest version for runtime cpp tests." /path/to/opencl/gtest)
@@ -350,6 +352,7 @@ list(APPEND COMPILER_SRCS "src/target/datatype/myfloat/myfloat.cc")
350352
tvm_file_glob(GLOB RUNTIME_SRCS
351353
src/runtime/*.cc
352354
src/runtime/vm/*.cc
355+
src/runtime/disco/*.cc
353356
src/runtime/minrpc/*.cc
354357
src/runtime/relax_vm/*.cc
355358
)
@@ -434,6 +437,13 @@ if(USE_PROFILER)
434437
list(APPEND RUNTIME_SRCS ${RUNTIME_VM_PROFILER_SRCS})
435438
endif(USE_PROFILER)
436439

440+
if(USE_CUDA AND USE_NCCL)
441+
message(STATUS "Build with NCCL...")
442+
find_nccl(${USE_NCCL})
443+
tvm_file_glob(GLOB RUNTIME_NCCL_SRC src/runtime/disco/nccl/*.cc)
444+
list(APPEND RUNTIME_SRCS ${RUNTIME_NCCL_SRC})
445+
endif()
446+
437447
if(USE_AOT_EXECUTOR)
438448
message(STATUS "Build with AOT Executor support...")
439449
file(GLOB RUNTIME_AOT_EXECUTOR_SRCS src/runtime/aot_executor/*.cc)
@@ -850,3 +860,8 @@ if(USE_CUDA AND USE_CUTLASS)
850860
target_link_libraries(tvm PRIVATE -Wl,--no-as-needed flash_attn)
851861
target_link_libraries(tvm_runtime PRIVATE -Wl,--no-as-needed flash_attn)
852862
endif()
863+
864+
if(USE_CUDA AND USE_NCCL)
865+
target_link_libraries(tvm_runtime PRIVATE nccl)
866+
target_link_libraries(tvm PRIVATE nccl)
867+
endif()

cmake/config.cmake

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,12 @@
4848
# - /path/to/cuda: use specific path to cuda toolkit
4949
set(USE_CUDA OFF)
5050

51+
# Whether to enable NCCL support:
52+
# - ON: enable NCCL with cmake's auto search
53+
# - OFF: disable NCCL
54+
# - /path/to/nccl: use specific path to nccl
55+
set(USE_NCCL OFF)
56+
5157
# Whether enable ROCM runtime
5258
#
5359
# Possible values:

cmake/modules/LibInfo.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ function(add_lib_info src_file)
6464
TVM_INFO_USE_CPP_RTVM="${USE_CPP_RTVM}"
6565
TVM_INFO_USE_CUBLAS="${USE_CUBLAS}"
6666
TVM_INFO_USE_CUDA="${USE_CUDA}"
67+
TVM_INFO_USE_NCCL="${USE_NCCL}"
6768
TVM_INFO_USE_CUDNN="${USE_CUDNN}"
6869
TVM_INFO_USE_CUSTOM_LOGGING="${USE_CUSTOM_LOGGING}"
6970
TVM_INFO_USE_CUTLASS="${USE_CUTLASS}"

cmake/utils/FindNCCL.cmake

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
#
18+
# Variables used by this module, they can change the default behaviour and need
19+
# to be set before calling find_package:
20+
#
21+
# NCCL_ROOT - When set, this path is inspected instead of standard library
22+
# locations as the root of the NCCL installation.
23+
# The environment variable NCCL_ROOT overrides this variable.
24+
#
25+
# This module defines
26+
# Nccl_FOUND, whether nccl has been found
27+
# NCCL_INCLUDE_DIR, directory containing header
28+
# NCCL_LIBRARY, directory containing nccl library
29+
# This module assumes that the user has already called find_package(CUDA)
30+
31+
macro(find_nccl use_nccl)
32+
if(${use_nccl} MATCHES ${IS_FALSE_PATTERN})
33+
return()
34+
endif()
35+
if(${use_nccl} MATCHES ${IS_TRUE_PATTERN})
36+
find_path(NCCL_INCLUDE_DIR NAMES nccl.h)
37+
find_library(NCCL_LIBRARY NAMES nccl)
38+
else()
39+
find_path(NCCL_INCLUDE_DIR NAMES nccl.h HINTS ${use_nccl} ${use_nccl}/include)
40+
find_library(NCCL_LIBRARY NAMES nccl HINTS ${use_nccl} ${use_nccl}/lib)
41+
endif()
42+
include(FindPackageHandleStandardArgs)
43+
find_package_handle_standard_args(Nccl DEFAULT_MSG NCCL_INCLUDE_DIR NCCL_LIBRARY)
44+
if (Nccl_FOUND)
45+
message(STATUS "Found NCCL_LIBRARY: ${NCCL_LIBRARY}")
46+
message(STATUS "Found NCCL_INCLUDE_DIR: ${NCCL_INCLUDE_DIR}")
47+
add_library(nccl SHARED IMPORTED)
48+
set_target_properties(nccl
49+
PROPERTIES
50+
INTERFACE_INCLUDE_DIRECTORIES "${NCCL_INCLUDE_DIR}"
51+
IMPORTED_LOCATION "${NCCL_LIBRARY}")
52+
else()
53+
message(STATUS "NCCL not found")
54+
endif()
55+
mark_as_advanced(NCCL_INCLUDE_DIR NCCL_LIBRARY)
56+
endmacro(find_nccl)

include/tvm/relax/attrs/ccl.h

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file tvm/relax/attrs/ccl.h
22+
* \brief Attributes for ccl operators.
23+
*/
24+
#ifndef TVM_RELAX_ATTRS_CCL_H_
25+
#define TVM_RELAX_ATTRS_CCL_H_
26+
27+
#include <tvm/relax/expr.h>
28+
29+
namespace tvm {
30+
namespace relax {
31+
32+
/*! \brief Attributes used in allreduce operators */
33+
struct AllReduceAttrs : public tvm::AttrsNode<AllReduceAttrs> {
34+
String op_type;
35+
36+
TVM_DECLARE_ATTRS(AllReduceAttrs, "relax.attrs.AllReduceAttrs") {
37+
TVM_ATTR_FIELD(op_type).describe(
38+
"The type of reduction operation to be applied to the input data. Now only sum is "
39+
"supported.");
40+
}
41+
}; // struct AllReduceAttrs
42+
43+
} // namespace relax
44+
} // namespace tvm
45+
46+
#endif // TVM_RELAX_ATTRS_CCL_H_

include/tvm/runtime/packed_func.h

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1284,13 +1284,27 @@ namespace parameter_pack {
12841284

12851285
template <typename... EnumArgs>
12861286
struct EnumeratedParamPack {
1287-
struct Invoke {
1288-
template <template <size_t i, typename TArgument> class Functor, typename... ExtraParams>
1289-
static void F(ExtraParams&&... extra_params) {
1287+
struct InvokeWithoutArg {
1288+
template <template <size_t i, typename TArgument> class Functor, typename ExtraParams>
1289+
static void F(ExtraParams&& extra_params) {
12901290
using TExpander = int[];
12911291
(void)TExpander{
12921292
0,
1293-
(Functor<EnumArgs::i, typename EnumArgs::T>::F(extra_params...), 0)...,
1293+
(Functor<EnumArgs::i, typename EnumArgs::T>::F(std::forward<ExtraParams>(extra_params)),
1294+
0)...,
1295+
};
1296+
}
1297+
};
1298+
struct InvokeWithArg {
1299+
template <template <size_t i, typename TArgument> class Functor, typename ExtraParams,
1300+
typename... Params>
1301+
static void F(ExtraParams&& extra_params, Params&&... params) {
1302+
using TExpander = int[];
1303+
(void)TExpander{
1304+
0,
1305+
(Functor<EnumArgs::i, typename EnumArgs::T>::F(std::forward<ExtraParams>(extra_params),
1306+
std::forward<Params>(params)),
1307+
0)...,
12941308
};
12951309
}
12961310
};
@@ -1310,22 +1324,27 @@ struct EnumerateImpl {
13101324

13111325
template <std::size_t... id>
13121326
struct Zipper<std::integer_sequence<std::size_t, id...>> {
1313-
using T = EnumeratedParamPack<Item<id, Args>...>;
1327+
using WithoutArg = typename EnumeratedParamPack<Item<id, Args>...>::InvokeWithoutArg;
1328+
using WithArg = typename EnumeratedParamPack<Item<id, Args>...>::InvokeWithArg;
13141329
};
13151330

13161331
public:
1317-
using T = typename Zipper<std::index_sequence_for<Args...>>::T;
1332+
using WithoutArg = typename Zipper<std::index_sequence_for<Args...>>::WithoutArg;
1333+
using WithArg = typename Zipper<std::index_sequence_for<Args...>>::WithArg;
13181334
};
13191335

13201336
template <typename... Args>
1321-
using Enumerate = typename EnumerateImpl<Args...>::T;
1337+
using EnumerateWithoutArg = typename EnumerateImpl<Args...>::WithoutArg;
1338+
1339+
template <typename... Args>
1340+
using EnumerateWithArg = typename EnumerateImpl<Args...>::WithArg;
13221341

13231342
template <typename... Args>
13241343
struct ParamPack {
1325-
template <template <size_t i, typename TArgument> class Functor, typename... ExtraParams>
1326-
static void InvokeWithoutArg(ExtraParams&&... extra_params) {
1327-
Enumerate<Args...>::Invoke::template F<Functor, ExtraParams...>(
1328-
std::forward<ExtraParams>(extra_params)...);
1344+
template <template <size_t i, typename TArgument> class Functor, typename ExtraParams>
1345+
static void InvokeWithoutArg(ExtraParams&& extra_params) {
1346+
EnumerateWithoutArg<Args...>::template F<Functor, ExtraParams>(
1347+
std::forward<ExtraParams>(extra_params));
13291348
}
13301349
};
13311350

@@ -1622,6 +1641,20 @@ inline TVMRetValue PackedFunc::operator()(Args&&... args) const {
16221641
return rv;
16231642
}
16241643

1644+
template <int i, typename T>
1645+
struct TVMArgsSetterApply {
1646+
static TVM_ALWAYS_INLINE void F(TVMArgsSetter* setter, T&& value) {
1647+
(*setter)(i, std::forward<T>(value));
1648+
}
1649+
};
1650+
1651+
template <typename... Args>
1652+
void TVM_ALWAYS_INLINE PackArgs(TVMValue* values, int* type_codes, Args&&... args) {
1653+
TVMArgsSetter setter(values, type_codes);
1654+
detail::parameter_pack::EnumerateWithArg<Args...>::template F<TVMArgsSetterApply>(
1655+
&setter, std::forward<Args>(args)...);
1656+
}
1657+
16251658
namespace detail {
16261659
template <typename R, int nleft, int index, typename F>
16271660
struct unpack_call_dispatcher {

python/tvm/relax/op/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from . import image
3939
from . import memory
4040
from . import nn
41+
from . import ccl
4142

4243
# Register operator gradient functions
4344
from . import _op_gradient
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=wildcard-import
18+
"""CCL related operators."""
19+
from .ccl import *
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Operators serving for Collective Communications Library (CCL) operators"""
18+
import tvm._ffi
19+
20+
tvm._ffi._init_api("relax.op.ccl", __name__)

python/tvm/relax/op/ccl/ccl.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Relax Collective Communications Library (CCL) operators"""
18+
from . import _ffi_api
19+
from ...expr import Expr
20+
21+
22+
def allreduce(x, op_type: str = "sum"):
23+
"""Allreduce operator
24+
25+
Parameters
26+
----------
27+
x : relax.Expr
28+
The input tensor.
29+
op_type: str
30+
The type of reduction operation to be applied to the input data.
31+
Now "sum", "prod", "min", "max" and "avg" are supported.
32+
33+
Returns
34+
-------
35+
result : relax.Expr
36+
The result of allreduce.
37+
"""
38+
supported_op_types = ["sum", "prod", "min", "max", "avg"]
39+
assert (
40+
op_type in supported_op_types
41+
), f"Allreduce only supports limited reduction operations, including {supported_op_types}, but got {op_type}."
42+
return _ffi_api.allreduce(x, op_type)

0 commit comments

Comments
 (0)