Skip to content

Conversation

@LeiWang1999
Copy link
Contributor

@LeiWang1999 LeiWang1999 commented Feb 13, 2024

To optimize i4_to_f16 decoding, we can use some advanced hardware instructions to do fast type conversion to alleviate the cost of decoding, we can do that by tensorize in tvm.

To tensorize decoding, this pr extends the call component of ir_comparator, which is necessary because the decode block comprises call expressions.

Moreover, currently comparator do simplification on the lhs expr, however, the tensor intrin descs are not simplified, which will be inconsistent and will fail at comparation,
see this pr: #14108.

For example, we provide a test case for this situation:

def test_tensorize_arith_simplification():
    # fmt: off
    @T.prim_func
    def decode_i4s_to_int32_to_f16():
        B_decode_local = T.alloc_buffer((16384, 16384), "float16", scope="local")
        B_local = T.alloc_buffer((16384, 2048), "int32", scope="local")
        for ax0_0 in T.thread_binding(8192, thread="blockIdx.x"):
            for ax0_1 in T.thread_binding(2, thread="threadIdx.y"):
                for ax1_0 in range(32):
                    for ax1_1 in T.thread_binding(64, thread="threadIdx.x"):
                        for ax0, ax1 in T.grid(1, 8):
                            with T.block("B_decode_local"):
                                v0 = T.axis.spatial(16384, ax0_0 * 2 + ax0_1 + ax0)
                                v1 = T.axis.spatial(16384, ax1_0 * 512 + ax1_1 * 8 + ax1)
                                T.reads(B_local[v0, v1 // 8])
                                T.writes(B_decode_local[v0, v1])
                                B_decode_local[v0, v1] = T.Cast("float16", T.shift_right(T.shift_left(T.bitwise_and(T.shift_right(B_local[v0, v1 // 8], v1 % 8 * 4), 15), 28), 28))

The desc should be simplified from [v1 // 8] and [v1 % 8] to [0], [v1] to match the simplified lhs expr.

To do simplification for tensor intrin's desc, we warp and reuse tir::transform::simplify to support simplification for single stmt.

@LeiWang1999 LeiWang1999 marked this pull request as ready for review February 13, 2024 13:59
@LeiWang1999
Copy link
Contributor Author

LeiWang1999 commented Feb 13, 2024

cc @Hzfengsy and @Lunderberg , looks like pr #13299 provides a stmt_simplify declaration but do not provide an implementation.

@tqchen
Copy link
Member

tqchen commented Feb 13, 2024

cc @vinx13


class TensorIntrinSimplifier : public arith::IRMutatorWithAnalyzer {
public:
static PrimFunc Apply(PrimFunc func, arith::Analyzer* analyzer) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of simplifying the body of the PrimFunc, can we instead simplify the entire PrimFunc? That way, dynamic expressions that are used in shapes are exposed to the analyzer as non-negative. (e.g. Using buffer of shape [n,m] implies that n >= 0 && m >= 0.)

Copy link
Contributor Author

@LeiWang1999 LeiWang1999 Feb 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as u mentioned in #13299 (comment) , perhaps its better to simplify in prim_func level, I chose to implement a stmt simplifier because it may be more useful. The rationale is that stmt is more fine-grained.

Moreover, in the context of tensor desc in tensorize schedule, prim_func typically encompasses a single block without dynamic symbolic. I think for this issue a stmt simplifier is enough.
But we can implement a prim_func one as well, should we keep both stmt and primfunc simplifier or just maintain only one of them?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as u mentioned in #13299 (comment) , perhaps its better to simplify in prim_func level, I chose to implement a stmt simplifier because it may be more useful. The rationale is that stmt is more fine-grained.

Good point. Thinking on it again in the morning, I think we should avoid having the simplify function for tir::Stmt altogether, because it is more fine-grained. That is, its existence would encourage simplifications to be performed for specific statements, even though those statements might not be the outer-most.

But we can implement a prim_func one as well, should we keep both stmt and primfunc simplifier or just maintain only one of them?

I think having a simplifier for a PrimFunc would be better, because it encourages developers to simplify with the full context of a statement. The functionality already exists here, and would just need a wrapper function to expose StmtSimplifier::Apply.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Lunderberg hi, I think this pr is ready for review.

@tqchen
Copy link
Member

tqchen commented Mar 4, 2024

@vinx13 @spectrometerHBH do you mind take a look at this PR?

@vinx13 vinx13 merged commit 7b7677f into apache:main Mar 8, 2024
Lunderberg pushed a commit to Lunderberg/tvm that referenced this pull request Mar 12, 2024
* support tensorize with simplified and call expr

* replace stmt simplifier with primfunc simplifier

* lint fix

* lint:remove white space

* lint: remove white space

* cpp lint fix

* lint: resolve include

* clang format lint fix
thaisacs pushed a commit to thaisacs/tvm that referenced this pull request Apr 3, 2024
* support tensorize with simplified and call expr

* replace stmt simplifier with primfunc simplifier

* lint fix

* lint:remove white space

* lint: remove white space

* cpp lint fix

* lint: resolve include

* clang format lint fix
LeiWang1999 pushed a commit to LeiWang1999/tvm that referenced this pull request May 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants