[BACKEND] Improve printf.#2532
Conversation
| // Cache the llvm Value for each format string to avoid making duplicates in | ||
| // the module. Sorry about `mutable`, but the top-level matchAndRewrite | ||
| // function is const, so it virally comes down here. | ||
| mutable llvm::StringMap<Value> formatStrCache; |
There was a problem hiding this comment.
Caching will break if the new print function is not dominated by the value that has been cached. For instance it will always fail if we are printing from different functions. In general doing caching during rewrite patterns is not trivial, we could cache only the global value but this is probably a premature optimization.
There was a problem hiding this comment.
addStringToModule adds the string as a global, so shouldn't this work? https://github.com/openai/triton/blob/020f43d5a37c2f1cce5c70fcda17ec57b000755c/lib/Conversion/TritonGPUToLLVM/Utility.cpp#L354
(The reason I wanted to cache this is otherwise we get BLOCK_SIZE copies of the format-string. Or otherwise we have to rewrite it so we generate the format-string only once per block, but that's also bug-prone and harder to read.)
There was a problem hiding this comment.
but it creates operations that are within a block:
Value stringStart =
rewriter.create<LLVM::GEPOp>(UnknownLoc::get(ctx), ptr_ty(i8_ty),
globalPtr, SmallVector<Value>({zero, zero}));
return stringStart;
if those are used by an instruction that isn't dominated by the cached ops it would break SSA. Can you change the code so that this gets called outside the loop?
There was a problem hiding this comment.
Indeed you're completely correct. Done, thanks.
|
Another comment regarding the output format. Is it possible to make the format consistent with the output of the assert which also annotates block ids and thread ids? |
The assert message e.g. is actually coming from nvptx. The only part that we control is The block/thread-id here is CUDA, not Triton pid. AIUI Triton pid is only equivalent to the block id pre-Hopper; on Hopper, the Triton pid seems to be equal to the cluster ID. Since we don't have control over the assertion message and since we can't get it to print out the actual Triton PID, I'm inclined not to try to match it. WDYT, @Jokeren ? |
5424cc4 to
9794ae7
Compare
8e47509 to
295550e
Compare
ad55ad0 to
28ae3e6
Compare
OK, that sounds reasonable. |
f3c919c to
8b50cfd
Compare
28ae3e6 to
4844a4a
Compare
|
thanks @jlebar ! It is a great improvement to our printer :) |
05876d1 to
6308407
Compare
4844a4a to
529141a
Compare
529141a to
00d8de7
Compare
6308407 to
7fcaed7
Compare
00d8de7 to
75fa8a2
Compare
7fcaed7 to
da2abfa
Compare
Previously, we printed all of a GPU thread's values in a single printf() call, and this, plus the user-specified prefix, was all we printed. This caused a few problems. - nvptx printf can only handle 32 arguments; if you pass more than that, it prints garbage. So if a thread had more than 32 values, you couldn't print them, issue #2486. - The order of the values within the Triton program (GPU thread block) is an implementation detail -- it depends on the layout the compiler assigns to a tensor. So this also prevented you from interpreting the printed output. To address this, we now print the Triton pid and multi-dimensional Tensor index for each value. And each value gets its own line to avoid passing too many args to printf. Example output: ``` pid (0, 1, 2) idx (36, 127) x: 42 ``` If you want to observe all the values in a tensor in order, you can grep and then sort the output. We also make a UX enhancement to print: The printed label always ends with ": "; you don't have to add it yourself. Fixes #2486. GPC: improve-printf
75fa8a2 to
74564a5
Compare
[BACKEND] Improve printf. Previously, we printed all of a GPU thread's values in a single printf() call, and this, plus the user-specified prefix, was all we printed. This caused a few problems. - nvptx printf can only handle 32 arguments; if you pass more than that, it prints garbage. So if a thread had more than 32 values, you couldn't print them, issue triton-lang#2486. - The order of the values within the Triton program (GPU thread block) is an implementation detail -- it depends on the layout the compiler assigns to a tensor. So this also prevented you from interpreting the printed output. To address this, we now print the Triton pid and multi-dimensional Tensor index for each value. And each value gets its own line to avoid passing too many args to printf. Example output: ``` pid (0, 1, 2) idx (36, 127) x: 42 ``` If you want to observe all the values in a tensor in order, you can grep and then sort the output. We also make a UX enhancement to print: The printed label always ends with ": "; you don't have to add it yourself. Fixes triton-lang#2486.
[BACKEND] Improve printf.
Previously, we printed all of a GPU thread's values in a single printf()
call, and this, plus the user-specified prefix, was all we printed.
This caused a few problems.
nvptx printf can only handle 32 arguments; if you pass more than
that, it prints garbage. So if a thread had more than 32 values, you
couldn't print them, issue issues with print function #2486.
The order of the values within the Triton program (GPU thread block)
is an implementation detail -- it depends on the layout the compiler
assigns to a tensor. So this also prevented you from interpreting
the printed output.
To address this, we now print the Triton pid and multi-dimensional
Tensor index for each value. And each value gets its own line to avoid
passing too many args to printf.
Example output:
If you want to observe all the values in a tensor in order, you can grep
and then sort the output.
We also make a UX enhancement to print: The printed label always ends
with ": "; you don't have to add it yourself.
Fixes #2486.
PR chain