Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/transform/lower_tile_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -998,11 +998,13 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
bool has_non_local = false;
PostOrderVisit(for_node->body, [&](const ObjectRef &obj) {
if (const auto *load = obj.as<BufferLoadNode>()) {
if (!IsLocalBuffer(load->buffer) && !IsFragmentBuffer(load->buffer)) {
if (!IsLocalBuffer(load->buffer, /*allow_var*/ true) &&
!IsFragmentBuffer(load->buffer)) {
has_non_local = true;
}
} else if (const auto *store = obj.as<BufferStoreNode>()) {
if (!IsLocalBuffer(store->buffer) && !IsFragmentBuffer(store->buffer)) {
if (!IsLocalBuffer(store->buffer, /*allow_var*/ true) &&
!IsFragmentBuffer(store->buffer)) {
has_non_local = true;
}
}
Expand Down
16 changes: 16 additions & 0 deletions testing/python/language/test_tilelang_language_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,5 +66,21 @@ def test_parallel_dynamic_extent():
torch.testing.assert_close(out, reference, atol=1e-5, rtol=1e-5)


@tilelang.jit
def _parallel_vectorize_local_and_var():
with T.Kernel(1) as _:
x = T.alloc_fragment([256], T.float32)
y = T.alloc_fragment([256], T.float32)
z = T.alloc_var(T.float32)
for i in T.parallel(256):
y[i] = x[i] * z

Comment on lines +69 to +77
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

# First, find and read the test file to verify the actual casing
fd test_tilelang_language_parallel.py --type f

Repository: tile-ai/tilelang

Length of output: 118


🏁 Script executed:

# Read the specific lines mentioned in the review
if [ -f "testing/python/language/test_tilelang_language_parallel.py" ]; then
  echo "=== Lines 69-77 (the disputed code) ===" && \
  sed -n '69,77p' testing/python/language/test_tilelang_language_parallel.py && \
  echo -e "\n=== Lines 18, 33, 36 (for comparison) ===" && \
  sed -n '18p;33p;36p' testing/python/language/test_tilelang_language_parallel.py
fi

Repository: tile-ai/tilelang

Length of output: 538


🏁 Script executed:

# Search for T.parallel and T.Parallel definitions in the codebase
rg "def parallel|def Parallel" --type py -C 2

Repository: tile-ai/tilelang

Length of output: 10557


🏁 Script executed:

# Check how T is imported and what it refers to
rg "^import T|^from.*import.*T|^T = " testing/python/language/test_tilelang_language_parallel.py -C 1

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Check the imports at the top of the test file
head -30 testing/python/language/test_tilelang_language_parallel.py

Repository: tile-ai/tilelang

Length of output: 783


🏁 Script executed:

# Search for T usage and imports more broadly
rg "^import|^from" testing/python/language/test_tilelang_language_parallel.py | head -20

Repository: tile-ai/tilelang

Length of output: 156


🏁 Script executed:

# Check what T module contains and where Parallel is accessible from
cat -n tilelang/language/loop.py | head -50

Repository: tile-ai/tilelang

Length of output: 2267


🏁 Script executed:

# Check if parallel is aliased to Parallel or available in the same module
rg "parallel|Parallel" tilelang/language/__init__.py -C 3

Repository: tile-ai/tilelang

Length of output: 375


🏁 Script executed:

# Check what else is exported from __init__.py, especially any lowercase parallel
rg "parallel" tilelang/language/__init__.py

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Check if there's an alias or if parallel is defined elsewhere
rg "^parallel" tilelang/language/ -A 2

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Let's also check the specific loop.py file more thoroughly for any lowercase parallel
rg -i "parallel" tilelang/language/loop.py | grep -v "^\s*#"

Repository: tile-ai/tilelang

Length of output: 999


🏁 Script executed:

# Check if parallel exists in tir/ir.py and ast/ir.py that was mentioned in the earlier grep
cat -n tilelang/language/tir/ir.py | grep -A 10 "^.*def parallel"

Repository: tile-ai/tilelang

Length of output: 436


🏁 Script executed:

# Check if tir/ir.py or parallel (lowercase) is imported in __init__.py
cat tilelang/language/__init__.py

Repository: tile-ai/tilelang

Length of output: 4252


🏁 Script executed:

# Verify the full tir/ir.py file to understand its structure
wc -l tilelang/language/tir/ir.py
head -100 tilelang/language/tir/ir.py

Repository: tile-ai/tilelang

Length of output: 2726


🏁 Script executed:

# Check if parallel is explicitly exported from tir module
rg "from.*tir" tilelang/language/__init__.py -C 3

Repository: tile-ai/tilelang

Length of output: 569


🏁 Script executed:

# Verify the function signatures one more time to be absolutely certain
grep -A 5 "def parallel(" tilelang/language/tir/ir.py | head -10
grep -A 5 "def Parallel(" tilelang/language/loop.py | head -10

Repository: tile-ai/tilelang

Length of output: 454


🏁 Script executed:

# Check if there are other uses of T.parallel vs T.Parallel in the test file to see the pattern
rg "T\.parallel|T\.Parallel" testing/python/language/test_tilelang_language_parallel.py -n

Repository: tile-ai/tilelang

Length of output: 227


Use T.Parallel() for consistency with the rest of the file.

Line 75 uses T.parallel() (the low-level TVM IR API), while lines 18, 33, and 36 use T.Parallel() (the TileLang high-level API). Both are valid, but the file should use T.Parallel() consistently for readability and to maintain the TileLang abstraction level.

🤖 Prompt for AI Agents
In `@testing/python/language/test_tilelang_language_parallel.py` around lines 69 -
77, The test uses the low-level T.parallel() in
_parallel_vectorize_local_and_var while other tests use the TileLang high-level
T.Parallel(); update the for loop in _parallel_vectorize_local_and_var to use
T.Parallel() instead of T.parallel() so the file remains consistent with the
TileLang API (search for the function _parallel_vectorize_local_and_var and
replace the T.parallel(...) usage with T.Parallel(...)).


def test_parallel_vectorize_var():
source = _parallel_vectorize_local_and_var.get_kernel_source()
# do not vectorize if the loop only contains local/fragment and var buffer access
assert "float2" not in source


if __name__ == "__main__":
tilelang.testing.main()
Loading