Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/transform/storage_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,7 @@ class StoragePlanRewriter : public StmtExprMutator {
bool IsSpecialTaggedMemory(const StorageScope &scope) {
return !scope.tag.empty() && scope.tag != ".dyn" &&
scope.tag != ".barrier" && scope.tag != ".workspace" &&
scope.tag != ".vtcm";
scope.tag != ".vtcm" && scope.tag != ".var";
}

// Allocate entry of node.
Expand Down
172 changes: 139 additions & 33 deletions tilelang/transform/add_bufstore_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,149 @@
from tvm.tir import BufferStore, For, AttrStmt, ForKind, Var, PrimFunc
from tvm.tir import (BufferStore, For, AttrStmt, ForKind, Var, PrimFunc, BufferLoad, Buffer, IntImm)
from tvm.tir.stmt_functor import ir_transform, post_order_visit
from tvm.tir.transform import prim_func_pass


def AddWrapperForSingleBufStore():
"""
Creates a TVM pass that wraps single buffer stores with parallel loops.

This transformation adds T.Parallel wrappers around buffer stores that:
1. Access fragment buffers with index 0
2. Are not inside existing tile operations or thread bindings
3. Don't access fragment buffers with non-zero indices

Returns:
A prim_func_pass that applies the transformation
"""

def pass_fn(func: PrimFunc, mod, ctx):
pfor = 0
thread_binding_var = set()

def get_used_var(op):
used_var = set()

def visit_fn(x):
if isinstance(x, Var):
used_var.add(x)

post_order_visit(op, visit_fn)
return used_var

def is_tile_op_for(op: For):
return op.kind == ForKind.PARALLEL or 'num_stages' in op.annotations

def pre_visit(stmt):
nonlocal pfor
if isinstance(stmt, AttrStmt) and stmt.attr_key == 'thread_extent':
thread_binding_var.add(stmt.node.var)
if isinstance(stmt, For):
pfor += is_tile_op_for(stmt)

def post_visit(stmt):
nonlocal pfor
if isinstance(stmt, For):
pfor -= is_tile_op_for(stmt)
if isinstance(stmt, BufferStore):
used_var = get_used_var(stmt)
used_binding = used_var.intersection(thread_binding_var)
if not pfor and len(used_binding) == 0:
return For(Var("_", "int"), 0, 1, ForKind.PARALLEL, stmt)
# Counter for tracking nested tile operations
tile_operation_depth = 0
# Set of variables bound to threads
thread_binding_vars = set()

def get_used_variables(operation) -> set:
"""
Collects all variables used in the given operation.

Args:
operation: The TIR operation to analyze

Returns:
Set of variables used in the operation
"""
used_variables = set()

def visit_variable(node):
if isinstance(node, Var):
used_variables.add(node)

post_order_visit(operation, visit_variable)
return used_variables

def collect_buffer_accesses(statement) -> tuple[list[Buffer], list[Buffer]]:
"""
Categorizes buffers accessed in the statement by their scope.

Args:
statement: The TIR statement to analyze

Returns:
Tuple of (local_buffers, fragment_buffers)
"""
accessed_buffers = set()

def visit_buffer_access(node):
if isinstance(node, (BufferLoad, BufferStore)):
accessed_buffers.add(node.buffer)

post_order_visit(statement, visit_buffer_access)

local_buffers = []
fragment_buffers = []
for buffer in accessed_buffers:
if buffer.scope() == "local.fragment":
fragment_buffers.append(buffer)
elif buffer.scope().startswith("local"):
local_buffers.append(buffer)
return local_buffers, fragment_buffers
Comment on lines +44 to +69
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The local_buffers list is populated in collect_buffer_accesses but its value is discarded by the caller at line 129. To improve code clarity and remove unused logic, consider refactoring collect_buffer_accesses to only compute and return fragment_buffers.


def collect_buffer_indices(statement) -> dict[Buffer, list[int]]:
"""
Maps each buffer to its access indices.

Args:
statement: The TIR statement to analyze

Returns:
Dictionary mapping buffers to their access indices
"""
buffer_to_indices = {}

def visit_buffer_access(node):
if isinstance(node, (BufferLoad, BufferStore)):
buffer_to_indices[node.buffer] = node.indices

post_order_visit(statement, visit_buffer_access)
return buffer_to_indices

def is_tile_operation_loop(loop: For) -> bool:
"""
Determines if a For loop is a tile operation.

Args:
loop: The For loop to check

Returns:
True if the loop is a tile operation (parallel or has num_stages annotation)
"""
return loop.kind == ForKind.PARALLEL or 'num_stages' in loop.annotations

def pre_visit(statement):
"""
Pre-order visitor that tracks thread bindings and tile operation depth.
"""
nonlocal tile_operation_depth

if isinstance(statement, AttrStmt) and statement.attr_key == 'thread_extent':
thread_binding_vars.add(statement.node.var)
elif isinstance(statement, For) and is_tile_operation_loop(statement):
tile_operation_depth += 1

def post_visit(statement):
"""
Post-order visitor that applies transformations and updates counters.
"""
nonlocal tile_operation_depth

if isinstance(statement, For) and is_tile_operation_loop(statement):
tile_operation_depth -= 1

elif isinstance(statement, BufferStore):
used_variables = get_used_variables(statement)
thread_bound_variables = used_variables.intersection(thread_binding_vars)

# Only transform if not inside tile operations and no thread bindings
if tile_operation_depth == 0 and len(thread_bound_variables) == 0:
# Skip if no fragment buffers are accessed
_, fragment_buffers = collect_buffer_accesses(statement)
if len(fragment_buffers) == 0:
return statement

# Validate fragment buffer indices - only index 0 is supported
buffer_indices = collect_buffer_indices(statement)
for buffer, indices in buffer_indices.items():
if buffer.scope() == "local.fragment":
for index in indices:
if isinstance(index, IntImm) and index != 0:
raise ValueError(
f"Fragment buffer access with non-zero index [{index}] is not supported. "
"Only fragment[0] access is allowed.")
Comment on lines +138 to +141
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The validation for fragment buffer indices seems too permissive. It only raises an error for constant non-zero indices, but allows variable indices to pass through. This could lead to incorrect transformations if a variable index evaluates to a non-zero value at runtime. According to the docstring "Don't access fragment buffers with non-zero indices", the check should be stricter to only allow provably zero indices.

Suggested change
if isinstance(index, IntImm) and index != 0:
raise ValueError(
f"Fragment buffer access with non-zero index [{index}] is not supported. "
"Only fragment[0] access is allowed.")
if not (isinstance(index, IntImm) and index.value == 0):
raise ValueError(
f"Fragment buffer access with non-constant or non-zero index [{index}] is not supported. "
"Only fragment[0] access is allowed.")

Comment on lines +134 to +141
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Integer comparison issue with IntImm.

The comparison index != 0 compares an IntImm object with an integer. This should compare the value instead.

                         if buffer.scope() == "local.fragment":
                             for index in indices:
-                                if isinstance(index, IntImm) and index != 0:
+                                if isinstance(index, IntImm) and index.value != 0:
                                     raise ValueError(
                                         f"Fragment buffer access with non-zero index [{index}] is not supported. "
                                         "Only fragment[0] access is allowed.")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
buffer_indices = collect_buffer_indices(statement)
for buffer, indices in buffer_indices.items():
if buffer.scope() == "local.fragment":
for index in indices:
if isinstance(index, IntImm) and index != 0:
raise ValueError(
f"Fragment buffer access with non-zero index [{index}] is not supported. "
"Only fragment[0] access is allowed.")
buffer_indices = collect_buffer_indices(statement)
for buffer, indices in buffer_indices.items():
if buffer.scope() == "local.fragment":
for index in indices:
if isinstance(index, IntImm) and index.value != 0:
raise ValueError(
f"Fragment buffer access with non-zero index [{index}] is not supported. "
"Only fragment[0] access is allowed.")
🧰 Tools
🪛 Ruff (0.12.2)

139-141: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In tilelang/transform/add_bufstore_wrapper.py around lines 134 to 141, the
comparison `index != 0` is comparing an IntImm object to an int; change the
check to compare the IntImm's numeric value (e.g., `index.value != 0`) and
update the error message to include the numeric value (use `index.value`) so the
condition and message operate on the actual integer rather than the IntImm
object.


# Wrap fragment[0] access with T.Parallel loop
return For(Var("_", "int32"), 0, 1, ForKind.PARALLEL, statement)

return statement

new_body = ir_transform(func.body, pre_visit, post_visit)

Expand Down
Loading