Skip to content

[FRONTEND] Fix expand_dims and tl.full to handle scalar tensors#2275

Merged
ptillet merged 4 commits intotriton-lang:mainfrom
peterbell10:triton-scalars
Sep 12, 2023
Merged

[FRONTEND] Fix expand_dims and tl.full to handle scalar tensors#2275
ptillet merged 4 commits intotriton-lang:mainfrom
peterbell10:triton-scalars

Conversation

@peterbell10
Copy link
Copy Markdown
Contributor

This fixes a few bugs related to scalar tensors:

  • tl.full([], fill_value, dtype) fails with TypeError('0d block_type is forbidden')
  • scalar[None] fails with TypeError("'constexpr' object is not iterable")
  • scalar[None, None] fails with AttributeError("'dtype' object has no attribute 'shape'")
  • scalar.shape returns [1] instead of 0-dim []
    • Also related, tl.zeros_like(scalar) returns a 1d tensor instead of another scalar

Currently `tl.full([], 0, tl.int32)` fails with:
```
TypeError('0d block_type is forbidden')
```
This fixes it to create a tensor with scalar type.
Currently calling:
```
tl.sum(x)[None]
```
results in
```
TypeError("'constexpr' object is not iterable")
```

Further if we fix that, it still fails in `semantic.expand_dims` with
```
AttributeError("'dtype' object has no attribute 'shape'")
```
# test broadcast
# ---------------
@pytest.mark.parametrize("dtype", dtypes_with_bfloat16)
def test_broadcast(dtype, device):
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated change, but this test overwrites the test above and looks like it might have been a bad merge conflict resolution. I've removed the shadowed test that is identical other than hard coding device="cuda".

@peterbell10
Copy link
Copy Markdown
Contributor Author

Looks like tl.multiple_of(scalar, 8) is failing because it expects a list of hints for each rank of the tensor. I don't think that really makes sense, the divisibility hint should probably be a scalar value. Does that sound right?

@peterbell10
Copy link
Copy Markdown
Contributor Author

Actually, there is special handling of SplatOp in the axis analysis that expects 1 single element, so it's okay for multiple_of to get more values than the tensor's rank here.

@ptillet ptillet merged commit ab9da3b into triton-lang:main Sep 12, 2023
@peterbell10 peterbell10 deleted the triton-scalars branch September 14, 2023 02:04
alexander-zinoviev pushed a commit to alexander-zinoviev/triton that referenced this pull request Sep 21, 2023
…on-lang#2275)

This fixes a few bugs related to scalar tensors:
- `tl.full([], fill_value, dtype)` fails with `TypeError('0d block_type
is forbidden')`
- `scalar[None]` fails with `TypeError("'constexpr' object is not
iterable")`
- `scalar[None, None]` fails with `AttributeError("'dtype' object has no
attribute 'shape'")`
- `scalar.shape` returns `[1]` instead of 0-dim `[]`
- Also related, `tl.zeros_like(scalar)` returns a 1d tensor instead of
another scalar
pingzhuu pushed a commit to siliconflow/triton that referenced this pull request Apr 2, 2024
…on-lang#2275)

This fixes a few bugs related to scalar tensors:
- `tl.full([], fill_value, dtype)` fails with `TypeError('0d block_type
is forbidden')`
- `scalar[None]` fails with `TypeError("'constexpr' object is not
iterable")`
- `scalar[None, None]` fails with `AttributeError("'dtype' object has no
attribute 'shape'")`
- `scalar.shape` returns `[1]` instead of 0-dim `[]`
- Also related, `tl.zeros_like(scalar)` returns a 1d tensor instead of
another scalar
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants