[BACKEND] Fix divisibility analysis for shift ops#4221
[BACKEND] Fix divisibility analysis for shift ops#4221ThomasRaoux merged 3 commits intotriton-lang:mainfrom
Conversation
87cd85f to
1dd1860
Compare
| auto shift = rhs.getConstantValue().has_value() | ||
| ? rhs.getConstantValue().value() | ||
| : rhs.getDivisibility(dim); | ||
| if (!rhs.getConstantValue().has_value()) |
There was a problem hiding this comment.
I agree. Any value that's not a power of 2 was wrong...
There was a problem hiding this comment.
I think this is similar to the discussion above but doing 1 << max(log_div_lhs - log_div_rhs, 0) at the end
There was a problem hiding this comment.
It's a bit different because rhs can be larger than the divisibility of rhs.
We want to get a pessimistic value of the possible divisibility
| auto shift = rhs.getConstantValue().has_value() | ||
| ? rhs.getConstantValue().value() | ||
| : rhs.getDivisibility(dim); | ||
| auto shift = rhs.getConstantValue().value_or(0); |
There was a problem hiding this comment.
As I mentioned, I'm worried about its a bit too conservative since rhs is usually not 0.
It's indeed tricky...
There was a problem hiding this comment.
For example, if you have %a << %b, where both a and b are unknown values but with divisibility 4 and 4, then we could estimate the divisibility of the result as 4 << 4.
There was a problem hiding this comment.
FYI, 0's divisibility is usually the maximum possible power of 2 integer of its type. Maybe we could do special handling here.
There was a problem hiding this comment.
For example, if you have %a << %b, where both a and b are unknown values but with divisibility 4 and 4, then we could estimate the divisibility of the result as 4 << 4.
I don't understand how that is correct. if %b is 0 then this is not true?
FYI, 0's divisibility is usually the maximum possible power of 2 integer of its type. Maybe we could do special handling here.
I think we would have to fix a bunch of other cases. For instance when we calculate the divisibility of an unknown %a in %a << 1 we will assume divisibility by 2 however the value may still be 0.
There was a problem hiding this comment.
As I mentioned, I'm worried about its a bit too conservative since rhs is usually not 0.
I agree, I couldn't find a better way that would always be correct
There was a problem hiding this comment.
If I've understood well the definition of divisibility, you want to compute the largest power of two that divides a << b.
If there's no overflow, this is the same as a * 2^b.
int64_t log_div_lhs;
if (lhs.getConstantValue().has_value()) {
auto lhs_v = *lhs.getConstantValue();
// Or implement ctz manually or smth
log_div_lhs = lhs_v ? __builtin_ctz(lhs_v) : 64;
} else {
log_div_lhs = log2Int(lhs.getDivisibility(dim));
}
int64_t log_div_rhs;
if (rhs.getConstantValue().has_value()) {
log_div_rhs = *rhs.getConstantValue();
} else {
log_div_rhs = rhs.getDivisibility(dim);
}
auto shift = int64_t{1} << max(log_div_lhs + log_div_rhs, 64);or something along this lines (not tested). If the result does not fit on 64 bits... then I guess this still gives you a lower bound I think, as we are just counting trailing zeros and then doing 2 ** # zeros.
There was a problem hiding this comment.
If I've understood well the definition of divisibility, you want to compute the largest power of two that divides a << b.
Correct
If there's no overflow, this is the same as a * 2^b.
yes
log_div_lhs = log2Int(lhs.getDivisibility(dim));
this assumes that lhs >= lhs.getDivisibility(dim) which I think is not always true as lhs can be 0.
There was a problem hiding this comment.
In that case, if lhs == 0, then I chose 1 << 64 as its divisibility, as @Jokeren proposed in the branch above that. If we have that, then this code would still be correct... but I guess we don't have it?
Edit. After talking with Thomas, the issue here is not the line pointeda bove, but
log_div_rhs = rhs.getDivisibility(dim);This breaks in the following example:
%a = call foo() : i32 // -> we know nothing divisibility = 1
%b = mul %a, 2 // -> divisibility = 2
%c = shl %d, %b
Here the divisibility of %c is 1 if %a is zero and 2 otherwise.
There was a problem hiding this comment.
if lhs == 0 then lhsDivisibility should already be MAX_INT, so I believe this case should already work.
1a4546b to
a7c1e4e
Compare
Divisibility does not ensure that a value is not 0 therefore we cannot use divisibility as a minimum shifted values.
Update Update Update Update Add a more meaningful check to make sure we are not merging blocks (#4186) This is a follow-up to #4176 (comment) I am now counting the number of blocks with (17) and without (31) block merging. I double checked to make sure this does not pass when we use an aggressive region simplification strategy. [AMD] Skip mfma layout in maybeDuplicate (#4170) The workaround introduced in #4048 "forgot" to skip mfma layout. [TEST] Merge duplicate `max_num_imprecise_acc` tests and improve code (#4191) [DOCS][NFC] Fix doc formatting problems (#4195) 1. f-string cannot be used as docstrings in Python. 2. URLs should follow the reStructuredText format. 3. Code snippets in a code block should be indented. Tested and passed on a local machine. [BACKEND] Fix regression in pipeliner pre-checks. (#4196) During some previous refactoring we changed the logic and started pipeling cases that had incompatible shared encoding. This was missed because one of the lit test had not been updated :( Remove tl.multiple_of call from tma persistent kernel (#4198) [AMD] Guard against null in `BypassEpilogueSMEM` (#4203) `val.getDefiningOp()` can return `nullptr`. In this case, we must fail the `BypassEpilogueSMEM` rewrite pass for the given op. This prevents run-time crashes. [FRONTEND][NFC] Fix type checking, conditional logic, and loop structures for improved readability and performance (#4208) Document TRITON_HOME (#4210) Document the existence of `TRITON_HOME` environment variable. The `TRITON_HOME` variable controls the location of the `.triton` directory that stores, among other things, the files downloaded during a `pip install -e python` virtualenv build. By default, this is located in the user's home directory, at `~/.triton`. I was trying to build Triton on my system on a large local disk, but with limited network home directory space, and the `pip` command kept failing with out of disk space errors. It turned out that during installation, large files were downloaded to the `~/.triton` directory causing failure. After checking that it was not `pip` doing this, I found the `TRITON_HOME` variable which allowed me to workaround the issue and build Triton successfully. After seconding #4007, I decided to contribute this documentation fix. Co-authored-by: sree <sree@buckyball> [BACKEND] Fix regression in i1 reduction (#4215) Recent refactoring broke i1 shared memory load. [BUILD] update URL for LLVM tarballs (#4216) [BACKEND] Fix divisibility analysis for shift ops (#4221) Divisibility does not ensure that a value is not 0 therefore we cannot use divisibility as a minimum shifted values. Support FP8 constant (#4222) To unblock the compilation of kernels like below which don't operate arithmetically on FP8. ``` @triton.jit def triton_poi_fused__scaled_mm__to_copy_constant_pad_nd_lift_fresh_2(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): xnumel = 400624 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:] xmask = xindex < xnumel x0 = xindex % 784 x1 = (xindex // 784) x2 = xindex tmp0 = x0 tmp1 = tl.full([1], 769, tl.int64) tmp2 = tmp0 < tmp1 tmp3 = tl.load(in_ptr0 + (x0 + (769*x1)), tmp2 & xmask, other=0.0) tmp4 = tmp3.to(tl.float8e4nv) tmp5 = tl.full(tmp4.shape, 0.0, tmp4.dtype) tmp6 = tl.where(tmp2, tmp4, tmp5) tl.store(out_ptr0 + (x2), tmp6, xmask) ``` [INTERPRETER] Implement implicit tensor conversion for assignment operators (#4214) Update Update Update Update Update Update Update Update Update Update Update Update Update Update Update Update Update Update Update Update Update Update Update Update Update Update Update Update Update Update Update Update Update Update Update Update Update Update Update Update Update Update Update Update Update
Divisibility does not ensure that a value is not 0 therefore we cannot use divisibility as a minimum shifted values.
Divisibility does not ensure that a value is not 0 therefore we cannot use divisibility as a minimum shifted values.