-
Notifications
You must be signed in to change notification settings - Fork 450
[Enhancement] Introduce T.__ldg
#1414
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2354,6 +2354,23 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { | |
| stream << ": \"l\"((void*)(" << global_buffer << "+" << global_addr | ||
| << ")), \"r\"((int)" << guard << ")\n"; | ||
| stream << ");\n"; | ||
| } else if (op->op.same_as(tl::__ldg())) { | ||
| // Explicit read-only cached load. Preferred form: __ldg(BufferLoad(...)). | ||
| // Fallback form: __ldg(buffer, index) | ||
| const BufferLoadNode *bl = nullptr; | ||
| if (!op->args.empty()) { | ||
| bl = op->args[0].as<BufferLoadNode>(); | ||
| } | ||
| if (bl == nullptr) { | ||
| LOG(FATAL) << "T.__ldg expects a BufferLoad as the first argument."; | ||
| } | ||
| const BufferNode *buffer = bl->buffer.get(); | ||
| ICHECK_EQ(bl->indices.size(), 1) | ||
| << "T.__ldg currently supports flattened 1D buffer accesses."; | ||
| PrimExpr base = bl->indices[0]; | ||
| // Emit __ldg(&buffer_ref) | ||
| auto buffer_ref = this->GetBufferRef(op->dtype, buffer, base); | ||
| os << "__ldg(&(" << buffer_ref << "))"; | ||
|
Comment on lines
+2357
to
+2373
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: # First, confirm the exact code at the specified lines
cat -n src/target/codegen_cuda.cc | sed -n '2357,2373p'Repository: tile-ai/tilelang Length of output: 905 🏁 Script executed: # Search for other __ldg usage patterns in the codebase
rg '__ldg' --type cc --type h -B 2 -A 5Repository: tile-ai/tilelang Length of output: 86 🏁 Script executed: # Check if there are existing dtype checks in similar CUDA operations
rg 'dtype.*bits\(\)' --type cc --type h -B 1 -A 1 | head -50Repository: tile-ai/tilelang Length of output: 86 🏁 Script executed: # Search for CUDA capability checks and SM arch support in the codebase
rg 'SM_[0-9]' --type cc --type h | head -20Repository: tile-ai/tilelang Length of output: 86 🏁 Script executed: # Look for TileLang CI configuration
fd -e yml -e yaml -e json | grep -i 'ci\|test' | head -20Repository: tile-ai/tilelang Length of output: 42 🌐 Web query:
💡 Result: Official NVIDIA docs state __ldg supports:
Sources: NVIDIA CUDA Programming Guide / C++ Language Extensions (see Read-Only Data Cache / __ldg entries). [1][2] 🏁 Script executed: # Search for __ldg usage and other dtype checks without type specification
rg '__ldg|dtype\.bits\(\)|is_scalar' -A 3 -B 1Repository: tile-ai/tilelang Length of output: 50372 🏁 Script executed: # Look at the full context around the __ldg implementation
sed -n '2340,2380p' src/target/codegen_cuda.ccRepository: tile-ai/tilelang Length of output: 1979 🏁 Script executed: # Search for dtype validation patterns in codegen_cuda.cc
sed -n '1,100p' src/target/codegen_cuda.cc | grep -i 'dtype\|check'Repository: tile-ai/tilelang Length of output: 42 🏁 Script executed: # Find other CUDA intrinsic handlers to see if they have dtype gating
rg 'LOG\(FATAL\)|ICHECK.*dtype' src/target/codegen_cuda.cc -B 2 -A 1 | head -60Repository: tile-ai/tilelang Length of output: 1603 🏁 Script executed: # Look for any existing 32/64-bit checks in CUDA codegen
rg '\.bits\(\).*32|\.bits\(\).*64' -A 1 -B 1Repository: tile-ai/tilelang Length of output: 9010 🏁 Script executed: # Check builtin.h and builtin.cc to understand the design intent
sed -n '30,50p' src/op/builtin.h | grep -A 15 "__ldg"Repository: tile-ai/tilelang Length of output: 42 🏁 Script executed: # Look for any architecture baseline in CI config or comments
fd -e yml -e yaml . | head -5 | xargs grep -l 'cuda\|gpu\|compute_capability' 2>/dev/null | head -3Repository: tile-ai/tilelang Length of output: 42 🏁 Script executed: # Check if there are baseline SM arch definitions in the codebase
rg 'compute_capability|sm_[0-9]|SM_[0-9]' --type cc --type h -i | head -10Repository: tile-ai/tilelang Length of output: 86 Remove misleading "fallback form" comment and add dtype validation for The comment at lines 2358–2359 claims a fallback form Update the comment to reflect the actual implementation, replace } else if (op->op.same_as(tl::__ldg())) {
- // Explicit read-only cached load. Preferred form: __ldg(BufferLoad(...)).
- // Fallback form: __ldg(buffer, index)
+ // Explicit read-only cached load. Expects a flattened 1D BufferLoad.
const BufferLoadNode *bl = nullptr;
if (!op->args.empty()) {
bl = op->args[0].as<BufferLoadNode>();
}
- if (bl == nullptr) {
- LOG(FATAL) << "T.__ldg expects a BufferLoad as the first argument.";
- }
+ ICHECK(bl) << "T.__ldg expects a BufferLoad as the first argument.";
const BufferNode *buffer = bl->buffer.get();
ICHECK_EQ(bl->indices.size(), 1)
<< "T.__ldg currently supports flattened 1D buffer accesses.";
+ // Validate dtype: baseline support for 32/64-bit scalars.
+ ICHECK(op->dtype.is_scalar())
+ << "T.__ldg currently supports scalar element loads only, but got " << op->dtype;
+ ICHECK(op->dtype.bits() == 32 || op->dtype.bits() == 64)
+ << "T.__ldg only supports 32/64-bit scalar types for now, but got " << op->dtype;
PrimExpr base = bl->indices[0];
// Emit __ldg(&buffer_ref)
auto buffer_ref = this->GetBufferRef(op->dtype, buffer, base);
os << "__ldg(&(" << buffer_ref << "))";🤖 Prompt for AI Agents |
||
| } else if (op->op.same_as(builtin::reinterpret())) { | ||
| DataType tgt_dtype = op->dtype; | ||
| DataType src_dtype = op->args[0]->dtype; | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,30 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import tilelang | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import tilelang.language as T | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import tilelang.testing | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @tilelang.testing.requires_cuda | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def test_language_ldg_codegen(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| N = 128 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @T.prim_func | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def main( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| x: T.Tensor((N,), "float32"), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| y: T.Tensor((N,), "float32"), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| with T.Kernel(N, threads=32) as pid: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Explicitly request read-only cache load for x[pid] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| y[pid] = T.__ldg(x[pid]) + 1.0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Compile for CUDA and retrieve generated CUDA source | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| kernel = tilelang.compile(main, out_idx=[1], target="cuda") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| src = kernel.get_kernel_source() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| print(src) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Assert that codegen uses __ldg on CUDA backend | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # We look for the intrinsic call with address-of argument | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert "__ldg(" in src, "Expected __ldg call in generated CUDA source" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert "__ldg(&" in src or "__ldg(&(" in src, "Expected address-of form in __ldg call" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
6
to
27
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Avoid unconditional - print(src)
+ # print(src) # uncomment for debugging📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if __name__ == "__main__": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tilelang.testing.main() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -96,6 +96,7 @@ | |||||
| ) | ||||||
| from .logical import any_of, all_of # noqa: F401 | ||||||
| from .builtin import * # noqa: F401 | ||||||
| from .builtin import __ldg as __ldg # noqa: F401 | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Keep the explicit import, but drop the unused -from .builtin import __ldg as __ldg # noqa: F401
+from .builtin import __ldg as __ldg📝 Committable suggestion
Suggested change
🧰 Tools🪛 Ruff (0.14.8)99-99: Unused Remove unused (RUF100) 🤖 Prompt for AI Agents |
||||||
|
|
||||||
| from .utils import index_to_coordinates # noqa: F401 | ||||||
|
|
||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -59,6 +59,35 @@ def create_list_of_mbarrier(*args: Any) -> Call: | |
| raise TypeError("create_list_of_mbarrier expects a list or one or more arguments.") | ||
|
|
||
|
|
||
| def __ldg(load_or_buf: BufferLoad | tir.Buffer, index: PrimExpr | int | None = None) -> PrimExpr: | ||
| """Explicitly load via CUDA read-only data cache. | ||
|
|
||
| Prefer calling with a BufferLoad: `T.__ldg(x[i])` emits `__ldg(&x[i])` on CUDA. | ||
| On non-CUDA backends, falls back to a regular load. | ||
|
|
||
| Args: | ||
| load_or_buf: A `BufferLoad` like `x[i]`, or a `Buffer`. | ||
| index: Optional index when passing a `Buffer` directly. | ||
|
|
||
| Returns: | ||
| PrimExpr: The loaded value. | ||
| """ | ||
| if isinstance(load_or_buf, BufferLoad): | ||
| dtype = load_or_buf.dtype | ||
| return tir.call_intrin(str(dtype), tir.op.Op.get("tl.__ldg"), load_or_buf) | ||
| if isinstance(load_or_buf, tir.Buffer): | ||
| if index is None: | ||
| raise ValueError("T.__ldg(Buffer, index) requires an index when passing a Buffer.") | ||
| idx = index | ||
| if isinstance(index, (list, tuple)): | ||
| if len(index) != 1: | ||
| raise ValueError("T.__ldg currently supports 1D flattened indices.") | ||
| idx = index[0] | ||
| bl = BufferLoad(load_or_buf, [idx]) | ||
| return tir.call_intrin(str(load_or_buf.dtype), tir.op.Op.get("tl.__ldg"), bl) | ||
| raise TypeError("T.__ldg expects a BufferLoad or a Buffer.") | ||
|
|
||
|
Comment on lines
+62
to
+89
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Normalize def __ldg(load_or_buf: BufferLoad | tir.Buffer, index: PrimExpr | int | None = None) -> PrimExpr:
@@
if isinstance(load_or_buf, tir.Buffer):
if index is None:
raise ValueError("T.__ldg(Buffer, index) requires an index when passing a Buffer.")
- idx = index
+ idx = index
if isinstance(index, (list, tuple)):
if len(index) != 1:
raise ValueError("T.__ldg currently supports 1D flattened indices.")
idx = index[0]
- bl = BufferLoad(load_or_buf, [idx])
+ bl = BufferLoad(load_or_buf, [convert(idx)])
return tir.call_intrin(str(load_or_buf.dtype), tir.op.Op.get("tl.__ldg"), bl)
raise TypeError("T.__ldg expects a BufferLoad or a Buffer.")🧰 Tools🪛 Ruff (0.14.8)80-80: Avoid specifying long messages outside the exception class (TRY003) 84-84: Avoid specifying long messages outside the exception class (TRY003) 88-88: Avoid specifying long messages outside the exception class (TRY003) 🤖 Prompt for AI Agents |
||
|
|
||
| def get_mbarrier(*args): | ||
| """Retrieve a memory barrier operation. | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
Repository: tile-ai/tilelang
Length of output: 1164
🏁 Script executed:
Repository: tile-ai/tilelang
Length of output: 86
🏁 Script executed:
Repository: tile-ai/tilelang
Length of output: 86
🏁 Script executed:
Repository: tile-ai/tilelang
Length of output: 449
🏁 Script executed:
Repository: tile-ai/tilelang
Length of output: 8774
🏁 Script executed:
Repository: tile-ai/tilelang
Length of output: 112
🏁 Script executed:
Repository: tile-ai/tilelang
Length of output: 1724
🏁 Script executed:
Repository: tile-ai/tilelang
Length of output: 1170
🏁 Script executed:
Repository: tile-ai/tilelang
Length of output: 4187
🌐 Web query:
CUDA __ldg intrinsic supported data types element size constraints💡 Result:
Short answer — native CUDA overloads of __ldg exist only for scalar and small vector built-ins; other types must be handled via a template wrapper (or cast) and the hardware performs loads in natural sizes (1/2/4/8 and up to 16‑byte transactions). Native overload list and constraints:
Native CUDA __ldg overloads (built‑in types): char, short, int, long long, unsigned char, unsigned short, unsigned int, unsigned long long, float, double and the vector variants int2/int4, uint2/uint4, float2/float4, double2. For other/user types use a template wrapper (e.g., template device T ldg(const T*) that calls __ldg for supported types and falls back to *ptr otherwise). [1]
Element/transaction size: GPU memory transactions ordinarily use 1, 2, 4, 8 byte accesses and can coalesce up to 16 bytes per thread (so user types should be sized and aligned to match 4/8/16‑byte boundaries for best behavior). __ldg maps to cacheable global loads (ld.global.nc) introduced on sm_35 and later. [2][3]
If you want exact behavior for a custom type, use the template overload approach (or memcpy/cast to a supported vector type of matching size/alignment) and ensure sizes are a multiple of 4 bytes (or aligned to natural alignment) to avoid unexpected misalignment/coalescing penalties. [1][2]
Remove or correct the "32-bit element types" constraint in the documentation.
The doc claims
__ldgis for "32-bit element types on supported architectures," but CUDA's native__ldgoverloads actually support 8, 16, 32, and 64-bit scalar types (char, short, int, long long, float, double) plus vector variants (int2/4, uint2/4, float2/4, double2). More importantly, the CUDA codegen insrc/target/codegen_cuda.ccdoes not validate the element type at all—it will emit__ldgfor any dtype, including unsupported ones like fp16 or custom structs. Either enforce type validation against CUDA's actual supported types or update the doc to remove the "32-bit" restriction.🤖 Prompt for AI Agents