Skip to content

Commit 63ca050

Browse files
author
Min Chen
committed
[Arith] Add IntegerSetNode to represent Presburger Set
1 parent f4b53fb commit 63ca050

File tree

10 files changed

+620
-102
lines changed

10 files changed

+620
-102
lines changed

cmake/config.cmake

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,10 @@ set(USE_MICRO_STANDALONE_RUNTIME OFF)
144144
# - /path/to/llvm-config: enable specific LLVM when multiple llvm-dev is available.
145145
set(USE_LLVM OFF)
146146

147+
# Whether use MLIR to help analyze, requires USE_LLVM is enabled
148+
# Possible values: ON/OFF
149+
set(USE_MLIR OFF)
150+
147151
#---------------------------------------------
148152
# Contrib libraries
149153
#---------------------------------------------

cmake/modules/LLVM.cmake

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ if(NOT ${USE_LLVM} MATCHES ${IS_FALSE_PATTERN})
3535
message(STATUS "Set TVM_LLVM_VERSION=" ${TVM_LLVM_VERSION})
3636
# Set flags that are only needed for LLVM target
3737
add_definitions(-DTVM_LLVM_VERSION=${TVM_LLVM_VERSION})
38+
if (${TVM_MLIR_VERSION})
39+
add_definitions(-DTVM_MLIR_VERSION=${TVM_MLIR_VERSION})
40+
endif()
3841
tvm_file_glob(GLOB COMPILER_LLVM_SRCS src/target/llvm/*.cc)
3942
list(APPEND TVM_LINKER_LIBS ${LLVM_LIBS})
4043
list(APPEND COMPILER_SRCS ${COMPILER_LLVM_SRCS})

cmake/utils/FindLLVM.cmake

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,16 @@ macro(find_llvm use_llvm)
142142
string(REPLACE "$" ${__llvm_prefix} __lib_with_prefix "${__flag}")
143143
list(APPEND LLVM_LIBS "${__lib_with_prefix}")
144144
endforeach()
145+
if (${USE_MLIR})
146+
if (EXISTS "${__llvm_libdir}/libMLIRPresburger.a")
147+
if (EXISTS "${__llvm_libdir}/libMLIRSupport.a")
148+
message(STATUS "Found MLIR")
149+
list(APPEND LLVM_LIBS "${__llvm_libdir}/libMLIRPresburger.a")
150+
list(APPEND LLVM_LIBS "${__llvm_libdir}/libMLIRSupport.a")
151+
set(TVM_MLIR_VERSION ${TVM_LLVM_VERSION})
152+
endif()
153+
endif()
154+
endif()
145155
separate_arguments(__llvm_system_libs)
146156
foreach(__flag IN ITEMS ${__llvm_system_libs})
147157
# If the library file ends in .lib try to

include/tvm/node/reflection.h

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,107 @@ class AttrVisitor {
7272
//! \endcond
7373
};
7474

75+
// Attr getter.
76+
class AttrGetter : public AttrVisitor {
77+
public:
78+
const String& skey;
79+
runtime::TVMRetValue* ret;
80+
81+
AttrGetter(const String& skey, runtime::TVMRetValue* ret) : skey(skey), ret(ret) {}
82+
83+
bool found_ref_object{false};
84+
85+
void Visit(const char* key, double* value) final {
86+
if (skey == key) *ret = value[0];
87+
}
88+
void Visit(const char* key, int64_t* value) final {
89+
if (skey == key) *ret = value[0];
90+
}
91+
void Visit(const char* key, uint64_t* value) final {
92+
ICHECK_LE(value[0], static_cast<uint64_t>(std::numeric_limits<int64_t>::max()))
93+
<< "cannot return too big constant";
94+
if (skey == key) *ret = static_cast<int64_t>(value[0]);
95+
}
96+
void Visit(const char* key, int* value) final {
97+
if (skey == key) *ret = static_cast<int64_t>(value[0]);
98+
}
99+
void Visit(const char* key, bool* value) final {
100+
if (skey == key) *ret = static_cast<int64_t>(value[0]);
101+
}
102+
void Visit(const char* key, void** value) final {
103+
if (skey == key) *ret = static_cast<void*>(value[0]);
104+
}
105+
void Visit(const char* key, DataType* value) final {
106+
if (skey == key) *ret = value[0];
107+
}
108+
void Visit(const char* key, std::string* value) final {
109+
if (skey == key) *ret = value[0];
110+
}
111+
112+
void Visit(const char* key, runtime::NDArray* value) final {
113+
if (skey == key) {
114+
*ret = value[0];
115+
found_ref_object = true;
116+
}
117+
}
118+
void Visit(const char* key, runtime::ObjectRef* value) final {
119+
if (skey == key) {
120+
*ret = value[0];
121+
found_ref_object = true;
122+
}
123+
}
124+
};
125+
126+
class NodeAttrSetter : public AttrVisitor {
127+
public:
128+
std::string type_key;
129+
std::unordered_map<std::string, runtime::TVMArgValue> attrs;
130+
131+
void Visit(const char* key, double* value) final { *value = GetAttr(key).operator double(); }
132+
void Visit(const char* key, int64_t* value) final { *value = GetAttr(key).operator int64_t(); }
133+
void Visit(const char* key, uint64_t* value) final { *value = GetAttr(key).operator uint64_t(); }
134+
void Visit(const char* key, int* value) final { *value = GetAttr(key).operator int(); }
135+
void Visit(const char* key, bool* value) final { *value = GetAttr(key).operator bool(); }
136+
void Visit(const char* key, std::string* value) final {
137+
*value = GetAttr(key).operator std::string();
138+
}
139+
void Visit(const char* key, void** value) final { *value = GetAttr(key).operator void*(); }
140+
void Visit(const char* key, DataType* value) final { *value = GetAttr(key).operator DataType(); }
141+
void Visit(const char* key, runtime::NDArray* value) final {
142+
*value = GetAttr(key).operator runtime::NDArray();
143+
}
144+
void Visit(const char* key, ObjectRef* value) final {
145+
*value = GetAttr(key).operator ObjectRef();
146+
}
147+
148+
runtime::TVMArgValue GetAttr(const char* key) {
149+
auto it = attrs.find(key);
150+
if (it == attrs.end()) {
151+
LOG(FATAL) << type_key << ": require field " << key;
152+
}
153+
runtime::TVMArgValue v = it->second;
154+
attrs.erase(it);
155+
return v;
156+
}
157+
};
158+
159+
// List names;
160+
class AttrDir : public AttrVisitor {
161+
public:
162+
std::vector<std::string>* names;
163+
164+
void Visit(const char* key, double* value) final { names->push_back(key); }
165+
void Visit(const char* key, int64_t* value) final { names->push_back(key); }
166+
void Visit(const char* key, uint64_t* value) final { names->push_back(key); }
167+
void Visit(const char* key, bool* value) final { names->push_back(key); }
168+
void Visit(const char* key, int* value) final { names->push_back(key); }
169+
void Visit(const char* key, void** value) final { names->push_back(key); }
170+
void Visit(const char* key, DataType* value) final { names->push_back(key); }
171+
void Visit(const char* key, std::string* value) final { names->push_back(key); }
172+
void Visit(const char* key, runtime::NDArray* value) final { names->push_back(key); }
173+
void Visit(const char* key, runtime::ObjectRef* value) final { names->push_back(key); }
174+
};
175+
75176
/*!
76177
* \brief Virtual function table to support IR/AST node reflection.
77178
*

python/tvm/arith/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .int_set import (
2020
IntSet,
2121
IntervalSet,
22+
IntegerSet,
2223
estimate_region_lower_bound,
2324
estimate_region_strict_bound,
2425
estimate_region_upper_bound,

python/tvm/arith/int_set.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,23 @@ def __init__(self, min_value, max_value):
8181
self.__init_handle_by_constructor__(_ffi_api.IntervalSet, min_value, max_value)
8282

8383

84+
@tvm._ffi.register_object("arith.IntegerSet")
85+
class IntegerSet(IntSet):
86+
"""Represent of Presburger Set
87+
88+
Parameters
89+
----------
90+
constraint : PrimExpr
91+
The constraint expression.
92+
93+
vars : List[PrimExpr]
94+
The domain vars of Presburger Set.
95+
"""
96+
97+
def __init__(self, constraint, vars):
98+
self.__init_handle_by_constructor__(_ffi_api.IntegerSet, constraint, vars)
99+
100+
84101
def estimate_region_lower_bound(region, var_dom, predicate):
85102
"""Analyze the region with affine map, given the domain of variables and their predicate
86103
Some subregion may be discarded during the lower-bound analysis.

0 commit comments

Comments
 (0)