Skip to content

Commit 8539ef4

Browse files
committed
- accidentally introduced 'transforms' namespace
- can't use default Target("tensorrt") arg
1 parent f173fbc commit 8539ef4

File tree

6 files changed

+39
-31
lines changed

6 files changed

+39
-31
lines changed

python/tvm/relay/op/contrib/tensorrt.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,9 @@ def get_tensorrt_use_fp16() -> bool:
111111
def partition_for_tensorrt(
112112
mod: tvm.IRModule,
113113
params: Optional[Dict[str, tvm.nd.NDArray]] = None,
114-
target: tvm.target.Target = tvm.target.Target("tensorrt"),
114+
# CAUTION: Can't use default Target("tensorrt") here since the target kind is only available
115+
# if is_tensorrt_compiler_enabled() == True.
116+
target: Optional[tvm.target.Target] = None,
115117
) -> tvm.IRModule:
116118
"""Partition all functions in mod to greedily offload supported operators to TensorRT.
117119
@@ -130,8 +132,13 @@ def partition_for_tensorrt(
130132
The partitioned module.
131133
132134
"""
135+
assert is_tensorrt_compiler_enabled(), "Can only partition for TensorRT if it is enabled"
133136
if params:
134137
mod["main"] = bind_params_by_name(mod["main"], params)
138+
if target is None:
139+
# Use a default target. The get_tensorrt_target() function will similarly create an
140+
# equivalent default target when compilation continues after partitioning.
141+
target = tvm.target.Target("tensorrt")
135142

136143
seq = tvm.transform.Sequential(
137144
[

src/relay/backend/contrib/codegen_c/codegen.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -360,8 +360,8 @@ class CodegenCModule {
360360
};
361361

362362
/*! \brief The actual translation pass. */
363-
transform::Pass CCompilerImpl() {
364-
auto pass_func = [=](IRModule mod, const transform::PassContext& pass_ctx) {
363+
tvm::transform::Pass CCompilerImpl() {
364+
auto pass_func = [=](IRModule mod, const tvm::transform::PassContext& pass_ctx) {
365365
VLOG(1) << "CCompilerImpl input:" << std::endl << PrettyPrint(mod);
366366
Target target = GetCCompilerTarget();
367367

@@ -388,10 +388,10 @@ transform::Pass CCompilerImpl() {
388388
return tvm::transform::CreateModulePass(pass_func, 0, "CCompilerImpl", {});
389389
}
390390

391-
transform::Pass CCompilerPass() {
391+
tvm::transform::Pass CCompilerPass() {
392392
return transform::Sequential(
393-
{transforms::OutlineCompilerFunctionsWithExistingGlobalSymbols("ccompiler"), CCompilerImpl(),
394-
transforms::MarkCompilerFunctionsAsExtern("ccompiler")});
393+
{transform::OutlineCompilerFunctionsWithExistingGlobalSymbols("ccompiler"), CCompilerImpl(),
394+
transform::MarkCompilerFunctionsAsExtern("ccompiler")});
395395
}
396396

397397
} // namespace contrib

src/relay/backend/contrib/cutlass/codegen.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -902,8 +902,8 @@ class CutlassModuleCodegen {
902902
* \brief A small shim to redirect to the 'relay.ext.cutlass.compile_for_cutlass' Python
903903
* function which does the main CUTLASS training, c-code generation and compilation steps.
904904
*/
905-
transform::Pass CompileForCutlassImpl() {
906-
auto pass_func = [=](IRModule mod, const transform::PassContext& pass_ctx) {
905+
tvm::transform::Pass CompileForCutlassImpl() {
906+
auto pass_func = [=](IRModule mod, const tvm::transform::PassContext& pass_ctx) {
907907
VLOG(1) << "CompileForCutlass input:" << std::endl << PrettyPrint(mod);
908908
const auto* pf = runtime::Registry::Get("relay.ext.cutlass.compile_for_cutlass");
909909
ICHECK(pf != nullptr) << "Cannot find compile_for_cutlass function";
@@ -926,10 +926,10 @@ runtime::Module CreateCSourceModule(const IRModule& mod) {
926926

927927
TVM_REGISTER_GLOBAL("relay.ext.cutlass.create_c_source_module").set_body_typed(CreateCSourceModule);
928928

929-
transform::Pass CompileForCutlass() {
929+
tvm::transform::Pass CompileForCutlass() {
930930
return transform::Sequential(
931-
{transforms::OutlineCompilerFunctionsWithExistingGlobalSymbols("cutlass"),
932-
CompileForCutlassImpl(), transforms::MarkCompilerFunctionsAsExtern("cutlass")});
931+
{transform::OutlineCompilerFunctionsWithExistingGlobalSymbols("cutlass"),
932+
CompileForCutlassImpl(), transform::MarkCompilerFunctionsAsExtern("cutlass")});
933933
}
934934

935935
} // namespace cutlass

src/relay/backend/contrib/tensorrt/codegen.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -348,8 +348,8 @@ void CollectFromCompositeFunctionBody::VisitExpr_(const CallNode* call_node) {
348348
* function will require a linear scan of imported runtime modules to find the matching
349349
* TensorRTRuntimeModule implementing it.
350350
*/
351-
transform::Pass CompileForTensorRTImpl() {
352-
auto pass_func = [](IRModule mod, const transform::PassContext& pass_ctx) {
351+
tvm::transform::Pass CompileForTensorRTImpl() {
352+
auto pass_func = [](IRModule mod, const tvm::transform::PassContext& pass_ctx) {
353353
VLOG(1) << "CompileForTensorRT input:" << std::endl << PrettyPrint(mod);
354354
Target target = GetTensorRTTarget();
355355

@@ -400,10 +400,10 @@ transform::Pass CompileForTensorRTImpl() {
400400
return tvm::transform::CreateModulePass(pass_func, 0, "CompileForTensorRT", {});
401401
}
402402

403-
transform::Pass CompileForTensorRT() {
403+
tvm::transform::Pass CompileForTensorRT() {
404404
return transform::Sequential(
405-
{transforms::OutlineCompilerFunctionsWithExistingGlobalSymbols("tensorrt"),
406-
CompileForTensorRTImpl(), transforms::MarkCompilerFunctionsAsExtern("tensorrt")});
405+
{transform::OutlineCompilerFunctionsWithExistingGlobalSymbols("tensorrt"),
406+
CompileForTensorRTImpl(), transform::MarkCompilerFunctionsAsExtern("tensorrt")});
407407
}
408408

409409
} // namespace tensorrt

src/relay/transforms/compiler_function_utils.cc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,13 @@
2424

2525
#include "./compiler_function_utils.h"
2626

27-
#include "../op/call/call.h"
2827
#include "tvm/relay/analysis.h"
2928
#include "tvm/relay/expr_functor.h"
3029
#include "tvm/relay/transform.h"
3130

3231
namespace tvm {
3332
namespace relay {
34-
namespace transforms {
33+
namespace transform {
3534
namespace {
3635

3736
/*!
@@ -211,8 +210,8 @@ GlobalVar ExistingGlobalSymbolCache::GetGlobalSymbol(const Function& function) {
211210
return global_var;
212211
}
213212

214-
transform::Pass OutlineCompilerFunctions(std::shared_ptr<GlobalSymbolCache> cache,
215-
std::string compiler_filter) {
213+
tvm::transform::Pass OutlineCompilerFunctions(std::shared_ptr<GlobalSymbolCache> cache,
214+
std::string compiler_filter) {
216215
runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> pass_func =
217216
[cache = std::move(cache), compiler_filter = std::move(compiler_filter)](
218217
IRModule mod, transform::PassContext ctx) {
@@ -235,12 +234,13 @@ transform::Pass OutlineCompilerFunctions(std::shared_ptr<GlobalSymbolCache> cach
235234
}
236235

237236
// Any Java programmers in the house?
238-
transform::Pass OutlineCompilerFunctionsWithExistingGlobalSymbols(std::string compiler_filter) {
237+
tvm::transform::Pass OutlineCompilerFunctionsWithExistingGlobalSymbols(
238+
std::string compiler_filter) {
239239
return OutlineCompilerFunctions(std::make_shared<ExistingGlobalSymbolCache>(),
240240
std::move(compiler_filter));
241241
}
242242

243-
transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter) {
243+
tvm::transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter) {
244244
runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> pass_func =
245245
[compiler_filter = std::move(compiler_filter)](IRModule mod, transform::PassContext ctx) {
246246
VLOG(1) << "MarkCompilerFunctionsAsExtern input:" << std::endl << PrettyPrint(mod);
@@ -262,7 +262,7 @@ transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter) {
262262
return tvm::transform::CreateModulePass(pass_func, 0, "MarkCompilerFunctionsAsExtern", {});
263263
}
264264

265-
transform::Pass InlineCompilerFunctionsBoundTo(Array<GlobalVar> global_vars) {
265+
tvm::transform::Pass InlineCompilerFunctionsBoundTo(Array<GlobalVar> global_vars) {
266266
runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> pass_func =
267267
[global_vars = std::move(global_vars)](IRModule mod, transform::PassContext ctx) {
268268
VLOG(1) << "InlineCompilerFunctionsBoundTo with global_vars: " << PrettyPrint(global_vars);
@@ -295,6 +295,6 @@ TVM_REGISTER_GLOBAL("relay._transform.MarkCompilerFunctionsAsExtern")
295295
TVM_REGISTER_GLOBAL("relay._transform.InlineCompilerFunctionsBoundTo")
296296
.set_body_typed(InlineCompilerFunctionsBoundTo);
297297

298-
} // namespace transforms
298+
} // namespace transform
299299
} // namespace relay
300300
} // namespace tvm

src/relay/transforms/compiler_function_utils.h

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666

6767
namespace tvm {
6868
namespace relay {
69-
namespace transforms {
69+
namespace transform {
7070

7171
/*!
7272
* \brief Abstract class representing a cache of unique global vars keyed by functions. This can
@@ -105,8 +105,8 @@ class ExistingGlobalSymbolCache : public GlobalSymbolCache {
105105
* If \p compiler_filter is non-empty only functions with that as their attribute value are
106106
* outlined.
107107
*/
108-
transform::Pass OutlineCompilerFunctions(std::shared_ptr<GlobalSymbolCache> cache,
109-
std::string compiler_filter = "");
108+
tvm::transform::Pass OutlineCompilerFunctions(std::shared_ptr<GlobalSymbolCache> cache,
109+
std::string compiler_filter = "");
110110

111111
/*!
112112
* \brief A pass to outline all let-bound and literal functions in direct call positions which have
@@ -119,7 +119,8 @@ transform::Pass OutlineCompilerFunctions(std::shared_ptr<GlobalSymbolCache> cach
119119
* This pass may be useful for external codegen using the "RelayToTIR" custom pass mechanism
120120
* to prepare the IRModule before custom lowering.
121121
*/
122-
transform::Pass OutlineCompilerFunctionsWithExistingGlobalSymbols(std::string compiler_filter = "");
122+
tvm::transform::Pass OutlineCompilerFunctionsWithExistingGlobalSymbols(
123+
std::string compiler_filter = "");
123124

124125
/*!
125126
* \brief A pass to mark all global functions which have a "Compiler" attribute matching
@@ -132,7 +133,7 @@ transform::Pass OutlineCompilerFunctionsWithExistingGlobalSymbols(std::string co
132133
* This pass may be useful for external codegen using the "RelayToTIR" custom pass mechanism to
133134
* cleanup the IRModule after custom lowering.
134135
*/
135-
transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter = "");
136+
tvm::transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter = "");
136137

137138
/*!
138139
* \brief A pass to inline all global "Compiler" functions which are bound to a global var
@@ -142,9 +143,9 @@ transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter = "");
142143
* This pass may be useful for external codegen which needs to undo partitioning based on
143144
* properties of the entire partition.
144145
*/
145-
transform::Pass InlineCompilerFunctionsBoundTo(Array<GlobalVar> global_vars);
146+
tvm::transform::Pass InlineCompilerFunctionsBoundTo(Array<GlobalVar> global_vars);
146147

147-
} // namespace transforms
148+
} // namespace transform
148149
} // namespace relay
149150
} // namespace tvm
150151

0 commit comments

Comments
 (0)