From 6248b5db43505fbcfb13cc289d11877d5d2649e8 Mon Sep 17 00:00:00 2001 From: Nguyen Duy Loc <77536430+locnd182644@users.noreply.github.com> Date: Sat, 13 Dec 2025 14:29:23 +0700 Subject: [PATCH 01/11] [Relax][Torch] Fixed issues related to sum op when without dim and keep dim (#18583) ## Issue 1: Without Dim ### Summary: In _sum function (BaseFXGraphImporter), after retrieve_args, args[1] = [] and still pass into relax.op.sum so the result is incorrect. ### Steps to Reproduce - Module ``` class SumWithoutDim(nn.Module): def forward(self, x): return torch.sum(x) ``` ``` class Module: def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tuple(R.Tensor((2, 3), dtype="float32")): with R.dataflow(): lv: R.Tensor((2, 3), dtype="float32") = R.sum(x, axis=[], keepdims=False) gv: R.Tuple(R.Tensor((2, 3), dtype="float32")) = (lv,) R.output(gv) return gv ``` - Result: Input: tensor([[1., 1., 1.], [1., 1., 1.]]) Torch output: tensor(6.) Torch output shape: torch.Size([]) TVM output: [[1. 1. 1.] [1. 1. 1.]] TVM output shape: (2, 3) ### Expected ``` class Module: def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32")): with R.dataflow(): lv: R.Tensor((), dtype="float32") = R.sum(x, axis=None, keepdims=False) gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,) R.output(gv) return gv ``` - Result: TVM output: 6.0; TVM output shape: () ## Issue 2: Keep Dim ### Summary: In _sum function (BaseFXGraphImporter), previously keepdim value get only from node.kwargs and no pass into relax.op.sum. Now keepdim get more from args[2] and pass into. ### Steps to Reproduce - Module ``` class SumKeepDim(nn.Module): def forward(self, x): return torch.sum(x, dim=1, keepdim=True) ``` ``` class Module: def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tuple(R.Tensor((2,), dtype="float32")): with R.dataflow(): lv: R.Tensor((2,), dtype="float32") = R.sum(x, axis=[1], keepdims=False) gv: R.Tuple(R.Tensor((2,), dtype="float32")) = (lv,) R.output(gv) return gv ``` - Result: Input: tensor([[1., 1., 1.], [1., 1., 1.]]) Torch output: tensor([[3.], [3.]]) Torch output shape: torch.Size([2, 1]) TVM VM output: [3. 3.] TVM VM output shape: (2,) ### Expected ``` class Module: def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tuple(R.Tensor((2, 1), dtype="float32")): with R.dataflow(): lv: R.Tensor((2, 1), dtype="float32") = R.sum(x, axis=[1], keepdims=True) gv: R.Tuple(R.Tensor((2, 1), dtype="float32")) = (lv,) R.output(gv) return gv ``` - Result: TVM output: [[3.] [3.]] ;TVM output shape: (2, 1) --- .../torch/base_fx_graph_translator.py | 10 ++-- .../test_frontend_from_exported_program.py | 48 ++++++++++++++++--- 2 files changed, 48 insertions(+), 10 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 47eb66621008..f7d54a6216a7 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1628,10 +1628,12 @@ def _std(self, node: fx.Node) -> relax.Var: def _sum(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) - keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False - if len(args) == 1: - return self.block_builder.emit(relax.op.sum(args[0], keepdims=keepdim)) - return self.block_builder.emit(relax.op.sum(args[0], args[1])) + x = args[0] + dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) + if isinstance(dim, (list, tuple)) and len(dim) == 0: + dim = None + keepdim = args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False) + return self.block_builder.emit(relax.op.sum(x, dim, keepdims=keepdim)) def _var(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 01e16e7564ac..4a84b50cc9d9 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -4945,6 +4945,14 @@ class Sum(Module): def forward(self, x): return torch.sum(x, (2, 1)) + class SumKeepDim(Module): + def forward(self, x): + return torch.sum(x, (2, 1), keepdim=True) + + class SumWithoutDim(Module): + def forward(self, x): + return torch.sum(x) + @tvm.script.ir_module class expected1: @R.function @@ -4958,8 +4966,36 @@ def main( R.output(gv) return gv + @tvm.script.ir_module + class expected2: + @R.function + def main( + inp_0: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 1, 1, 4), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 1, 1, 4), dtype="float32") = R.sum( + inp_0, axis=[2, 1], keepdims=True + ) + gv: R.Tuple(R.Tensor((1, 1, 1, 4), dtype="float32")) = (lv,) + R.output(gv) + return gv + + @tvm.script.ir_module + class expected3: + @R.function + def main( + inp_0: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tuple(R.Tensor((), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((), dtype="float32") = R.sum(inp_0, axis=None, keepdims=False) + gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,) + R.output(gv) + return gv + example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),) verify_model(Sum(), example_args, {}, expected1) + verify_model(SumKeepDim(), example_args, {}, expected2) + verify_model(SumWithoutDim(), example_args, {}, expected3) def test_argmax_argmin(): @@ -7840,7 +7876,7 @@ def forward(self, x): @tvm.script.ir_module class Expected1: @R.function - def main(x: R.Tensor((4, 3), dtype="float32")) -> R.Tuple(R.Tensor((4,), dtype="float32")): + def main(x: R.Tensor((4, 3), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32")): with R.dataflow(): lv: R.Tensor((4, 3), dtype="float32") = R.astype(x, dtype="float32") lv1: R.Tensor((4, 3), dtype="float32") = R.nn.log_softmax(lv, axis=1) @@ -7863,11 +7899,11 @@ def main(x: R.Tensor((4, 3), dtype="float32")) -> R.Tuple(R.Tensor((4,), dtype=" lv12: R.Tensor((4,), dtype="bool") = R.not_equal( R.const([0, 1, 2, 1], dtype="int64"), R.const(-100, "int64") ) - lv13: R.Tensor((4,), dtype="bool") = R.sum(lv12, axis=[], keepdims=False) - lv14: R.Tensor((4,), dtype="float32") = R.astype(lv13, dtype="float32") - lv15: R.Tensor((4,), dtype="float32") = R.sum(lv11, axis=[], keepdims=False) - lv16: R.Tensor((4,), dtype="float32") = R.divide(lv15, lv14) - gv: R.Tuple(R.Tensor((4,), dtype="float32")) = (lv16,) + lv13: R.Tensor((), dtype="bool") = R.sum(lv12, axis=None, keepdims=False) + lv14: R.Tensor((), dtype="float32") = R.astype(lv13, dtype="float32") + lv15: R.Tensor((), dtype="float32") = R.sum(lv11, axis=None, keepdims=False) + lv16: R.Tensor((), dtype="float32") = R.divide(lv15, lv14) + gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv16,) R.output(gv) return gv From f2930d5bb14eac4c3984c413da5a069eab98fd20 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Tue, 16 Dec 2025 13:10:54 +0900 Subject: [PATCH 02/11] [LLVM][Codegen] Avoid segfault when `arith::GetVScaleValues` returns empty vector (#18586) As per title. --- src/target/llvm/codegen_aarch64.cc | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/target/llvm/codegen_aarch64.cc b/src/target/llvm/codegen_aarch64.cc index adac65914469..872e4f4cd110 100644 --- a/src/target/llvm/codegen_aarch64.cc +++ b/src/target/llvm/codegen_aarch64.cc @@ -59,9 +59,11 @@ void CodeGenAArch64::SetTargetAttributes(llvm::Function* func) { // Add vscale_range() function attribute when appropriate. if (llvm_target_->TargetHasCPUFeature("sve") || llvm_target_->TargetHasCPUFeature("sme")) { auto kVScaleValues = arith::GetVScaleValues(Target::Current()); - unsigned int max_val = *std::max_element(kVScaleValues.begin(), kVScaleValues.end()); - func->addFnAttr( - llvm::Attribute::getWithVScaleRangeArgs(*llvm_target_->GetContext(), 1, max_val)); + if (!kVScaleValues.empty()) { + unsigned int max_val = *std::max_element(kVScaleValues.begin(), kVScaleValues.end()); + func->addFnAttr( + llvm::Attribute::getWithVScaleRangeArgs(*llvm_target_->GetContext(), 1, max_val)); + } } #endif CodeGenCPU::SetTargetAttributes(func); From d375f7483a6b46ee9ee77fdcdf5ae192f77f15c5 Mon Sep 17 00:00:00 2001 From: Dayuxiaoshui <158081477+Dayuxiaoshui@users.noreply.github.com> Date: Wed, 17 Dec 2025 10:33:17 +0800 Subject: [PATCH 03/11] =?UTF-8?q?Fix=20ACOS=20precision=20issue=20for=20bo?= =?UTF-8?q?undary=20values=20(x=3D=C2=B11.0)=20(#18582)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The ACOS operator was producing incorrect results for boundary values due to poor precision of ASIN's Taylor series expansion near x=±1.0. Root cause: - ASIN used a 6-term Taylor series that converges slowly near boundaries - ACOS was implemented as acos(x) = π/2 - asin(x), inheriting ASIN errors - At x=1.0, ASIN error of 0.354874 (22.6%) caused ACOS to output 0.354874 instead of 0.0 Solution: - Modified ASIN to use system library function (asinf) for |x| >= 0.9 - Modified ACOS to use system library function (acosf) for |x| >= 0.9 - For |x| < 0.9, continue using Taylor series (accurate in this range) This ensures high precision for boundary values while maintaining the existing behavior for values in the middle range. Fixes #18580 --- src/target/llvm/intrin_rule_llvm.cc | 34 ++++++++++++--- tests/python/tir-base/test_tir_intrin.py | 53 ++++++++++++++++++++++++ 2 files changed, 82 insertions(+), 5 deletions(-) diff --git a/src/target/llvm/intrin_rule_llvm.cc b/src/target/llvm/intrin_rule_llvm.cc index 4ce7ce9f2291..a8a3d911ca8e 100644 --- a/src/target/llvm/intrin_rule_llvm.cc +++ b/src/target/llvm/intrin_rule_llvm.cc @@ -167,9 +167,15 @@ TVM_REGISTER_OP("tir.sinh") TVM_REGISTER_OP("tir.asin") .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { using tir::make_const; + using namespace intrin; const tir::CallNode* call = e.as(); ICHECK(call != nullptr); const PrimExpr& x = call->args[0]; + + PrimExpr threshold = make_const(x.dtype(), 0.5); + PrimExpr abs_x = tvm::abs(x); + PrimExpr use_lib = abs_x >= threshold; + PrimExpr x2 = x * x; PrimExpr term1 = x; PrimExpr term3 = term1 * x2 / make_const(x.dtype(), 6); @@ -178,25 +184,43 @@ TVM_REGISTER_OP("tir.asin") PrimExpr term9 = term7 * x2 * make_const(x.dtype(), 1225) / make_const(x.dtype(), 3456); PrimExpr term11 = term9 * x2 * make_const(x.dtype(), 3969) / make_const(x.dtype(), 28160); PrimExpr series = term1 + term3 + term5 + term7 + term9 + term11; - /* --- domain limit check --- */ + + PrimExpr lib_result = + ::tvm::codegen::intrin::DispatchPureExtern<::tvm::codegen::intrin::FloatSuffix>(e); + PrimExpr lower = make_const(x.dtype(), -1.0); PrimExpr upper = make_const(x.dtype(), 1.0); PrimExpr out_range = tir::Or(x upper); - // Use a quiet NaN constant PrimExpr nan_const = make_const(x.dtype(), std::numeric_limits::quiet_NaN()); - // select: if out of [-1,1] → NaN, else → series - return tir::Select(out_range, nan_const, series); + + return tir::Select(out_range, nan_const, tir::Select(use_lib, lib_result, series)); }); TVM_REGISTER_OP("tir.acos") .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { using tir::make_const; + using namespace intrin; const tir::CallNode* call = e.as(); ICHECK(call != nullptr) << "Invalid call node in acos legalization"; const PrimExpr& x = call->args[0]; + + PrimExpr threshold = make_const(x.dtype(), 0.5); + PrimExpr abs_x = tvm::abs(x); + PrimExpr use_lib = abs_x >= threshold; + PrimExpr half_pi = make_const(x.dtype(), M_PI / 2); PrimExpr asin_x = asin(x); - return half_pi - asin_x; + PrimExpr formula_result = half_pi - asin_x; + + PrimExpr lib_result = + ::tvm::codegen::intrin::DispatchPureExtern<::tvm::codegen::intrin::FloatSuffix>(e); + + PrimExpr lower = make_const(x.dtype(), -1.0); + PrimExpr upper = make_const(x.dtype(), 1.0); + PrimExpr out_range = tir::Or(x upper); + PrimExpr nan_const = make_const(x.dtype(), std::numeric_limits::quiet_NaN()); + + return tir::Select(out_range, nan_const, tir::Select(use_lib, lib_result, formula_result)); }); TVM_REGISTER_OP("tir.atan") diff --git a/tests/python/tir-base/test_tir_intrin.py b/tests/python/tir-base/test_tir_intrin.py index 8dabdbb344f3..1e8c88e08e65 100644 --- a/tests/python/tir-base/test_tir_intrin.py +++ b/tests/python/tir-base/test_tir_intrin.py @@ -135,6 +135,58 @@ def run_test(tvm_intrin, np_func, atol=1e-5, rtol=1e-5): run_test(*func, atol, rtol) +def test_asin_acos_boundary_values(): + """Test asin and acos with boundary values and threshold switching.""" + test_funcs = [ + (tvm.tir.asin, lambda x: np.arcsin(x)), + (tvm.tir.acos, lambda x: np.arccos(x)), + ] + + def run_test(tvm_intrin, np_func): + m = te.var("m") + A = te.placeholder((m,), name="A") + B = te.compute((m,), lambda *i: tvm_intrin(A(*i)), name="B") + + mod = te.create_prim_func([A, B]) + sch = tir.Schedule(mod) + func = tvm.compile(sch.mod, target="llvm") + + dev = tvm.cpu(0) + + # Test boundary values: ±1.0 (should use system library) + boundary_values = np.array([1.0, -1.0], dtype=np.float32) + a1 = tvm.runtime.tensor(boundary_values, dev) + b1 = tvm.runtime.tensor(np.empty_like(boundary_values), dev) + func(a1, b1) + tvm.testing.assert_allclose(b1.numpy(), np_func(boundary_values), atol=1e-5, rtol=1e-5) + + # Test values at threshold: ±0.5 (should use system library) + threshold_values = np.array([0.5, -0.5], dtype=np.float32) + a2 = tvm.runtime.tensor(threshold_values, dev) + b2 = tvm.runtime.tensor(np.empty_like(threshold_values), dev) + func(a2, b2) + tvm.testing.assert_allclose(b2.numpy(), np_func(threshold_values), atol=1e-4, rtol=1e-4) + + # Test values just below threshold: ±0.49 (should use Taylor series) + below_threshold_values = np.array([0.49, -0.49, 0.3, -0.3, 0.0], dtype=np.float32) + a3 = tvm.runtime.tensor(below_threshold_values, dev) + b3 = tvm.runtime.tensor(np.empty_like(below_threshold_values), dev) + func(a3, b3) + tvm.testing.assert_allclose( + b3.numpy(), np_func(below_threshold_values), atol=1e-3, rtol=1e-3 + ) + + # Test out-of-domain values: should return NaN + out_of_domain = np.array([1.1, -1.1, 2.0, -2.0], dtype=np.float32) + a4 = tvm.runtime.tensor(out_of_domain, dev) + b4 = tvm.runtime.tensor(np.empty_like(out_of_domain), dev) + func(a4, b4) + assert np.all(np.isnan(b4.numpy())), "Out-of-domain inputs should return NaN" + + for func in test_funcs: + run_test(*func) + + def test_binary_intrin(): test_funcs = [ (tvm.tir.atan2, lambda x1, x2: np.arctan2(x1, x2)), @@ -315,6 +367,7 @@ def test_fma(): test_nearbyint() test_unary_intrin() test_round_intrinsics_on_int() + test_asin_acos_boundary_values() test_binary_intrin() test_ldexp() test_clz() From 5989ef57a6c9e845b2ab4851195de3e0f7997304 Mon Sep 17 00:00:00 2001 From: ping-ee <82318236+ping-ee@users.noreply.github.com> Date: Wed, 17 Dec 2025 15:32:11 +0800 Subject: [PATCH 04/11] [BugFix][OpenCL] Guard QCOM perf hint behind USE_OPENCL_EXTN_QCOM to avoid undefined symbol on non-QCOM runtimes (#18589) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR is a re-open of #18581 The previous PR was created while Jenkins CI was experiencing a disk space issue and the CI job did not trigger. ## PR Description Recent OpenCL-Headers update (https://github.com/KhronosGroup/OpenCL-Headers/pull/277 ) added QCOM perf-hint definitions (`CL_CONTEXT_PERF_HINT_QCOM`, `clSetPerfHintQCOM`) to `cl_ext.h`. These macros are now defined even on platforms whose OpenCL runtimes (e.g., PoCL, ICD loaders) do not implement the QCOM extension. TVM previously enabled the perf-hint code path solely based on the presence of `CL_CONTEXT_PERF_HINT_QCOM`, causing link errors such as: ``` undefined symbol: clSetPerfHintQCOM ``` This PR guards the QCOM perf-hint logic behind `USE_OPENCL_EXTN_QCOM`, matching the behavior of other QCOM-specific OpenCL paths (e.g., `SetNativePtr`). ## Effects Prevents accidental linking against unsupported QCOM symbols on non-QCOM runtimes. Keeps QCOM builds fully functional when `USE_OPENCL_EXTN_QCOM` is explicitly enabled. Aligns TVM’s extension handling across OpenCL code paths. --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- src/runtime/opencl/opencl_device_api.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/runtime/opencl/opencl_device_api.cc b/src/runtime/opencl/opencl_device_api.cc index 8b6fba24988e..63e0890f70c2 100644 --- a/src/runtime/opencl/opencl_device_api.cc +++ b/src/runtime/opencl/opencl_device_api.cc @@ -472,7 +472,7 @@ void OpenCLWorkspace::SetNativePtr(const tvm::runtime::Tensor& narr, void* host_ } void OpenCLWorkspace::SetPerfHint(Device dev, cl_uint perf_hint) { -#ifdef CL_CONTEXT_PERF_HINT_QCOM +#if defined(USE_OPENCL_EXTN_QCOM) && defined(CL_CONTEXT_PERF_HINT_QCOM) cl_device_id device_id = GetCLDeviceID(dev.device_id); auto platform = device_info[device_id].platform_id; OPENCL_CALL(clSetPerfHintQCOM(this->contexts[platform], perf_hint)); From 44d973b0aa939307a36aef1011de30833837c664 Mon Sep 17 00:00:00 2001 From: "Guan-Ming (Wesley) Chiu" <105915352+guan404ming@users.noreply.github.com> Date: Thu, 18 Dec 2025 11:56:20 +0800 Subject: [PATCH 05/11] [Relax] Add layout inference support for repeat operator (#18579) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## How - Implemented InferLayoutRepeat function that: - Preserves layout when axis is specified (with axis transformation) - Returns 1D layout when axis is not specified (flatten mode) - Transforms the axis parameter based on layout changes (e.g., NCHW axis=1 → NHWC axis=3) --- src/relax/op/tensor/manipulate.cc | 60 ++++++++++++- .../relax/test_transform_convert_layout.py | 85 +++++++++++++++++++ 2 files changed, 144 insertions(+), 1 deletion(-) diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 0310c7f46b0d..493198fbd091 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -1805,12 +1805,70 @@ StructInfo InferStructInfoRepeat(const Call& call, const BlockBuilder& ctx) { return TensorStructInfo(ShapeExpr(shape_array), data_sinfo->dtype, data_sinfo->vdevice); } -// TODO(relax-team): implement FRelaxInferLayout for repeat +InferLayoutOutput InferLayoutRepeat( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + ICHECK(NoDesiredLayout(call, desired_layouts)); + + const auto* attrs = call->attrs.as(); + ICHECK(attrs != nullptr) << "Invalid Call"; + const auto* tensor_sinfo = GetStructInfoAs(call->args[0]); + ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; + ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now"; + + LayoutDecision existing_layout = GetLayoutDecision(var_layout_map, call->args[0]); + int ndim = tensor_sinfo->ndim; + + // Can't handle sub indexed layouts. + if (existing_layout->layout.ndim() != existing_layout->layout.ndim_primal()) { + existing_layout = LayoutDecision(InitialLayout(ndim)); + } + + // When axis is not specified, the output is 1D (flattened) + if (!attrs->axis.has_value()) { + return InferLayoutOutput({existing_layout}, {InitialLayoutDecision(1)}, Attrs(call->attrs)); + } + + // Transform the axis based on the layout + int axis = attrs->axis.value(); + if (axis < 0) { + axis += ndim; + } + + // Create a mapping from original layout to existing layout + std::string axis_str(ndim, '0'); + axis_str[axis] = '1'; + for (int i = 0, j = 0; i < ndim; ++i) { + if (axis_str[i] != '1') { + axis_str[i] = 'A' + j++; + } + } + + ffi::String new_axis_str = + TransposeStrLike(axis_str, InitialLayout(ndim), existing_layout->layout); + + int64_t new_axis = -1; + for (size_t i = 0; i < new_axis_str.size(); ++i) { + if (new_axis_str.at(i) == '1') { + new_axis = i; + break; + } + } + ICHECK_GE(new_axis, 0) << "Failed to find transformed axis"; + + ObjectPtr new_attrs = ffi::make_object(*attrs); + new_attrs->axis = new_axis; + + // When axis is specified, the layout is preserved + return InferLayoutOutput({existing_layout}, {existing_layout}, Attrs(new_attrs)); +} + TVM_REGISTER_OP("relax.repeat") .set_attrs_type() .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") .set_attr("FInferStructInfo", InferStructInfoRepeat) + .set_attr("FRelaxInferLayout", InferLayoutRepeat) .set_attr("FPurity", Bool(true)); /* relax.tile */ diff --git a/tests/python/relax/test_transform_convert_layout.py b/tests/python/relax/test_transform_convert_layout.py index 83b81a6898a7..95f043ef6629 100644 --- a/tests/python/relax/test_transform_convert_layout.py +++ b/tests/python/relax/test_transform_convert_layout.py @@ -4992,5 +4992,90 @@ def main( verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) +def test_conv2d_repeat(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 8, 26, 26), "float32") = R.repeat(gv, repeats=2, axis=1) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32") + ) -> R.Tensor(None, dtype="float32", ndim=4): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + lv2: R.Tensor((2, 26, 26, 8), dtype="float32") = R.repeat(gv, repeats=2, axis=3) + gv2: R.Tensor((2, 8, 26, 26), dtype="float32") = R.permute_dims( + lv2, axes=[0, 3, 1, 2] + ) + R.output(gv2) + return gv2 + + verify(Input, Expected) + + +def test_conv2d_repeat_flatten(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32") + ) -> R.Tensor((5408,), "float32"): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((5408,), "float32") = R.repeat(gv, repeats=1) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32") + ) -> R.Tensor((5408,), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + gv2: R.Tensor((5408,), dtype="float32") = R.repeat(gv, repeats=1) + R.output(gv2) + return gv2 + + verify(Input, Expected) + + if __name__ == "__main__": tvm.testing.main() From 1c209e27b7b0c62fcb37968382ffcd1612319eab Mon Sep 17 00:00:00 2001 From: "Guan-Ming (Wesley) Chiu" <105915352+guan404ming@users.noreply.github.com> Date: Sat, 20 Dec 2025 19:47:33 +0800 Subject: [PATCH 06/11] [Relax] Clean up scatter_elements unknown dtype handling (#18577) ## Why - LOG(WARNING) is the standard and correct approach throughout the TVM codebase - The existing pattern is used consistently in all relax ops (see test_op_manipulate.py, index.cc, etc.) - Added test coverage for previously untested scenarios --- src/relax/op/tensor/manipulate.cc | 2 -- tests/python/relax/test_op_manipulate.py | 14 ++++++++++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 493198fbd091..1aab52ac56a5 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -2456,7 +2456,6 @@ StructInfo InferStructInfoScatterElements(const Call& call, const BlockBuilder& if (data_sinfo->IsUnknownDtype() || updates_sinfo->IsUnknownDtype()) { auto diag_dtype = [&](const TensorStructInfoNode* sinfo, ffi::String name) { if (sinfo->IsUnknownDtype()) { - // TODO(tvm-team): Do we have an equivalent of `ctx->ReportFatal` for warning? LOG(WARNING) << "Data type of " << name << " has not been specified. Assume it has an integer type."; } @@ -2473,7 +2472,6 @@ StructInfo InferStructInfoScatterElements(const Call& call, const BlockBuilder& } if (indices_sinfo->IsUnknownDtype()) { - // TODO(tvm-team): Do we have an equivalent of `ctx->ReportFatal` for warning? LOG(WARNING) << "Data type of indice has not been specified. Assume it has an integer type."; } else if (!(indices_sinfo->dtype.is_int() || indices_sinfo->dtype.is_uint())) { ctx->ReportFatal( diff --git a/tests/python/relax/test_op_manipulate.py b/tests/python/relax/test_op_manipulate.py index d39584e06ba8..6a73a84fd8cd 100644 --- a/tests/python/relax/test_op_manipulate.py +++ b/tests/python/relax/test_op_manipulate.py @@ -3417,6 +3417,20 @@ def test_scatter_elements_infer_struct_info(): relax.op.scatter_elements(d2, i3, u0, 0, "updates"), relax.TensorStructInfo(dtype="float32", ndim=-1), ) + # Test with unknown dtype for data + d_unknown = relax.Var("data", R.Tensor((4, 4))) + _check_inference( + bb, + relax.op.scatter_elements(d_unknown, i0, u0, 0, "updates"), + relax.TensorStructInfo((4, 4), dtype=""), + ) + # Test with unknown dtype for updates + u_unknown = relax.Var("updates", R.Tensor((2, 2))) + _check_inference( + bb, + relax.op.scatter_elements(d0, i0, u_unknown, 0, "updates"), + relax.TensorStructInfo((4, 4), dtype="float32"), + ) def test_scatter_elements_infer_struct_info_symbolic_shape(): From f4e28d3153323ad97a7e74740c9fb22300fd6cd0 Mon Sep 17 00:00:00 2001 From: Neo Chien <6762509+cchung100m@users.noreply.github.com> Date: Sun, 21 Dec 2025 20:14:21 +0800 Subject: [PATCH 07/11] [Relax] Chore: Fix the DeprecationWarning: invalid escape sequence \ (#18591) Hi @mshr-h @tlopex, This PR is trying to fix issue: DeprecationWarning: invalid escape sequence `\` Any suggestions would be appreciated if you are available. ### Root Cause The backslashes(`\`) inside the docstring image ### Solution Use a raw docstring(`r"""`) Co-authored-by: cchung100m --- .../adreno/test_transform_annotate_custom_scope.py | 12 ++++++------ tests/python/relax/test_transform_convert_layout.py | 6 +++--- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/python/relax/adreno/test_transform_annotate_custom_scope.py b/tests/python/relax/adreno/test_transform_annotate_custom_scope.py index 24b4cf66b888..2c0b7073a119 100644 --- a/tests/python/relax/adreno/test_transform_annotate_custom_scope.py +++ b/tests/python/relax/adreno/test_transform_annotate_custom_scope.py @@ -808,7 +808,7 @@ def main( def test_residual_block(): - """ + r""" - some kind of residual block followed by convolution to have texture after residual block - scalar data type verification which should be mapped to global memory scope layout_transform (NCHW->NCHW4c) @@ -874,7 +874,7 @@ def main( def test_conv2d_conv2d_fallback_to_buffer_conv2d(): - """ + r""" layout_transform (NCHW->NCHW4c) | <- texture conv2d (1) <- textures as output @@ -931,7 +931,7 @@ def main( def test_conv2d_conv2d_conv2d_concat(): - """ + r""" layout_transform (NCHW->NCHW4c) | <- texture conv2d (1) <- textures as output @@ -991,7 +991,7 @@ def main( def test_pooling_branching_texture_params(): - """ + r""" Verification of the pooling and many branches having textures layout_transform (NCHW->NCHW4c) | <- texture @@ -1066,7 +1066,7 @@ def main( def test_injective_inputs1(): - """ + r""" Input / \ / | @@ -1133,7 +1133,7 @@ def main( def test_injective_nwo_inputs2(): - """ + r""" Input / \ | \ diff --git a/tests/python/relax/test_transform_convert_layout.py b/tests/python/relax/test_transform_convert_layout.py index 95f043ef6629..a53b5db246e6 100644 --- a/tests/python/relax/test_transform_convert_layout.py +++ b/tests/python/relax/test_transform_convert_layout.py @@ -4585,7 +4585,7 @@ def main( def test_conv2d_conv2d_conv2d_concat(): - """ + r""" layout_transform (NCHW->NCHW4c) | <- texture conv2d (1) <- textures as output @@ -4713,7 +4713,7 @@ def main( def test_conv2d_conv2d_callback_to_buffer_conv2d_concat(): - """ + r""" layout_transform (NCHW->NCHW4c) | <- texture conv2d (1) <- textures as output @@ -4841,7 +4841,7 @@ def main( def test_pooling_branching_texture_params(): - """ + r""" Verification of the pooling and many branches having textures layout_transform (NCHW->NCHW4c) | <- texture From 69ccf8b17af764a7c1d636c104d80072a3e42171 Mon Sep 17 00:00:00 2001 From: Nguyen Duy Loc <77536430+locnd182644@users.noreply.github.com> Date: Mon, 22 Dec 2025 23:25:15 +0700 Subject: [PATCH 08/11] [Relax][Torch] AssertionError: Unsupported function types ['mean.default'] (#18574) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Happen error when create module from exported_program have torch.mean without dim. ## Reproduce - Module: ``` class MeanModule(nn.Module): def forward(self, x): return torch.mean(x) ... # Export → Relax ep = torch_export(m, (x,)) mod = from_exported_program(ep) ``` - Error log: ``` --------------------------------------------------------------------------- AssertionError Traceback (most recent call last) Cell In[2], line 13 11 # Export → Relax 12 ep = torch_export(m, (x,)) ---> 13 mod = from_exported_program(ep) 15 mod.show() 17 target = "llvm" File ~/Programming/tvm/python/tvm/relax/frontend/torch/exported_program_translator.py:1783, in from_exported_program(exported_program, keep_params_as_input, unwrap_unit_return_tuple, no_bind_return_tuple, run_ep_decomposition) 1780 if run_ep_decomposition: 1781 exported_program = exported_program.run_decompositions() -> 1783 return ExportedProgramImporter().from_exported_program( 1784 exported_program, 1785 keep_params_as_input, 1786 unwrap_unit_return_tuple, 1787 no_bind_return_tuple, 1788 ) File ~/Programming/tvm/python/tvm/relax/frontend/torch/exported_program_translator.py:1642, in ExportedProgramImporter.from_exported_program(self, exported_program, keep_params_as_input, unwrap_unit_return_tuple, no_bind_return_tuple) 1639 nodes: List[fx.Node] = exported_program.graph.nodes 1641 # Find all the missing function types -> 1642 self._check_unsupported_func_type(nodes) 1644 with self.block_builder.function( 1645 name=func_name, params=list(inputs_vars.values()).copy(), attrs=func_attrs 1646 ): 1647 output = None File ~/Programming/tvm/python/tvm/relax/frontend/torch/base_fx_graph_translator.py:182, in BaseFXGraphImporter._check_unsupported_func_type(self, nodes) 174 def _check_unsupported_func_type(self, nodes: List[fx.Node]): 175 missing_func_types = list( 176 { 177 node.target.__name__ (...) 180 } 181 ) --> 182 assert not missing_func_types, f"Unsupported function types {missing_func_types}" AssertionError: Unsupported function types ['mean.default'] ``` ## Resolve: - Add "mean.default" into create_convert_map in class ExportedProgramImporter. --- .../torch/exported_program_translator.py | 1 + .../test_frontend_from_exported_program.py | 17 +++++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 3d6a632fb20f..94df0282c870 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -1371,6 +1371,7 @@ def create_convert_map( "any.dim": self._any, "any.dims": self._any, "mean.dim": self._mean, + "mean.default": self._mean, "prod.default": self._prod, "std.correction": self._std, "sum.default": self._sum, diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 4a84b50cc9d9..7894a9fb6d81 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -4911,6 +4911,10 @@ class MeanKeepDim(Module): def forward(self, input: torch.Tensor): return input.mean(-1, keepdim=True) + class MeanWithoutDim(Module): + def forward(self, input: torch.Tensor): + return input.mean() + @I.ir_module class Expected1: @R.function @@ -4935,9 +4939,22 @@ def main( R.output(gv) return gv + @I.ir_module + class Expected3: + @R.function + def main( + inp_0: R.Tensor((256, 256), dtype="float32") + ) -> R.Tuple(R.Tensor((), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((), dtype="float32") = R.mean(inp_0, axis=None, keepdims=False) + gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,) + R.output(gv) + return gv + example_args = (torch.randn(256, 256, dtype=torch.float32),) verify_model(Mean(), example_args, {}, Expected1) verify_model(MeanKeepDim(), example_args, {}, Expected2) + verify_model(MeanWithoutDim(), example_args, {}, Expected3) def test_sum(): From 3a1076565f95fdff4bfabe0a80385e0d3363d2c3 Mon Sep 17 00:00:00 2001 From: senhtry Date: Tue, 23 Dec 2025 17:29:06 +0800 Subject: [PATCH 09/11] [CUDA][FFI] Add support for Programmatic Dependent Kernel Launch (PDL) in TVM CUDA FFI --- src/runtime/cuda/cuda_module.cc | 39 ++++++++++++++++++++++++++++----- src/runtime/file_utils.cc | 7 ++++++ src/runtime/meta_data.h | 4 ++-- 3 files changed, 43 insertions(+), 7 deletions(-) diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc index 3fee6b55f2e5..74f6bad2d95b 100644 --- a/src/runtime/cuda/cuda_module.cc +++ b/src/runtime/cuda/cuda_module.cc @@ -173,10 +173,12 @@ class CUDAWrappedFunc { public: // initialize the CUDA function. void Init(CUDAModuleNode* m, ObjectPtr sptr, const std::string& func_name, - size_t num_void_args, const std::vector& launch_param_tags) { + size_t num_void_args, const std::vector& launch_param_tags, + bool has_programmatic_dependent_launch) { m_ = m; sptr_ = sptr; func_name_ = func_name; + has_programmatic_dependent_launch_ = has_programmatic_dependent_launch; std::fill(fcache_.begin(), fcache_.end(), nullptr); launch_param_config_.Init(num_void_args, launch_param_tags); } @@ -200,9 +202,33 @@ class CUDAWrappedFunc { } } CUstream strm = static_cast(TVMFFIEnvGetStream(kDLCUDA, device_id)); - CUresult result = cuLaunchKernel(fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1), - wl.grid_dim(2), wl.block_dim(0), wl.block_dim(1), - wl.block_dim(2), wl.dyn_shmem_size, strm, void_args, nullptr); + + CUresult result; + + if (has_programmatic_dependent_launch_) { + CUlaunchConfig config; + CUlaunchAttribute attribute[1]; + attribute[0].id = CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_STREAM_SERIALIZATION; + attribute[0].value.programmaticStreamSerializationAllowed = 1; + + config.attrs = attribute; + config.numAttrs = 1; + config.hStream = strm; + config.gridDimX = wl.grid_dim(0); + config.gridDimY = wl.grid_dim(1); + config.gridDimZ = wl.grid_dim(2); + config.blockDimX = wl.block_dim(0); + config.blockDimY = wl.block_dim(1); + config.blockDimZ = wl.block_dim(2); + config.sharedMemBytes = wl.dyn_shmem_size; + + result = cuLaunchKernelEx(&config, fcache_[device_id], void_args, nullptr); + } else { + result = cuLaunchKernel(fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2), + wl.block_dim(0), wl.block_dim(1), wl.block_dim(2), wl.dyn_shmem_size, + strm, void_args, nullptr); + } + if (result != CUDA_SUCCESS && result != CUDA_ERROR_DEINITIALIZED) { const char* msg; cuGetErrorName(result, &msg); @@ -234,6 +260,8 @@ class CUDAWrappedFunc { mutable std::array fcache_; // launch parameters configuration LaunchParamConfig launch_param_config_; + // have pdl setting + bool has_programmatic_dependent_launch_; }; class CUDAPrepGlobalBarrier { @@ -271,7 +299,8 @@ ffi::Optional CUDAModuleNode::GetFunction(const ffi::String& name if (it == fmap_.end()) return ffi::Function(); const FunctionInfo& info = it->second; CUDAWrappedFunc f; - f.Init(this, sptr_to_self, name, info.arg_types.size(), info.launch_param_tags); + f.Init(this, sptr_to_self, name, info.arg_types.size(), info.launch_param_tags, + info.has_programmatic_dependent_launch); return PackFuncVoidAddr(f, info.arg_types, info.arg_extra_tags); } diff --git a/src/runtime/file_utils.cc b/src/runtime/file_utils.cc index b3733ee6fdff..4c705e8a43a9 100644 --- a/src/runtime/file_utils.cc +++ b/src/runtime/file_utils.cc @@ -51,6 +51,8 @@ void FunctionInfo::Save(dmlc::JSONWriter* writer) const { iarg_extra_tags[i] = static_cast(arg_extra_tags[i]); } writer->WriteObjectKeyValue("arg_extra_tags", iarg_extra_tags); + writer->WriteObjectKeyValue("has_programmatic_dependent_launch", + has_programmatic_dependent_launch); writer->EndObject(); } @@ -64,6 +66,9 @@ void FunctionInfo::Load(dmlc::JSONReader* reader) { &launch_param_tags); // for backward compatibility std::vector iarg_extra_tags; helper.DeclareOptionalField("arg_extra_tags", &iarg_extra_tags); + helper.DeclareOptionalField("has_programmatic_dependent_launch", + &has_programmatic_dependent_launch); + arg_extra_tags.resize(iarg_extra_tags.size()); for (size_t i = 0; i < arg_extra_tags.size(); ++i) { arg_extra_tags[i] = static_cast(iarg_extra_tags[i]); @@ -80,6 +85,7 @@ void FunctionInfo::Save(dmlc::Stream* writer) const { writer->Write(arg_types); writer->Write(launch_param_tags); writer->Write(arg_extra_tags); + writer->Write(has_programmatic_dependent_launch); } bool FunctionInfo::Load(dmlc::Stream* reader) { @@ -87,6 +93,7 @@ bool FunctionInfo::Load(dmlc::Stream* reader) { if (!reader->Read(&arg_types)) return false; if (!reader->Read(&launch_param_tags)) return false; if (!reader->Read(&arg_extra_tags)) return false; + if (!reader->Read(&has_programmatic_dependent_launch)) return false; return true; } diff --git a/src/runtime/meta_data.h b/src/runtime/meta_data.h index 85b83289f4d3..b43c2c9a24b5 100644 --- a/src/runtime/meta_data.h +++ b/src/runtime/meta_data.h @@ -45,10 +45,8 @@ inline ffi::String get_name_mangled(const ffi::String& module_name, const ffi::S } namespace launch_param { - /*! \brief A tag to specify whether or not dynamic shared memory is used */ constexpr const char* kUseDynamicSharedMemoryTag = "tir.use_dyn_shared_memory"; - } // namespace launch_param /*! \brief function information needed by device */ @@ -60,6 +58,8 @@ struct FunctionInfo { enum class ArgExtraTags : int { kNone = 0, kTensorMap = 1 }; std::vector arg_extra_tags; + bool has_programmatic_dependent_launch; + void Save(dmlc::JSONWriter* writer) const; void Load(dmlc::JSONReader* reader); void Save(dmlc::Stream* writer) const; From c3d28b20f09afa565eda664300c9c1ca943d701f Mon Sep 17 00:00:00 2001 From: senhtry Date: Wed, 24 Dec 2025 11:21:15 +0800 Subject: [PATCH 10/11] tir: add launch param tag for programmatic dependent launch --- src/runtime/cuda/cuda_module.cc | 16 +++++----------- src/runtime/file_utils.cc | 7 ------- src/runtime/meta_data.h | 6 ++++-- src/runtime/thread_storage_scope.h | 6 ++++++ 4 files changed, 15 insertions(+), 20 deletions(-) diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc index 74f6bad2d95b..a2abec98f8e6 100644 --- a/src/runtime/cuda/cuda_module.cc +++ b/src/runtime/cuda/cuda_module.cc @@ -173,12 +173,10 @@ class CUDAWrappedFunc { public: // initialize the CUDA function. void Init(CUDAModuleNode* m, ObjectPtr sptr, const std::string& func_name, - size_t num_void_args, const std::vector& launch_param_tags, - bool has_programmatic_dependent_launch) { + size_t num_void_args, const std::vector& launch_param_tags) { m_ = m; sptr_ = sptr; func_name_ = func_name; - has_programmatic_dependent_launch_ = has_programmatic_dependent_launch; std::fill(fcache_.begin(), fcache_.end(), nullptr); launch_param_config_.Init(num_void_args, launch_param_tags); } @@ -202,12 +200,11 @@ class CUDAWrappedFunc { } } CUstream strm = static_cast(TVMFFIEnvGetStream(kDLCUDA, device_id)); - CUresult result; - if (has_programmatic_dependent_launch_) { - CUlaunchConfig config; - CUlaunchAttribute attribute[1]; + if (launch_param_config_.use_programtic_dependent_launch()) { + CUlaunchConfig config{}; + CUlaunchAttribute attribute[1]{}; attribute[0].id = CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_STREAM_SERIALIZATION; attribute[0].value.programmaticStreamSerializationAllowed = 1; @@ -260,8 +257,6 @@ class CUDAWrappedFunc { mutable std::array fcache_; // launch parameters configuration LaunchParamConfig launch_param_config_; - // have pdl setting - bool has_programmatic_dependent_launch_; }; class CUDAPrepGlobalBarrier { @@ -299,8 +294,7 @@ ffi::Optional CUDAModuleNode::GetFunction(const ffi::String& name if (it == fmap_.end()) return ffi::Function(); const FunctionInfo& info = it->second; CUDAWrappedFunc f; - f.Init(this, sptr_to_self, name, info.arg_types.size(), info.launch_param_tags, - info.has_programmatic_dependent_launch); + f.Init(this, sptr_to_self, name, info.arg_types.size(), info.launch_param_tags); return PackFuncVoidAddr(f, info.arg_types, info.arg_extra_tags); } diff --git a/src/runtime/file_utils.cc b/src/runtime/file_utils.cc index 4c705e8a43a9..b3733ee6fdff 100644 --- a/src/runtime/file_utils.cc +++ b/src/runtime/file_utils.cc @@ -51,8 +51,6 @@ void FunctionInfo::Save(dmlc::JSONWriter* writer) const { iarg_extra_tags[i] = static_cast(arg_extra_tags[i]); } writer->WriteObjectKeyValue("arg_extra_tags", iarg_extra_tags); - writer->WriteObjectKeyValue("has_programmatic_dependent_launch", - has_programmatic_dependent_launch); writer->EndObject(); } @@ -66,9 +64,6 @@ void FunctionInfo::Load(dmlc::JSONReader* reader) { &launch_param_tags); // for backward compatibility std::vector iarg_extra_tags; helper.DeclareOptionalField("arg_extra_tags", &iarg_extra_tags); - helper.DeclareOptionalField("has_programmatic_dependent_launch", - &has_programmatic_dependent_launch); - arg_extra_tags.resize(iarg_extra_tags.size()); for (size_t i = 0; i < arg_extra_tags.size(); ++i) { arg_extra_tags[i] = static_cast(iarg_extra_tags[i]); @@ -85,7 +80,6 @@ void FunctionInfo::Save(dmlc::Stream* writer) const { writer->Write(arg_types); writer->Write(launch_param_tags); writer->Write(arg_extra_tags); - writer->Write(has_programmatic_dependent_launch); } bool FunctionInfo::Load(dmlc::Stream* reader) { @@ -93,7 +87,6 @@ bool FunctionInfo::Load(dmlc::Stream* reader) { if (!reader->Read(&arg_types)) return false; if (!reader->Read(&launch_param_tags)) return false; if (!reader->Read(&arg_extra_tags)) return false; - if (!reader->Read(&has_programmatic_dependent_launch)) return false; return true; } diff --git a/src/runtime/meta_data.h b/src/runtime/meta_data.h index b43c2c9a24b5..1a44c38b7c24 100644 --- a/src/runtime/meta_data.h +++ b/src/runtime/meta_data.h @@ -45,8 +45,12 @@ inline ffi::String get_name_mangled(const ffi::String& module_name, const ffi::S } namespace launch_param { + /*! \brief A tag to specify whether or not dynamic shared memory is used */ constexpr const char* kUseDynamicSharedMemoryTag = "tir.use_dyn_shared_memory"; +/*! \brief A tag to specify whether or not use programatic dependent launch */ +constexpr const char* kUseProgramaticDependentLaunch = "tir.use_programtic_dependent_launch"; + } // namespace launch_param /*! \brief function information needed by device */ @@ -58,8 +62,6 @@ struct FunctionInfo { enum class ArgExtraTags : int { kNone = 0, kTensorMap = 1 }; std::vector arg_extra_tags; - bool has_programmatic_dependent_launch; - void Save(dmlc::JSONWriter* writer) const; void Load(dmlc::JSONReader* reader); void Save(dmlc::Stream* writer) const; diff --git a/src/runtime/thread_storage_scope.h b/src/runtime/thread_storage_scope.h index 914fe67819de..e3b9dbf74c7c 100644 --- a/src/runtime/thread_storage_scope.h +++ b/src/runtime/thread_storage_scope.h @@ -247,6 +247,8 @@ class LaunchParamConfig { ICHECK_EQ(i, launch_param_tags.size() - 1) << "kUseDynamicSharedMemoryTag should be the last tag in launch_param_tags."; use_dyn_shared_memory_ = true; + } else if (tag == launch_param::kUseProgramaticDependentLaunch) { + use_programmatic_dependent_launch_ = true; } else { ThreadScope ts = ThreadScope::Create(tag); arg_index_map_.push_back(ts.rank * 3 + ts.dim_index); @@ -281,6 +283,8 @@ class LaunchParamConfig { // return the work dim size_t work_dim() const { return work_dim_; } + bool use_programtic_dependent_launch() const { return use_programmatic_dependent_launch_; } + private: /*! \brief base axis */ size_t base_; @@ -290,6 +294,8 @@ class LaunchParamConfig { std::vector arg_index_map_; /*! \brief Whether or not use dynamic shared memory. */ bool use_dyn_shared_memory_{false}; + /*! \brief Whether or not use programmatic dependent launch. */ + bool use_programmatic_dependent_launch_{false}; }; } // namespace runtime From 490f6a0c82ba74332fe1314586146129df49cd94 Mon Sep 17 00:00:00 2001 From: senhtry Date: Wed, 24 Dec 2025 17:05:17 +0800 Subject: [PATCH 11/11] tir: add param tag for cuLaunchCooperativeKernel --- src/runtime/cuda/cuda_module.cc | 4 ++++ src/runtime/meta_data.h | 2 ++ src/runtime/thread_storage_scope.h | 6 ++++++ 3 files changed, 12 insertions(+) diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc index a2abec98f8e6..f07996c68b36 100644 --- a/src/runtime/cuda/cuda_module.cc +++ b/src/runtime/cuda/cuda_module.cc @@ -220,6 +220,10 @@ class CUDAWrappedFunc { config.sharedMemBytes = wl.dyn_shmem_size; result = cuLaunchKernelEx(&config, fcache_[device_id], void_args, nullptr); + } else if (launch_param_config_.use_cooperative_launch()) { + result = cuLaunchCooperativeKernel(fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1), + wl.grid_dim(2), wl.block_dim(0), wl.block_dim(1), + wl.block_dim(2), wl.dyn_shmem_size, strm, void_args); } else { result = cuLaunchKernel(fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2), wl.block_dim(0), wl.block_dim(1), wl.block_dim(2), wl.dyn_shmem_size, diff --git a/src/runtime/meta_data.h b/src/runtime/meta_data.h index 1a44c38b7c24..aceb97b58374 100644 --- a/src/runtime/meta_data.h +++ b/src/runtime/meta_data.h @@ -50,6 +50,8 @@ namespace launch_param { constexpr const char* kUseDynamicSharedMemoryTag = "tir.use_dyn_shared_memory"; /*! \brief A tag to specify whether or not use programatic dependent launch */ constexpr const char* kUseProgramaticDependentLaunch = "tir.use_programtic_dependent_launch"; +/*! \brief A tag to specify whether or not use cooperative launch */ +constexpr const char* kUseCooperativeLaunch = "tir.use_cooperative_launch"; } // namespace launch_param diff --git a/src/runtime/thread_storage_scope.h b/src/runtime/thread_storage_scope.h index e3b9dbf74c7c..c2cd792220f5 100644 --- a/src/runtime/thread_storage_scope.h +++ b/src/runtime/thread_storage_scope.h @@ -249,6 +249,8 @@ class LaunchParamConfig { use_dyn_shared_memory_ = true; } else if (tag == launch_param::kUseProgramaticDependentLaunch) { use_programmatic_dependent_launch_ = true; + } else if (tag == launch_param::kUseCooperativeLaunch) { + use_cooperative_launch_ = true; } else { ThreadScope ts = ThreadScope::Create(tag); arg_index_map_.push_back(ts.rank * 3 + ts.dim_index); @@ -285,6 +287,8 @@ class LaunchParamConfig { bool use_programtic_dependent_launch() const { return use_programmatic_dependent_launch_; } + bool use_cooperative_launch() const { return use_cooperative_launch_; } + private: /*! \brief base axis */ size_t base_; @@ -296,6 +300,8 @@ class LaunchParamConfig { bool use_dyn_shared_memory_{false}; /*! \brief Whether or not use programmatic dependent launch. */ bool use_programmatic_dependent_launch_{false}; + /*! \brief Whether or not use cooperative launch. */ + bool use_cooperative_launch_{false}; }; } // namespace runtime