Skip to content

Conversation

@int3
Copy link
Contributor

@int3 int3 commented Aug 19, 2024

This is motivated by #4509. The crux of the problem is that the Triton code generator needs to inspect a function's arguments / attributes / types in order to determine how it should be called. This meant that "implementation details" like whether a function is a builtin needed to be exposed in the "interface" tl.extra.libdevice module, instead of just residing in tl.extra.cuda.libdevice. Moreover, this meant that libdevice functions marked as @core.extern in the interface could not be implemented via JitFunctions.

Allowing each backend to provide its own module map solves this problem as the code generator can inspect the actual function implementation.

@int3
Copy link
Contributor Author

int3 commented Aug 19, 2024

I initially tried to tackle #4509 with a purely "userspace" solution by changing @dispatch from a runtime-rebinding decorator to a declaration-time-rebinding one. But this wasn't possible because we don't know which backend will be used at declaration time. I then drafted https://github.com/int3/triton-cpu/tree/resolve which allows for user-defined dynamic symbol resolvers. Ultimately I settled on this PR because it seems simpler and doesn't actually change the semantics of the language.

@int3
Copy link
Contributor Author

int3 commented Aug 19, 2024

fixed test

This is motivated by triton-lang#4509. The crux of the problem is that the Triton
code generator needs to inspect a function's arguments / attributes /
types in order to determine how it should be called. This meant that
"implementation details" like whether a function is a builtin needed to
be exposed in the "interface" `tl.extra.libdevice` module, instead of
just residing in `tl.extra.cuda.libdevice`. Moreover, this meant that
libdevice functions marked as @core.extern in the interface could not be
implemented via JitFunctions.

Allowing each backend to provide its own module map solves this problem
as the code generator can inspect the actual function implementation.
@ThomasRaoux ThomasRaoux merged commit 2ea4890 into triton-lang:main Aug 22, 2024
Jokeren pushed a commit that referenced this pull request Aug 24, 2024
This is motivated by #4509. The crux of the problem is that the Triton
code generator needs to inspect a function's arguments / attributes /
types in order to determine how it should be called. This meant that
"implementation details" like whether a function is a builtin needed to
be exposed in the "interface" `tl.extra.libdevice` module, instead of
just residing in `tl.extra.cuda.libdevice`. Moreover, this meant that
libdevice functions marked as `@core.extern` in the interface could not
be implemented via JitFunctions.

Allowing each backend to provide its own module map solves this problem
as the code generator can inspect the actual function implementation.
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Aug 31, 2024
…#134774)

In triton-lang/triton#4539 the `make_ir` API was modified to accept a new `module_map` parameter. Update the Inductor callsite accordingly, preserving backwards compatibility following the existing code.

Fixes #134674

Pull Request resolved: #134774
Approved by: https://github.com/EikanWang, https://github.com/zou3519, https://github.com/jansel
tolleybot pushed a commit to tolleybot/pytorch that referenced this pull request Sep 14, 2024
…pytorch#134774)

In triton-lang/triton#4539 the `make_ir` API was modified to accept a new `module_map` parameter. Update the Inductor callsite accordingly, preserving backwards compatibility following the existing code.

Fixes pytorch#134674

Pull Request resolved: pytorch#134774
Approved by: https://github.com/EikanWang, https://github.com/zou3519, https://github.com/jansel
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
…pytorch#134774)

In triton-lang/triton#4539 the `make_ir` API was modified to accept a new `module_map` parameter. Update the Inductor callsite accordingly, preserving backwards compatibility following the existing code.

Fixes pytorch#134674

Pull Request resolved: pytorch#134774
Approved by: https://github.com/EikanWang, https://github.com/zou3519, https://github.com/jansel
davidberard98 added a commit to davidberard98/triton that referenced this pull request Nov 6, 2024
Context: In `CodeGenerator.__init__`, globals for a given triton function are modified to handle remapping the libdevice module to cuda or hip (from triton-lang#4539). In particular, this logic:

```python
for k, v in gscope.items():  # gscope is a dict of fn.__globals__
  ...
  self.gscope[k] = getattr(module_map[module_name], k)
```

was failing if you do this in the global scope: `from triton.language.extras.libdevice import fast_dividef as my_fast_dividef`.
peterbell10 pushed a commit that referenced this pull request Nov 6, 2024
…5081)

Context: in `CodeGenerator.__init__`, globals for a given triton
function are modified to handle remapping the libdevice module to cuda
or hip (from #4539). In
particular, this logic:

```python
for k, v in gscope.items():  # gscope is a dict of fn.__globals__
  ...
  self.gscope[k] = getattr(module_map[module_name], k)
```

was failing if you do this in the global scope: `from
triton.language.extras.libdevice import fast_dividef as
my_fast_dividef`.
bertmaher pushed a commit that referenced this pull request Nov 6, 2024
…5081)

Context: in `CodeGenerator.__init__`, globals for a given triton
function are modified to handle remapping the libdevice module to cuda
or hip (from #4539). In
particular, this logic:

```python
for k, v in gscope.items():  # gscope is a dict of fn.__globals__
  ...
  self.gscope[k] = getattr(module_map[module_name], k)
```

was failing if you do this in the global scope: `from
triton.language.extras.libdevice import fast_dividef as
my_fast_dividef`.
Luosuu pushed a commit to Luosuu/triton that referenced this pull request Nov 13, 2024
…riton-lang#5081)

Context: in `CodeGenerator.__init__`, globals for a given triton
function are modified to handle remapping the libdevice module to cuda
or hip (from triton-lang#4539). In
particular, this logic:

```python
for k, v in gscope.items():  # gscope is a dict of fn.__globals__
  ...
  self.gscope[k] = getattr(module_map[module_name], k)
```

was failing if you do this in the global scope: `from
triton.language.extras.libdevice import fast_dividef as
my_fast_dividef`.
bertmaher pushed a commit to bertmaher/triton that referenced this pull request Dec 10, 2024
…ng#4539)

This is motivated by triton-lang#4509. The crux of the problem is that the Triton
code generator needs to inspect a function's arguments / attributes /
types in order to determine how it should be called. This meant that
"implementation details" like whether a function is a builtin needed to
be exposed in the "interface" `tl.extra.libdevice` module, instead of
just residing in `tl.extra.cuda.libdevice`. Moreover, this meant that
libdevice functions marked as `@core.extern` in the interface could not
be implemented via JitFunctions.

Allowing each backend to provide its own module map solves this problem
as the code generator can inspect the actual function implementation.
liuyunqi20 pushed a commit to flagos-ai/FlagTree that referenced this pull request Oct 21, 2025
…5081)

Context: in `CodeGenerator.__init__`, globals for a given triton
function are modified to handle remapping the libdevice module to cuda
or hip (from triton-lang/triton#4539). In
particular, this logic:

```python
for k, v in gscope.items():  # gscope is a dict of fn.__globals__
  ...
  self.gscope[k] = getattr(module_map[module_name], k)
```

was failing if you do this in the global scope: `from
triton.language.extras.libdevice import fast_dividef as
my_fast_dividef`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants