Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions include/tvm/target/target.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,16 @@ class TargetNode : public Object {
/*! \return The device type for this target */
TVM_DLL int GetTargetDeviceType() const;

/*!
* \brief Check if the target contains a key
*
* \param query_key The string name of the key to be checked
*
* \return True if the target's `TargetNode::keys` contains the
* specified key, False otherwise.
*/
TVM_DLL bool HasKey(const std::string& query_key) const;

/*!
* \brief Returns a human readable representation of \p Target which includes all fields,
* especially the host. Useful for diagnostic messages and debugging.
Expand Down
5 changes: 5 additions & 0 deletions src/target/target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,11 @@ int TargetNode::GetTargetDeviceType() const {
return kind->default_device_type;
}

bool TargetNode::HasKey(const std::string& query_key) const {
return std::any_of(keys.begin(), keys.end(),
[&query_key](const auto& key) { return key == query_key; });
}

String TargetNode::ToDebugString() const {
std::ostringstream os;
os << "Target(";
Expand Down
7 changes: 5 additions & 2 deletions src/tir/transforms/combine_context_call.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@

#include <unordered_map>

#include "ir_utils.h"

namespace tvm {
namespace tir {

Expand Down Expand Up @@ -102,8 +104,9 @@ namespace transform {

Pass CombineContextCall() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = ContextCallCombiner().Combine(std::move(n->body));
if (IsHostFunc(f).value_or(false)) {
f.CopyOnWrite()->body = ContextCallCombiner().Combine(f->body);
}
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.CombineContextCall", {});
Expand Down
10 changes: 10 additions & 0 deletions src/tir/transforms/ir_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,16 @@ std::pair<int32_t, int32_t> GetWmmaFragmentDimSize(const std::string& shape_str,
return std::pair<int32_t, int32_t>(0, 0);
}

std::optional<bool> IsHostFunc(const PrimFunc& func) {
if (func->HasNonzeroAttr(tvm::tir::attr::kIsHostFunc)) {
return true;
} else if (auto target = func->GetAttr<Target>(tvm::attr::kTarget)) {
return target.value()->HasKey("cpu");
} else {
return std::nullopt;
}
}

namespace transform {
Pass ConvertSSA() {
auto pass_func = [](IRModule mod, PassContext ctx) {
Expand Down
12 changes: 12 additions & 0 deletions src/tir/transforms/ir_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include <tvm/tir/op.h>

#include <limits>
#include <optional>
#include <string>
#include <unordered_map>
#include <utility>
Expand Down Expand Up @@ -351,6 +352,17 @@ CollectStorageAlignAnnotation(const Stmt& body);
std::pair<int32_t, int32_t> GetWmmaFragmentDimSize(const std::string& shape_str,
const std::string& scope);

/*! \brief Check if a PrimFunc is a host function
*
* \param func The function to be inspected
*
* \return True if the function is known to run on the host, false if
* the function is known to run on the device. If it cannot be
* determined (e.g. a function without a tvm::attr::kTarget
* attribute), returns std::nullopt.
*/
std::optional<bool> IsHostFunc(const PrimFunc& func);

} // namespace tir
} // namespace tvm
#endif // TVM_TIR_TRANSFORMS_IR_UTILS_H_
102 changes: 82 additions & 20 deletions tests/python/unittest/test_tir_transform_combine_context_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,33 +14,95 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import tvm
from tvm import te
import tvm.testing

from tvm.script import tir as T, ir as I


def _device_context(dev_type, dev_id):
ctx = tvm.tir.call_extern("handle", "device_context", dev_type, dev_id)
return tvm.tir.Call("handle", "tir.tvm_thread_context", [ctx])


class TestCombineContextsInLoop(tvm.testing.CompareBeforeAfter):
"""Device contexts should be hoisted and merged"""

transform = tvm.tir.transform.CombineContextCall()

def before(self):
@T.prim_func
def func(dev_type: T.int32, n: T.int32):
T.func_attr({"target": T.target("llvm")})
A = T.allocate([n], "float32", "global")
for i in range(n):
T.call_extern(
"int32",
"fadd",
_device_context(dev_type, 0),
A,
)
for j in range(10):
T.call_extern(
"int32",
"fadd",
_device_context(dev_type, 1),
A,
)
T.call_extern(
"int32",
"fadd",
_device_context(dev_type, 0),
A,
)

return func

def expected(dev_type: T.int32, n: T.int32):
T.func_attr({"target": T.target("llvm")})
ctx_cache_: T.handle = T.call_extern("handle", "device_context", dev_type, 0)
ctx_cache__1: T.handle = T.call_extern("handle", "device_context", dev_type, 1)
A = T.allocate([n], "float32", "global")
for i in range(n):
T.call_extern("int32", "fadd", ctx_cache_, A)
for j in range(10):
T.call_extern("int32", "fadd", ctx_cache__1, A)
T.call_extern("int32", "fadd", ctx_cache_, A)

def test_for():
dev_type = te.var("dev_type")

def device_context(dev_id):
ctx = tvm.tir.call_extern("handle", "device_context", dev_type, dev_id)
return tvm.tir.Call("handle", "tir.tvm_thread_context", [ctx])
class TestCombineContextsInLoopWithoutTarget(TestCombineContextsInLoop):
"""CombineContextCall only updates host-side functions"""

ib = tvm.tir.ir_builder.create()
n = te.var("n")
A = ib.allocate("float32", n, name="A", scope="global")
with ib.for_range(0, n, name="i") as i:
ib.emit(tvm.tir.call_extern("int32", "fadd", device_context(0), A.asobject().data))
with ib.for_range(0, 10, name="j") as j:
ib.emit(tvm.tir.call_extern("int32", "fadd", device_context(1), A.asobject().data))
ib.emit(tvm.tir.call_extern("int32", "fadd", device_context(0), A.asobject().data))
body = ib.get()
mod = tvm.IRModule({"func": tvm.tir.PrimFunc([dev_type, n], body)})
def before(self):
@T.prim_func
def func(dev_type: T.int32, n: T.int32):
A = T.allocate([n], "float32", "global")
for i in range(n):
T.call_extern(
"int32",
"fadd",
_device_context(dev_type, 0),
A,
)
for j in range(10):
T.call_extern(
"int32",
"fadd",
_device_context(dev_type, 1),
A,
)
T.call_extern(
"int32",
"fadd",
_device_context(dev_type, 0),
A,
)

mod = tvm.tir.transform.CombineContextCall()(mod)
return func

assert mod["func"].body.value.dtype == "handle"
assert mod["func"].body.body.value.dtype == "handle"
expected = before


if __name__ == "__main__":
test_for()
tvm.testing.main()