diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h index 6e2209d51950..291b79ea557d 100644 --- a/include/tvm/relax/analysis.h +++ b/include/tvm/relax/analysis.h @@ -533,6 +533,21 @@ TVM_DLL bool WellFormed(IRModule m, bool check_struct_info = true); TVM_DLL Map> SuggestLayoutTransforms( const Function& fn, Array write_buffer_transformations); +/* \brief Collect variables whose value can be computed at compile-time + * + * If a function has the `kNumInput` attribute, then the first + * `kNumInput` parameters are provided at run-time, while all + * remaining parameters may be known at compile-time. This utility + * collects all variable bindings that only depend, directly or + * indirectly, on the parameters known at compile-time. + * + * \param func The relax::Function to analyze + * + * \return The set of variables that can be computed at compile-time, + * in order of their occurrence within the function. + */ +TVM_DLL Array ComputableAtCompileTime(const Function& func); + } // namespace relax } // namespace tvm diff --git a/python/tvm/relax/analysis/__init__.py b/python/tvm/relax/analysis/__init__.py index d8454a02cc84..06b4f6432681 100644 --- a/python/tvm/relax/analysis/__init__.py +++ b/python/tvm/relax/analysis/__init__.py @@ -21,6 +21,7 @@ all_global_vars, all_vars, bound_vars, + computable_at_compile_time, contains_impure_call, definable_tir_vars_in_struct_info, defined_symbolic_vars, diff --git a/python/tvm/relax/analysis/analysis.py b/python/tvm/relax/analysis/analysis.py index 38f5ea2fea0e..83286c09803a 100644 --- a/python/tvm/relax/analysis/analysis.py +++ b/python/tvm/relax/analysis/analysis.py @@ -528,3 +528,28 @@ def detect_recursion(mod: tvm.IRModule) -> List[List[GlobalVar]]: with any other, it will be a singleton in this list. """ return _ffi_api.detect_recursion(mod) # type: ignore + + +def computable_at_compile_time(func: Function) -> List[Var]: + """Collect variables whose value can be computed at compile-time + + If a function has the `kNumInput` attribute, then the first + `kNumInput` parameters are provided at run-time, while all + remaining parameters may be known at compile-time. This utility + collects all variable bindings that only depend, directly or + indirectly, on the parameters known at compile-time. + + Parameters + ---------- + func: Function + + The `relax.Function` to analyze + + Returns + ------- + ret: List[Var] + + The set of variables that can be computed at compile-time, in + order of their occurrence within the function. + """ + return _ffi_api.computable_at_compile_time(func) # type: ignore diff --git a/src/relax/analysis/computable_at_compile_time.cc b/src/relax/analysis/computable_at_compile_time.cc new file mode 100644 index 000000000000..5ee336ff008f --- /dev/null +++ b/src/relax/analysis/computable_at_compile_time.cc @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file computable_at_compile_time.cc + * + * \brief Utilities for identifying potentially compile-time variables + */ + +#include +#include + +#include "../../support/ordered_set.h" + +namespace tvm { +namespace relax { + +namespace { +class CompileTimeCollector : ExprVisitor { + public: + static Array Collect(const Function& func) { + CompileTimeCollector visitor; + visitor(func); + return Array(visitor.known_relax_vars_.begin(), visitor.known_relax_vars_.end()); + } + + private: + void VisitExpr_(const FunctionNode* func) override { + if (auto opt_num_input = func->attrs.GetAttr(attr::kNumInput)) { + size_t num_input = opt_num_input.value()->value; + for (size_t i = num_input; i < func->params.size(); i++) { + MarkAsKnown(func->params[i]); + } + } + + ExprVisitor::VisitExpr_(func); + } + + void VisitBinding(const Binding& binding) override { + Expr value = GetBoundValue(binding); + bool can_compute_at_compile_time = [&]() { + for (const auto& relax_var : FreeVars(value)) { + if (!known_relax_vars_.count(relax_var)) { + return false; + } + } + for (const auto& tir_var : FreeSymbolicVars(value)) { + if (!known_tir_vars_.count(tir_var)) { + return false; + } + } + + return true; + }(); + + if (can_compute_at_compile_time) { + MarkAsKnown(binding->var); + } + + ExprVisitor::VisitBinding(binding); + } + + void MarkAsKnown(const Var& var) { + known_relax_vars_.insert(var); + for (const auto& tir_var : DefinableTIRVarsInStructInfo(GetStructInfo(var))) { + known_tir_vars_.insert(tir_var); + } + } + + support::OrderedSet known_relax_vars_; + std::unordered_set known_tir_vars_; +}; +} // namespace + +Array ComputableAtCompileTime(const Function& func) { + return CompileTimeCollector::Collect(func); +} + +TVM_REGISTER_GLOBAL("relax.analysis.computable_at_compile_time") + .set_body_typed(ComputableAtCompileTime); + +} // namespace relax +} // namespace tvm diff --git a/tests/python/relax/test_analysis_computable_at_compile_time.py b/tests/python/relax/test_analysis_computable_at_compile_time.py new file mode 100644 index 000000000000..be78648e96c7 --- /dev/null +++ b/tests/python/relax/test_analysis_computable_at_compile_time.py @@ -0,0 +1,243 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import List + +import tvm +import tvm.testing +from tvm.script import relax as R, tir as T + + +def _analyze_func(func: tvm.relax.Function) -> List[str]: + return [var.name_hint for var in tvm.relax.analysis.computable_at_compile_time(func)] + + +def test_no_num_input_attribute(): + """Without the "num_input" attribute, all params are runtime""" + + @R.function + def func(A: R.Tensor([16], "int32"), B: R.Tensor([16], "int32")): + C = R.add(A, B) + return C + + assert _analyze_func(func) == [] + + +def test_compile_time_param(): + """Parameters after "num_input" are known at compile-time""" + + @R.function + def func(A: R.Tensor([16], "int32"), B: R.Tensor([16], "int32")): + R.func_attr({"num_input": 1}) + return () + + assert _analyze_func(func) == ["B"] + + +def test_binding_using_one_param(): + """Bindings may be computable at compile-time""" + + @R.function + def func(A: R.Tensor([16], "int32"), B: R.Tensor([16], "int32")): + R.func_attr({"num_input": 1}) + C = R.add(B, B) + D = R.add(A, C) + return D + + assert _analyze_func(func) == ["B", "C"] + + +def test_binding_using_multiple_params(): + """Compile-time bindings may use multiple parameters""" + + @R.function + def func(A: R.Tensor([16], "int32"), B: R.Tensor([16], "int32"), C: R.Tensor([16], "int32")): + R.func_attr({"num_input": 1}) + D = R.add(B, C) + E = R.add(A, D) + return E + + assert _analyze_func(func) == ["B", "C", "D"] + + +def test_compile_time_binding_after_run_time(): + """Compile-time bindings may occur after run-time + + A binding being computable at compile-time only depends on the + arguments used for it. A value that is computable at compile-time + may occur after a value that is only computable at run-time. + """ + + @R.function + def func(A: R.Tensor([16], "int32"), B: R.Tensor([16], "int32")): + R.func_attr({"num_input": 1}) + C = R.add(A, A) + D = R.add(B, B) + E = R.add(C, D) + return E + + assert _analyze_func(func) == ["B", "D"] + + +def test_sequential_compile_time_bindings(): + """Compile-time bindings may occur after run-time + + A compile-time value may depend on variables defined within the + function, so long as those variables are themselves computable at + compile-time. + """ + + @R.function + def func(A: R.Tensor([16], "int32"), B: R.Tensor([16], "int32")): + R.func_attr({"num_input": 1}) + C = R.add(B, B) + D = R.add(C, C) + E = R.add(D, D) + F = R.add(E, A) + return F + + assert _analyze_func(func) == ["B", "C", "D", "E"] + + +def test_dataflow_vars(): + """Compile-time bindings may occur in dataflow blocks""" + + @R.function + def func(A: R.Tensor([16], "int32"), B: R.Tensor([16], "int32")): + R.func_attr({"num_input": 1}) + with R.dataflow(): + C = R.add(B, B) + D = R.add(C, C) + E = R.add(D, D) + F = R.add(E, A) + R.output(F) + return F + + assert _analyze_func(func) == ["B", "C", "D", "E"] + + +def test_compile_time_symbolic_shape(): + """Compile-time bindings may contain symbolic shapes""" + + @R.function + def func(A: R.Tensor([1], "int32"), B: R.Tensor(["n"], "int32")): + R.func_attr({"num_input": 1}) + n = T.int64() + + C: R.Tensor([n], "int32") = R.add(B, B) + D: R.Tensor([], "int32") = R.max(C, axis=0) + E: R.Tensor([1], "int32") = R.add(A, D) + return E + + assert _analyze_func(func) == ["B", "C", "D"] + + +def test_symbolic_variables_from_match_binding(): + """Symbolic vars may be inferred from compile-time bindings""" + + @R.function + def func(A: R.Tensor(ndim=1, dtype="int32"), B: R.Tensor(ndim=1, dtype="int32")): + R.func_attr({"num_input": 1}) + n = T.int64() + m = T.int64() + + A2 = R.match_cast(A, R.Tensor([n], "int32")) + B2 = R.match_cast(B, R.Tensor([m], "int32")) + + C = R.add(B2, B2) + D = R.max(C, axis=0) + E = R.max(A2, axis=0) + F = R.add(D, E) + return F + + assert _analyze_func(func) == ["B", "B2", "C", "D"] + + +def test_compile_time_expressions_may_not_use_runtime_symbolic_variables(): + """Symbolic vars may be inferred from compile-time bindings + + Here, `C` uses the symbolic variable `m`, which can be inferred + from the shape of `B` and is known at compile-time. However, `D` + uses the symbolic variable `n`, which cannot be inferred without + first knowing `A`, and is therefore unknown at compile-time. + """ + + @R.function + def func(A: R.Tensor(["n"], "int32"), B: R.Tensor(["m"], "int32")): + R.func_attr({"num_input": 1}) + n = T.int64() + m = T.int64() + + C = R.ones([m], "int32") + D = R.ones([n], "int32") + + E = (C, D) + return E + + assert _analyze_func(func) == ["B", "C"] + + +def test_compile_time_expressions_may_infer_same_variable_as_run_time(): + """Symbolic vars may be inferred from compile-time bindings + + A symbolic variable may be inferrable from multiple sources. + Here, while `n` can be inferred from the runtime parameter `A`, it + can also be inferred from the compile-time parameter `B`. + """ + + @R.function + def func(A: R.Tensor(["n"], "int32"), B: R.Tensor(["n"], "int32")): + R.func_attr({"num_input": 1}) + n = T.int64() + + C = R.ones([n], "int32") + D = R.ones([n], "int32") + + E = (C, D) + return E + + assert _analyze_func(func) == ["B", "C", "D", "E"] + + +def test_compile_time_expressions_may_use_variables_from_match_cast(): + """Symbolic vars may be inferred from compile-time bindings + + Here, `C` uses the symbolic variable `m`, which can be inferred + from the shape of `B` and is known at compile-time. However, `D` + uses the symbolic variable `n`, which cannot be inferred without + first knowing `A`, and is therefore unknown at compile-time. + """ + + @R.function + def func(A: R.Tensor(["n"], "int32"), B: R.Tensor(ndim=1, dtype="int32")): + R.func_attr({"num_input": 1}) + n = T.int64() + m = T.int64() + + B2 = R.match_cast(B, R.Tensor([m], "int32")) + + C = R.ones([m], "int32") + D = R.ones([n], "int32") + + E = (C, D) + return E + + assert _analyze_func(func) == ["B", "B2", "C"] + + +if __name__ == "__main__": + tvm.testing.main()