Skip to content

Comments

[BACKEND] Fix divisibility analysis for shift ops#4221

Merged
ThomasRaoux merged 3 commits intotriton-lang:mainfrom
ThomasRaoux:div_analysis
Jun 27, 2024
Merged

[BACKEND] Fix divisibility analysis for shift ops#4221
ThomasRaoux merged 3 commits intotriton-lang:mainfrom
ThomasRaoux:div_analysis

Conversation

@ThomasRaoux
Copy link
Collaborator

@ThomasRaoux ThomasRaoux commented Jun 27, 2024

Divisibility does not ensure that a value is not 0 therefore we cannot use divisibility as a minimum shifted values.

@ThomasRaoux ThomasRaoux changed the title Fix divisibility analysis for shift ops [BACKEND] Fix divisibility analysis for shift ops Jun 27, 2024
@ThomasRaoux ThomasRaoux requested a review from Jokeren June 27, 2024 06:57
@ThomasRaoux ThomasRaoux marked this pull request as ready for review June 27, 2024 06:57
@ThomasRaoux ThomasRaoux requested a review from ptillet as a code owner June 27, 2024 06:57
auto shift = rhs.getConstantValue().has_value()
? rhs.getConstantValue().value()
: rhs.getDivisibility(dim);
if (!rhs.getConstantValue().has_value())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. Any value that's not a power of 2 was wrong...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is similar to the discussion above but doing 1 << max(log_div_lhs - log_div_rhs, 0) at the end

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit different because rhs can be larger than the divisibility of rhs.
We want to get a pessimistic value of the possible divisibility

auto shift = rhs.getConstantValue().has_value()
? rhs.getConstantValue().value()
: rhs.getDivisibility(dim);
auto shift = rhs.getConstantValue().value_or(0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I mentioned, I'm worried about its a bit too conservative since rhs is usually not 0.
It's indeed tricky...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, if you have %a << %b, where both a and b are unknown values but with divisibility 4 and 4, then we could estimate the divisibility of the result as 4 << 4.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, 0's divisibility is usually the maximum possible power of 2 integer of its type. Maybe we could do special handling here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, if you have %a << %b, where both a and b are unknown values but with divisibility 4 and 4, then we could estimate the divisibility of the result as 4 << 4.

I don't understand how that is correct. if %b is 0 then this is not true?

FYI, 0's divisibility is usually the maximum possible power of 2 integer of its type. Maybe we could do special handling here.

I think we would have to fix a bunch of other cases. For instance when we calculate the divisibility of an unknown %a in %a << 1 we will assume divisibility by 2 however the value may still be 0.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I mentioned, I'm worried about its a bit too conservative since rhs is usually not 0.

I agree, I couldn't find a better way that would always be correct

Copy link
Contributor

@lezcano lezcano Jun 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I've understood well the definition of divisibility, you want to compute the largest power of two that divides a << b.
If there's no overflow, this is the same as a * 2^b.

int64_t log_div_lhs;
if (lhs.getConstantValue().has_value()) {
    auto lhs_v = *lhs.getConstantValue();
    // Or implement ctz manually or smth
    log_div_lhs = lhs_v ? __builtin_ctz(lhs_v) : 64;
} else {
    log_div_lhs = log2Int(lhs.getDivisibility(dim));
}
int64_t log_div_rhs;
if (rhs.getConstantValue().has_value()) {
    log_div_rhs = *rhs.getConstantValue();
} else {
    log_div_rhs = rhs.getDivisibility(dim);
}
auto shift = int64_t{1} << max(log_div_lhs + log_div_rhs, 64);

or something along this lines (not tested). If the result does not fit on 64 bits... then I guess this still gives you a lower bound I think, as we are just counting trailing zeros and then doing 2 ** # zeros.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I've understood well the definition of divisibility, you want to compute the largest power of two that divides a << b.

Correct

If there's no overflow, this is the same as a * 2^b.

yes

log_div_lhs = log2Int(lhs.getDivisibility(dim));

this assumes that lhs >= lhs.getDivisibility(dim) which I think is not always true as lhs can be 0.

Copy link
Contributor

@lezcano lezcano Jun 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case, if lhs == 0, then I chose 1 << 64 as its divisibility, as @Jokeren proposed in the branch above that. If we have that, then this code would still be correct... but I guess we don't have it?

Edit. After talking with Thomas, the issue here is not the line pointeda bove, but

    log_div_rhs = rhs.getDivisibility(dim);

This breaks in the following example:

%a = call foo() : i32 // -> we know nothing divisibility = 1
%b = mul %a, 2 // ->  divisibility = 2
%c = shl %d, %b 

Here the divisibility of %c is 1 if %a is zero and 2 otherwise.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if lhs == 0 then lhsDivisibility should already be MAX_INT, so I believe this case should already work.

@ThomasRaoux ThomasRaoux merged commit ab7b89b into triton-lang:main Jun 27, 2024
Jokeren pushed a commit that referenced this pull request Jul 1, 2024
Divisibility does not ensure that a value is not 0 therefore we cannot
use divisibility as a minimum shifted values.
Jokeren added a commit that referenced this pull request Jul 3, 2024
Update

Update

Update

Update

Add a more meaningful check to make sure we are not merging blocks (#4186)

This is a follow-up to
#4176 (comment)

I am now counting the number of blocks with (17) and without (31) block
merging. I double checked to make sure this does not pass when we use an
aggressive region simplification strategy.

[AMD] Skip mfma layout in maybeDuplicate (#4170)

The workaround introduced in
#4048 "forgot" to skip mfma
layout.

[TEST] Merge duplicate `max_num_imprecise_acc` tests and improve code (#4191)

[DOCS][NFC] Fix doc formatting problems (#4195)

1. f-string cannot be used as docstrings in Python.
2. URLs should follow the reStructuredText format.
3. Code snippets in a code block should be indented.

Tested and passed on a local machine.

[BACKEND] Fix regression in pipeliner pre-checks. (#4196)

During some previous refactoring we changed the logic and started
pipeling cases that had incompatible shared encoding. This was missed
because one of the lit test had not been updated :(

Remove tl.multiple_of call from tma persistent kernel (#4198)

[AMD] Guard against null in `BypassEpilogueSMEM` (#4203)

`val.getDefiningOp()` can return `nullptr`. In this case, we must fail
the `BypassEpilogueSMEM` rewrite pass for the given op. This prevents
run-time crashes.

[FRONTEND][NFC] Fix type checking, conditional logic, and loop structures for improved readability and performance (#4208)

Document TRITON_HOME (#4210)

Document the existence of `TRITON_HOME` environment variable.

The `TRITON_HOME` variable controls the location of the `.triton`
directory that stores, among other things, the files downloaded during a
`pip install -e python` virtualenv build. By default, this is located in
the user's home directory, at `~/.triton`.

I was trying to build Triton on my system on a large local disk, but
with limited network home directory space, and the `pip` command kept
failing with out of disk space errors. It turned out that during
installation, large files were downloaded to the `~/.triton` directory
causing failure.

After checking that it was not `pip` doing this, I found the
`TRITON_HOME` variable which allowed me to workaround the issue and
build Triton successfully. After seconding #4007, I decided to
contribute this documentation fix.

Co-authored-by: sree <sree@buckyball>

[BACKEND] Fix regression in i1 reduction (#4215)

Recent refactoring broke i1 shared memory load.

[BUILD] update URL for LLVM tarballs (#4216)

[BACKEND] Fix divisibility analysis for shift ops (#4221)

Divisibility does not ensure that a value is not 0 therefore we cannot
use divisibility as a minimum shifted values.

Support FP8 constant (#4222)

To unblock the compilation of kernels like below which don't operate
arithmetically on FP8.

```
@triton.jit
def triton_poi_fused__scaled_mm__to_copy_constant_pad_nd_lift_fresh_2(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 400624
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex % 784
    x1 = (xindex // 784)
    x2 = xindex
    tmp0 = x0
    tmp1 = tl.full([1], 769, tl.int64)
    tmp2 = tmp0 < tmp1
    tmp3 = tl.load(in_ptr0 + (x0 + (769*x1)), tmp2 & xmask, other=0.0)
    tmp4 = tmp3.to(tl.float8e4nv)
    tmp5 = tl.full(tmp4.shape, 0.0, tmp4.dtype)
    tmp6 = tl.where(tmp2, tmp4, tmp5)
    tl.store(out_ptr0 + (x2), tmp6, xmask)
```

[INTERPRETER] Implement implicit tensor conversion for assignment operators (#4214)

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update

Update
bertmaher pushed a commit to bertmaher/triton that referenced this pull request Dec 10, 2024
Divisibility does not ensure that a value is not 0 therefore we cannot
use divisibility as a minimum shifted values.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants