Skip to content

Commit

Permalink
Use FloatImm (not UIntImm) to hold immediates of custom datatypes
Browse files Browse the repository at this point in the history
This change switches from using UIntImm to FloatImm for storing immediates of
custom datatypes. The value of the number is stored in a double, which should be
enough precision for now, for most custom types we will explore in the immediate
future.

In line with this change, we change the datatype lowering so that FloatImms are
lowered to UInts of the appropriate size. Originally, this was going to be done
by allowing the user to register a double->uint_<storage size>_t conversion
which would be called at compile time to convert the value from the FloatImm to
a UInt and store it in a UIntImm. After discussions with Tianqi, we decided to
take the simpler route, and lower FloatImms just as we lower all other ops: by
replacing them with Call nodes. In this case, presumably the user will Call out
to a conversion function in their datatype library.

The justification for this decision is due to the functionality added in apache#1486.
This pull request adds the ability to load LLVM bytecode in at compile time.
This applies in our case as follows:
 1. The user writes their custom datatype programs and registers their lowering
    functions in the same way we've been doing it so far. All operations over
    custom datatypes are lowered to Calls to the datatype library.
 2. The user compiles their datatype library to LLVM bytecode.
 3. At TVM compile time, the user loads the LLVM bytecode. Depending on how the
    datatype library is written, Clang should be able to perform constant
    folding over the custom datatype immediates, even if their conversions are
    done with calls to the library.

Additionally adds test to test the FloatImm codepath.
  • Loading branch information
gussmith23 committed May 2, 2019
1 parent c6f48dd commit fee8f8b
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 6 deletions.
4 changes: 2 additions & 2 deletions python/tvm/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from ._ffi.function import register_func as _register_func
from . import make as _make
from .api import convert
from .expr import Call as _Call, Cast as _Cast
from .expr import Call as _Call, Cast as _Cast, FloatImm as _FloatImm
from ._ffi.runtime_ctypes import TVMType as _TVMType
from . import _api_internal

Expand Down Expand Up @@ -97,7 +97,7 @@ def lower(op):
dtype = "uint" + str(t.bits)
if t.lanes > 1:
dtype += "x" + str(t.lanes)
if isinstance(op, _Cast):
if isinstance(op, (_Cast, _FloatImm)):
return _make.Call(dtype, extern_func_name, convert([op.value]),
_Call.Extern, None, 0)
return _make.Call(dtype, extern_func_name, convert([op.a, op.b]),
Expand Down
9 changes: 9 additions & 0 deletions src/codegen/datatype/registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,15 @@ const runtime::PackedFunc* GetCastLowerFunc(const std::string& target, uint8_t t
return runtime::Registry::Get(ss.str());
}

const runtime::PackedFunc* GetFloatImmLowerFunc(const std::string& target, uint8_t type_code) {
std::ostringstream ss;
ss << "tvm.datatype.lower.";
ss << target;
ss << ".FloatImm.";
ss << datatype::Registry::Global()->GetTypeName(type_code);
return runtime::Registry::Get(ss.str());
}

uint64_t ConvertConstScalar(uint8_t type_code, double value) {
std::ostringstream ss;
ss << "tvm.datatype.convertconstscalar.float.";
Expand Down
11 changes: 7 additions & 4 deletions src/codegen/datatype/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,12 @@ namespace datatype {
* ensuring that neither conflict with existing types.
* 2. Use TVM_REGISTER_GLOBAL to register the lowering functions needed to
* lower the custom datatype. In general, these will look like:
* For Casts: tvm.datatype.lower.Cast.<target>.<type>.<src_type>
* Example: tvm.datatype.lower.Cast.llvm.myfloat.float for a Cast from
* For Casts: tvm.datatype.lower.<target>.Cast.<type>.<src_type>
* Example: tvm.datatype.lower.llvm.Cast.myfloat.float for a Cast from
* float to myfloat.
* Example: tvm.datatype.lower.add.llvm.myfloat
* For other ops: tvm.datatype.lower.<op>.<target>.<type>
* For other ops: tvm.datatype.lower.<target>.<op>.<type>
* Examples: tvm.datatype.lower.llvm.Add.myfloat
* tvm.datatype.lower.llvm.FloatImm.posit
*/
class Registry {
public:
Expand Down Expand Up @@ -79,6 +80,8 @@ const runtime::PackedFunc *GetCastLowerFunc(const std::string &target,
uint8_t type_code,
uint8_t src_type_code);

const runtime::PackedFunc* GetFloatImmLowerFunc(const std::string& target, uint8_t type_code);

#define DEFINE_GET_LOWER_FUNC_(OP) \
inline const runtime::PackedFunc* Get##OP##LowerFunc(const std::string& target, \
uint8_t type_code) { \
Expand Down
26 changes: 26 additions & 0 deletions src/pass/lower_custom_datatypes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,31 @@ class LoadLowerer : public IRMutator {
}
};

/*!
* \brief Mutator for lowering immediates of custom datatypes
*
* As with the other ops, immediates are lowered using a user-provided function. We separate them
* into their own IRMutator as they must be mutated after all other ops.
*/
class FloatImmLowerer : public IRMutator {
public:
explicit FloatImmLowerer(const std::string& target) : target_(target) {}

inline Expr Mutate_(const FloatImm* imm, const Expr& e) final {
auto type_code = imm->type.code();
if (datatype::Registry::Global()->GetTypeRegistered(type_code)) {
auto lower = datatype::GetFloatImmLowerFunc(target_, type_code);
CHECK(lower) << "FloatImm lowering function for target " << target_ << " type "
<< static_cast<unsigned>(type_code) << " not found";
return (*lower)(e);
}
return e;
}

private:
std::string target_;
};

LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target) {
auto n = make_node<LoweredFuncNode>(*f.operator->());
// We lower in stages. First, we lower all operations (e.g. casts, binary ops,
Expand All @@ -167,6 +192,7 @@ LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target) {
n->body = CustomDatatypesLowerer(target).Mutate(n->body);
n->body = AllocateLowerer().Mutate(n->body);
n->body = LoadLowerer().Mutate(n->body);
n->body = FloatImmLowerer(target).Mutate(n->body);
return LoweredFunc(n);
}

Expand Down
32 changes: 32 additions & 0 deletions tests/python/unittest/test_custom_datatypes_mybfloat16.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ def setup():
tvm.datatype.register_op(
tvm.datatype.create_lower_func("BFloat16Add_wrapper"), "Add", "llvm",
"bfloat")
tvm.datatype.register_op(
tvm.datatype.create_lower_func("FloatToBFloat16_wrapper"), "FloatImm",
"llvm", "bfloat")


def test_bfloat_add_and_cast_1():
Expand Down Expand Up @@ -122,7 +125,36 @@ def test_bfloat_add_and_cast_2():
assert np.array_equal(z_expected, z.asnumpy())


def test_bfloat_add_and_cast_FloatImm():
X = tvm.placeholder((3, ), name="X")
Z = topi.cast(
topi.add(
topi.cast(X, dtype="custom[bfloat]16"),
tvm.expr.FloatImm("custom[bfloat]16", 1.5)),
dtype="float")

# Create schedule and lower, manually lowering datatypes. Once datatype
# lowering is integrated directly into TVM's lower/build process, we won't
# need to do this manually.
s = tvm.create_schedule([Z.op])
flist = tvm.lower(s, [X, Z])
flist = [flist]
flist = [ir_pass.LowerCustomDatatypes(func, tgt) for func in flist]
built_cast = tvm.build(flist[0], target=tgt)

ctx = tvm.context(tgt, 0)

x = tvm.nd.array(np.array([0.0, 1.0, 1.5]).astype("float32"), ctx=ctx)
z_expected = np.array([1.5, 2.5, 3.0]).astype("float32")
z = tvm.nd.empty(Z.shape, dtype=Z.dtype, ctx=ctx)

built_cast(x, z)

assert np.array_equal(z_expected, z.asnumpy())


if __name__ == "__main__":
setup()
test_bfloat_add_and_cast_1()
test_bfloat_add_and_cast_2()
test_bfloat_add_and_cast_FloatImm()

0 comments on commit fee8f8b

Please sign in to comment.