Skip to content

[BACKEND] Improve printf.#2532

Merged
jlebar merged 1 commit into
mainfrom
jlebar/improve-printf
Oct 25, 2023
Merged

[BACKEND] Improve printf.#2532
jlebar merged 1 commit into
mainfrom
jlebar/improve-printf

Conversation

@jlebar
Copy link
Copy Markdown
Contributor

@jlebar jlebar commented Oct 23, 2023

[BACKEND] Improve printf.

Previously, we printed all of a GPU thread's values in a single printf()
call, and this, plus the user-specified prefix, was all we printed.

This caused a few problems.

  • nvptx printf can only handle 32 arguments; if you pass more than
    that, it prints garbage. So if a thread had more than 32 values, you
    couldn't print them, issue issues with print function #2486.

  • The order of the values within the Triton program (GPU thread block)
    is an implementation detail -- it depends on the layout the compiler
    assigns to a tensor. So this also prevented you from interpreting
    the printed output.

To address this, we now print the Triton pid and multi-dimensional
Tensor index for each value. And each value gets its own line to avoid
passing too many args to printf.

Example output:

```
pid (0, 1, 2) idx (36, 127) x: 42
```

If you want to observe all the values in a tensor in order, you can grep
and then sort the output.

We also make a UX enhancement to print: The printed label always ends
with ": "; you don't have to add it yourself.

Fixes #2486.

PR chain

  1. 👉 [BACKEND] Improve printf. #2532 👈 YOU ARE HERE

// Cache the llvm Value for each format string to avoid making duplicates in
// the module. Sorry about `mutable`, but the top-level matchAndRewrite
// function is const, so it virally comes down here.
mutable llvm::StringMap<Value> formatStrCache;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caching will break if the new print function is not dominated by the value that has been cached. For instance it will always fail if we are printing from different functions. In general doing caching during rewrite patterns is not trivial, we could cache only the global value but this is probably a premature optimization.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addStringToModule adds the string as a global, so shouldn't this work? https://github.com/openai/triton/blob/020f43d5a37c2f1cce5c70fcda17ec57b000755c/lib/Conversion/TritonGPUToLLVM/Utility.cpp#L354

(The reason I wanted to cache this is otherwise we get BLOCK_SIZE copies of the format-string. Or otherwise we have to rewrite it so we generate the format-string only once per block, but that's also bug-prone and harder to read.)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but it creates operations that are within a block:

Value stringStart =
      rewriter.create<LLVM::GEPOp>(UnknownLoc::get(ctx), ptr_ty(i8_ty),
                                   globalPtr, SmallVector<Value>({zero, zero}));
  return stringStart;

if those are used by an instruction that isn't dominated by the cached ops it would break SSA. Can you change the code so that this gets called outside the loop?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed you're completely correct. Done, thanks.

Comment thread lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp Outdated
@Jokeren
Copy link
Copy Markdown
Contributor

Jokeren commented Oct 23, 2023

Another comment regarding the output format. Is it possible to make the format consistent with the output of the assert which also annotates block ids and thread ids?

@Jokeren Jokeren changed the title Improve printf. [BACKEND] Improve printf. Oct 23, 2023
@jlebar
Copy link
Copy Markdown
Contributor Author

jlebar commented Oct 23, 2023

Another comment regarding the output format. Is it possible to make the format consistent with the output of the assert which also annotates block ids and thread ids?

The assert message e.g.

/root/code/triton/python/test/unit/language/assert_helper.py:51: kernel_device_assert: block: [0,0,0], thread: [31,0,0] Assertion `x != 0` failed.

is actually coming from nvptx. The only part that we control is x != 0.

The block/thread-id here is CUDA, not Triton pid. AIUI Triton pid is only equivalent to the block id pre-Hopper; on Hopper, the Triton pid seems to be equal to the cluster ID.

Since we don't have control over the assertion message and since we can't get it to print out the actual Triton PID, I'm inclined not to try to match it. WDYT, @Jokeren ?

@jlebar jlebar changed the base branch from jlebar/dead-vprintf to main October 23, 2023 22:46
@jlebar jlebar force-pushed the jlebar/improve-printf branch from 5424cc4 to 9794ae7 Compare October 23, 2023 22:46
@jlebar jlebar changed the base branch from main to jlebar/asan-build October 23, 2023 22:46
@jlebar jlebar force-pushed the jlebar/asan-build branch from 8e47509 to 295550e Compare October 23, 2023 22:51
@jlebar jlebar force-pushed the jlebar/improve-printf branch 2 times, most recently from ad55ad0 to 28ae3e6 Compare October 23, 2023 23:40
@Jokeren
Copy link
Copy Markdown
Contributor

Jokeren commented Oct 24, 2023

The block/thread-id here is CUDA, not Triton pid. AIUI Triton pid is only equivalent to the block id pre-Hopper; on Hopper, the Triton pid seems to be equal to the cluster ID.

OK, that sounds reasonable.

Base automatically changed from jlebar/asan-build to jlebar/dead-vprintf October 24, 2023 02:42
@jlebar jlebar force-pushed the jlebar/dead-vprintf branch from f3c919c to 8b50cfd Compare October 24, 2023 03:07
@jlebar jlebar changed the base branch from jlebar/dead-vprintf to jlebar/asan-build October 24, 2023 03:07
@jlebar jlebar changed the base branch from jlebar/asan-build to main October 24, 2023 03:10
@jlebar jlebar force-pushed the jlebar/improve-printf branch from 28ae3e6 to 4844a4a Compare October 24, 2023 03:10
@jlebar jlebar changed the base branch from main to jlebar/dead-vprintf October 24, 2023 03:10
@ptillet
Copy link
Copy Markdown
Collaborator

ptillet commented Oct 24, 2023

thanks @jlebar ! It is a great improvement to our printer :)

@jlebar jlebar requested review from ThomasRaoux and removed request for ptillet October 24, 2023 04:51
@jlebar jlebar force-pushed the jlebar/dead-vprintf branch from 05876d1 to 6308407 Compare October 24, 2023 05:09
@jlebar jlebar force-pushed the jlebar/improve-printf branch from 4844a4a to 529141a Compare October 24, 2023 05:09
Copy link
Copy Markdown
Collaborator

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!

@jlebar jlebar force-pushed the jlebar/improve-printf branch from 529141a to 00d8de7 Compare October 24, 2023 17:30
@jlebar jlebar force-pushed the jlebar/dead-vprintf branch from 6308407 to 7fcaed7 Compare October 24, 2023 17:30
@jlebar jlebar force-pushed the jlebar/improve-printf branch from 00d8de7 to 75fa8a2 Compare October 24, 2023 19:16
@jlebar jlebar force-pushed the jlebar/dead-vprintf branch from 7fcaed7 to da2abfa Compare October 24, 2023 19:16
Base automatically changed from jlebar/dead-vprintf to main October 25, 2023 00:22
Previously, we printed all of a GPU thread's values in a single printf()
call, and this, plus the user-specified prefix, was all we printed.

This caused a few problems.

 - nvptx printf can only handle 32 arguments; if you pass more than
   that, it prints garbage.  So if a thread had more than 32 values, you
   couldn't print them, issue #2486.

 - The order of the values within the Triton program (GPU thread block)
   is an implementation detail -- it depends on the layout the compiler
   assigns to a tensor.  So this also prevented you from interpreting
   the printed output.

To address this, we now print the Triton pid and multi-dimensional
Tensor index for each value.  And each value gets its own line to avoid
passing too many args to printf.

Example output:

    ```
    pid (0, 1, 2) idx (36, 127) x: 42
    ```

If you want to observe all the values in a tensor in order, you can grep
and then sort the output.

We also make a UX enhancement to print: The printed label always ends
with ": "; you don't have to add it yourself.

Fixes #2486.

GPC: improve-printf
@jlebar jlebar force-pushed the jlebar/improve-printf branch from 75fa8a2 to 74564a5 Compare October 25, 2023 04:17
@jlebar jlebar enabled auto-merge (squash) October 25, 2023 04:17
@jlebar jlebar merged commit e70e11e into main Oct 25, 2023
@jlebar jlebar deleted the jlebar/improve-printf branch October 25, 2023 08:47
pingzhuu pushed a commit to siliconflow/triton that referenced this pull request Apr 2, 2024
[BACKEND] Improve printf.

Previously, we printed all of a GPU thread's values in a single printf()
call, and this, plus the user-specified prefix, was all we printed.

This caused a few problems.

 - nvptx printf can only handle 32 arguments; if you pass more than
   that, it prints garbage.  So if a thread had more than 32 values, you
   couldn't print them, issue triton-lang#2486.

 - The order of the values within the Triton program (GPU thread block)
   is an implementation detail -- it depends on the layout the compiler
   assigns to a tensor.  So this also prevented you from interpreting
   the printed output.

To address this, we now print the Triton pid and multi-dimensional
Tensor index for each value.  And each value gets its own line to avoid
passing too many args to printf.

Example output:

    ```
    pid (0, 1, 2) idx (36, 127) x: 42
    ```

If you want to observe all the values in a tensor in order, you can grep
and then sort the output.

We also make a UX enhancement to print: The printed label always ends
with ": "; you don't have to add it yourself.

Fixes triton-lang#2486.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

issues with print function

4 participants