-
Notifications
You must be signed in to change notification settings - Fork 459
[Bugfix] Disable Memory Info Analysis for local.var
#851
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,43 +1,149 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||
| from tvm.tir import BufferStore, For, AttrStmt, ForKind, Var, PrimFunc | ||||||||||||||||||||||||||||||||||||||||||||||||||
| from tvm.tir import (BufferStore, For, AttrStmt, ForKind, Var, PrimFunc, BufferLoad, Buffer, IntImm) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| from tvm.tir.stmt_functor import ir_transform, post_order_visit | ||||||||||||||||||||||||||||||||||||||||||||||||||
| from tvm.tir.transform import prim_func_pass | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| def AddWrapperForSingleBufStore(): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||
| Creates a TVM pass that wraps single buffer stores with parallel loops. | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| This transformation adds T.Parallel wrappers around buffer stores that: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| 1. Access fragment buffers with index 0 | ||||||||||||||||||||||||||||||||||||||||||||||||||
| 2. Are not inside existing tile operations or thread bindings | ||||||||||||||||||||||||||||||||||||||||||||||||||
| 3. Don't access fragment buffers with non-zero indices | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| A prim_func_pass that applies the transformation | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| def pass_fn(func: PrimFunc, mod, ctx): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| pfor = 0 | ||||||||||||||||||||||||||||||||||||||||||||||||||
| thread_binding_var = set() | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| def get_used_var(op): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| used_var = set() | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| def visit_fn(x): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(x, Var): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| used_var.add(x) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| post_order_visit(op, visit_fn) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| return used_var | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| def is_tile_op_for(op: For): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| return op.kind == ForKind.PARALLEL or 'num_stages' in op.annotations | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| def pre_visit(stmt): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| nonlocal pfor | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(stmt, AttrStmt) and stmt.attr_key == 'thread_extent': | ||||||||||||||||||||||||||||||||||||||||||||||||||
| thread_binding_var.add(stmt.node.var) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(stmt, For): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| pfor += is_tile_op_for(stmt) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| def post_visit(stmt): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| nonlocal pfor | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(stmt, For): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| pfor -= is_tile_op_for(stmt) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(stmt, BufferStore): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| used_var = get_used_var(stmt) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| used_binding = used_var.intersection(thread_binding_var) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if not pfor and len(used_binding) == 0: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| return For(Var("_", "int"), 0, 1, ForKind.PARALLEL, stmt) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # Counter for tracking nested tile operations | ||||||||||||||||||||||||||||||||||||||||||||||||||
| tile_operation_depth = 0 | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # Set of variables bound to threads | ||||||||||||||||||||||||||||||||||||||||||||||||||
| thread_binding_vars = set() | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| def get_used_variables(operation) -> set: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||
| Collects all variables used in the given operation. | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| operation: The TIR operation to analyze | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| Set of variables used in the operation | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||
| used_variables = set() | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| def visit_variable(node): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(node, Var): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| used_variables.add(node) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| post_order_visit(operation, visit_variable) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| return used_variables | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| def collect_buffer_accesses(statement) -> tuple[list[Buffer], list[Buffer]]: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||
| Categorizes buffers accessed in the statement by their scope. | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| statement: The TIR statement to analyze | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| Tuple of (local_buffers, fragment_buffers) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||
| accessed_buffers = set() | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| def visit_buffer_access(node): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(node, (BufferLoad, BufferStore)): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| accessed_buffers.add(node.buffer) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| post_order_visit(statement, visit_buffer_access) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| local_buffers = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||
| fragment_buffers = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||
| for buffer in accessed_buffers: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if buffer.scope() == "local.fragment": | ||||||||||||||||||||||||||||||||||||||||||||||||||
| fragment_buffers.append(buffer) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| elif buffer.scope().startswith("local"): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| local_buffers.append(buffer) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| return local_buffers, fragment_buffers | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| def collect_buffer_indices(statement) -> dict[Buffer, list[int]]: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||
| Maps each buffer to its access indices. | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| statement: The TIR statement to analyze | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| Dictionary mapping buffers to their access indices | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||
| buffer_to_indices = {} | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| def visit_buffer_access(node): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(node, (BufferLoad, BufferStore)): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| buffer_to_indices[node.buffer] = node.indices | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| post_order_visit(statement, visit_buffer_access) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| return buffer_to_indices | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| def is_tile_operation_loop(loop: For) -> bool: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||
| Determines if a For loop is a tile operation. | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| loop: The For loop to check | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| True if the loop is a tile operation (parallel or has num_stages annotation) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||
| return loop.kind == ForKind.PARALLEL or 'num_stages' in loop.annotations | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| def pre_visit(statement): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||
| Pre-order visitor that tracks thread bindings and tile operation depth. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||
| nonlocal tile_operation_depth | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(statement, AttrStmt) and statement.attr_key == 'thread_extent': | ||||||||||||||||||||||||||||||||||||||||||||||||||
| thread_binding_vars.add(statement.node.var) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| elif isinstance(statement, For) and is_tile_operation_loop(statement): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| tile_operation_depth += 1 | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| def post_visit(statement): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||
| Post-order visitor that applies transformations and updates counters. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||
| nonlocal tile_operation_depth | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(statement, For) and is_tile_operation_loop(statement): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| tile_operation_depth -= 1 | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| elif isinstance(statement, BufferStore): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| used_variables = get_used_variables(statement) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| thread_bound_variables = used_variables.intersection(thread_binding_vars) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| # Only transform if not inside tile operations and no thread bindings | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if tile_operation_depth == 0 and len(thread_bound_variables) == 0: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # Skip if no fragment buffers are accessed | ||||||||||||||||||||||||||||||||||||||||||||||||||
| _, fragment_buffers = collect_buffer_accesses(statement) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if len(fragment_buffers) == 0: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| return statement | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| # Validate fragment buffer indices - only index 0 is supported | ||||||||||||||||||||||||||||||||||||||||||||||||||
| buffer_indices = collect_buffer_indices(statement) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| for buffer, indices in buffer_indices.items(): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if buffer.scope() == "local.fragment": | ||||||||||||||||||||||||||||||||||||||||||||||||||
| for index in indices: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(index, IntImm) and index != 0: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||||||||||||||||||||||||||||
| f"Fragment buffer access with non-zero index [{index}] is not supported. " | ||||||||||||||||||||||||||||||||||||||||||||||||||
| "Only fragment[0] access is allowed.") | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+138
to
+141
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The validation for fragment buffer indices seems too permissive. It only raises an error for constant non-zero indices, but allows variable indices to pass through. This could lead to incorrect transformations if a variable index evaluates to a non-zero value at runtime. According to the docstring "Don't access fragment buffers with non-zero indices", the check should be stricter to only allow provably zero indices.
Suggested change
Comment on lines
+134
to
+141
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Integer comparison issue with IntImm. The comparison if buffer.scope() == "local.fragment":
for index in indices:
- if isinstance(index, IntImm) and index != 0:
+ if isinstance(index, IntImm) and index.value != 0:
raise ValueError(
f"Fragment buffer access with non-zero index [{index}] is not supported. "
"Only fragment[0] access is allowed.")📝 Committable suggestion
Suggested change
🧰 Tools🪛 Ruff (0.12.2)139-141: Avoid specifying long messages outside the exception class (TRY003) 🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| # Wrap fragment[0] access with T.Parallel loop | ||||||||||||||||||||||||||||||||||||||||||||||||||
| return For(Var("_", "int32"), 0, 1, ForKind.PARALLEL, statement) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| return statement | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| new_body = ir_transform(func.body, pre_visit, post_visit) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
local_bufferslist is populated incollect_buffer_accessesbut its value is discarded by the caller at line 129. To improve code clarity and remove unused logic, consider refactoringcollect_buffer_accessesto only compute and returnfragment_buffers.