diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md new file mode 100644 index 000000000..d05c8a1cd --- /dev/null +++ b/.github/copilot-instructions.md @@ -0,0 +1,13 @@ +# Copilot Instructions for TileLang Repository + +## Review Guidelines +- Ensure all Python code follows PEP 8 standards. +- Check for proper documentation in docstrings. +- Verify that changes to CI workflows improve efficiency without breaking existing functionality. +- For C++ code, ensure clang-format compliance. +- Suggest improvements for performance in tensor operations. + +## Additional Contexts +- TileLang is built on TVM and targets Tenstorrent hardware. +- Prioritize security in dependency updates. +- Encourage use of type hints in Python code. diff --git a/.github/workflows/tenstorrent-ci.yml b/.github/workflows/tenstorrent-ci.yml new file mode 100644 index 000000000..a02139144 --- /dev/null +++ b/.github/workflows/tenstorrent-ci.yml @@ -0,0 +1,218 @@ +name: Tenstorrent Backend CI + +on: + pull_request: + paths: + - 'tilelang/engine/tt/**' + - 'testing/python/tt/**' + - 'tilelang/utils/target.py' + - '.github/workflows/tenstorrent-ci.yml' + push: + branches: + - main + - 'ws1-**' + +# Auto-cancel superseded runs on the same branch/PR +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +permissions: + contents: read + +env: + PYTHON_VERSION: '3.10' + PIP_DISABLE_PIP_VERSION_CHECK: '1' + +jobs: + lint-and-format: + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v5 + with: + fetch-depth: 0 + submodules: recursive + + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: ${{ env.PYTHON_VERSION }} + cache: 'pip' + # Include any file that can affect the lock of Python deps + cache-dependency-path: | + requirements-lint.txt + requirements-dev.txt + pyproject.toml + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements-lint.txt + + - name: Run format.sh check + run: | + bash format.sh + # Fail if the formatter changed files + if ! git diff --quiet; then + echo "Code formatting issues found." + git diff + exit 1 + fi + + build-and-test: + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v5 + with: + fetch-depth: 0 + submodules: recursive + + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: ${{ env.PYTHON_VERSION }} + cache: 'pip' + cache-dependency-path: | + requirements-test.txt + requirements-build.txt + requirements.txt + pyproject.toml + + - name: Install system dependencies + run: | + sudo apt-get update + sudo apt-get install -y \ + build-essential \ + cmake \ + ninja-build \ + llvm \ + libedit-dev \ + libxml2-dev \ + zlib1g-dev + + - name: Install Python dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements-test.txt + + # Compiler cache persisted via GitHub Actions cache + - name: Enable ccache + uses: hendrikmuhs/ccache-action@v1.2.19 + with: + # Key on OS and build config; ccache handles file-level caching internally + key: ${{ runner.os }}-ccache-llvm-${{ hashFiles('CMakeLists.txt') }}-v1 + max-size: 2G + create-symlink: true + + - name: Generate TVM cache key + id: tvm-cache-key + run: | + # Cache key based on TVM submodule commit hash + TVM_COMMIT=$(git rev-parse HEAD:3rdparty/tvm) + echo "tvm_commit=$TVM_COMMIT" >> $GITHUB_OUTPUT + echo "TVM submodule at commit: $TVM_COMMIT" + + - name: Restore TVM build cache + id: cache-tvm-restore + uses: actions/cache/restore@v4 + with: + path: | + build/tvm/ + build/libtilelang*.so + build/3rdparty/ + key: tvm-build-llvm-${{ steps.tvm-cache-key.outputs.tvm_commit }}-${{ runner.os }} + restore-keys: | + tvm-build-llvm-${{ steps.tvm-cache-key.outputs.tvm_commit }}- + tvm-build-llvm- + + - name: Build TileLang with LLVM (ccache-enabled) + run: | + mkdir -p build + cd build + # Create config.cmake for TVM + cp ../3rdparty/tvm/cmake/config.cmake . + echo "set(USE_LLVM ON)" >> config.cmake + echo "set(USE_CUDA OFF)" >> config.cmake + # Configure with CMake; enable ccache via compiler launcher + cmake .. \ + -G Ninja \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_EXPORT_COMPILE_COMMANDS=ON \ + -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache + # Build (use fewer jobs to avoid OOM on GitHub runners) + cmake --build . --config Release -j 2 + + - name: Install TileLang + run: | + # Copy built libraries to tilelang/lib + mkdir -p tilelang/lib + cp build/*.so tilelang/lib/ || true + # Install TVM Python package + # Set TVM_LIBRARY_PATH so TVM can find the built libraries + export TVM_LIBRARY_PATH=$(pwd)/build/tvm + cd 3rdparty/tvm/python + pip install -e . + cd ../../.. + # Install TileLang + export USE_LLVM=true + pip install -e . + + - name: Print ccache stats + if: always() + run: | + if command -v ccache >/dev/null 2>&1; then + ccache -s + else + echo "ccache not found; skipping stats." + fi + + - name: Run Tenstorrent target registration tests + run: | + export LD_LIBRARY_PATH=$(pwd)/build/tvm:$LD_LIBRARY_PATH + cd testing/python/tt + pytest test_target_registration.py -v --tb=short + continue-on-error: true # Don't fail if TVM isn't fully available + + - name: Run all Python tests (if TVM available) + run: | + export LD_LIBRARY_PATH=$(pwd)/build/tvm:$LD_LIBRARY_PATH + cd testing/python + pytest tt/ -v --tb=short -k "not gpu" || echo "Some tests skipped (no GPU)" + continue-on-error: true + + - name: Save TVM build cache + uses: actions/cache/save@v4 + if: always() && steps.cache-tvm-restore.outputs.cache-hit != true + with: + path: | + build/tvm/ + build/libtilelang*.so + build/3rdparty/ + key: tvm-build-llvm-${{ steps.tvm-cache-key.outputs.tvm_commit }}-${{ runner.os }} + + static-analysis: + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v5 + with: + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: ${{ env.PYTHON_VERSION }} + cache: 'pip' + cache-dependency-path: | + requirements-mypy.txt + pyproject.toml + + - name: Install mypy + run: pip install -r requirements-mypy.txt + + - name: Type check Tenstorrent backend + run: mypy tilelang/engine/tt/ --ignore-missing-imports || true + continue-on-error: true \ No newline at end of file diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 000000000..eb98b1bb9 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,247 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Overview + +TileLang is a domain-specific language for developing high-performance GPU/CPU kernels (GEMM, FlashAttention, etc.) and accelerators including **Tenstorrent AI architecture**. Built on TVM with a Pythonic syntax, it enables productivity without sacrificing low-level optimizations. + +This repository (`tilelang-tt`) is a **public fork** focused on adding first-class **Tenstorrent TT-Metalium backend** support alongside existing NVIDIA CUDA, AMD ROCm, and Huawei Ascend targets. + +## Build System + +### Environment Variables + +- `USE_LLVM=true` - Enable LLVM backend (CPU-only builds, required for Tenstorrent CI) +- `USE_ROCM=true` - Enable AMD ROCm backend (requires `ROCM_HOME`) +- `USE_CUDA=true` - Default; requires `CUDA_HOME` (automatically detected) +- `DEBUG_MODE=true` - Build with debug symbols and logging +- `WITH_COMMITID=true` - Include git commit ID in wheel filename (default for non-PyPI builds) +- `PYPI_BUILD=true` - Build for PyPI distribution (clean version strings) + +### Building TileLang + +**Standard CUDA build:** +```bash +python setup.py build_ext --inplace +pip install -e . +``` + +**LLVM-only build (for CPU/Tenstorrent development):** +```bash +USE_LLVM=true pip install -e . +``` + +**ROCm build:** +```bash +USE_ROCM=true pip install -e . +``` + +The build system: +- Uses CMake + Ninja for C++/CUDA compilation +- Automatically downloads LLVM 10.0.1 if system llvm-config unavailable +- Compiles TVM from `3rdparty/tvm` submodule (unless `TVM_PREBUILD_PATH` set) +- Generates `libtvm.so`, `libtvm_runtime.so`, `libtilelang.so`, `libtilelang_module.so` +- Supports incremental builds via ccache (CI uses this heavily) + +### Testing + +**Run all tests:** +```bash +pytest testing/python/ -v +``` + +**Run Tenstorrent tests:** +```bash +LD_LIBRARY_PATH=build/tvm pytest testing/python/tt/test_target_registration.py -v +``` + +**Run specific test category:** +```bash +pytest testing/python/kernel/ -v # Kernel tests +pytest testing/python/language/ -v # Language tests +pytest testing/python/autotune/ -v # Autotuner tests +``` + +Note: Set `LD_LIBRARY_PATH` to include `build/tvm` for tests to find TVM shared libraries. + +### Code Formatting + +**Check formatting:** +```bash +bash format.sh +``` + +This runs: +- `yapf` for Python formatting +- `ruff` for Python linting +- `codespell` for spelling checks +- `clang-format` for C++ code (if `.clang-format` exists) + +**Auto-format (if supported):** +The format script will show diffs; manually apply changes or use auto-formatting tools. + +## Code Architecture + +### Tenstorrent Backend Design + +**Goal:** Map TileLang's GPU-style grid kernels to Tenstorrent's persistent, tile-based execution model. + +**Key concept:** Users write grid-style kernels with `T.Kernel(grid_x, grid_y)` using block indices `(bx, by)`. The backend generates a **persistent outer loop** for each core that iterates over assigned tiles, recovering `(bx, by)` from a static schedule. + +**Components (from README):** + +1. **Annotations API** (`python/tilelang_tt/annotations.py`): + - `T.annotate_tt_schedule()` - Control static scheduling (contiguous/strided/rect) + - `T.annotate_tt_sharding()` - Specify tensor sharding/layout on TT cores + +2. **Compiler Passes** (`src/tt/passes/`): + - `GridToPersistentTT` - Wraps grid kernel body in per-core scheduler loop + - `TTShardToCoreMap` - Translates sharding annotations to CoreRangeSet + - `TilePadTT` - Handles non-tile-multiple shapes (32×32 tiles) + - `MemorySpaceLowerTT` - Lower DRAM↔L1 moves, circular buffers + - `TensorizeTT` - Map tile operations to TT micro-kernels + +3. **Codegen** (`src/tt/codegen/`): + - `EmitTTKernels` - Generate compute/reader/writer C++ kernels and host stubs + +4. **Target Registration** (`tilelang/engine/tt/`): + - Target registration hooks for TVM integration + - Engine adapter for Tenstorrent runtime + +### Directory Structure + +``` +tilelang-tt/ +├── 3rdparty/ +│ ├── tvm/ # TVM submodule (compiler infrastructure) +│ ├── cutlass/ # NVIDIA CUTLASS for CUDA kernels +│ └── composable_kernel/ # AMD CK for ROCm kernels +├── src/ +│ ├── ir.cc # IR definitions +│ ├── layout/ # Layout transformations +│ ├── op/ # Operator implementations +│ ├── runtime/ # CUDA runtime utilities +│ ├── target/ # Code generators (CUDA, HIP, WebGPU, C++) +│ │ ├── codegen_cuda.cc +│ │ ├── codegen_hip.cc +│ │ ├── rt_mod_cuda.cc # CUDA runtime module +│ │ └── rt_mod_hip.cc # ROCm runtime module +│ ├── tl_templates/ # Kernel templates +│ └── transform/ # IR transformation passes +├── tilelang/ +│ ├── engine/ # Backend engines +│ │ └── tt/ # Tenstorrent engine (in development) +│ ├── language/ # TileLang DSL (Python API) +│ ├── autotuner/ # Auto-tuning framework +│ ├── jit/ # JIT compilation +│ │ └── adapter/cython/ # Cython wrapper for performance +│ ├── primitives/ # Primitive operations +│ └── transform/ # Python-level transforms +├── testing/python/ +│ ├── tt/ # Tenstorrent tests +│ ├── kernel/ # Kernel tests +│ ├── language/ # Language tests +│ └── autotune/ # Autotuner tests +├── examples/ # Example kernels (GEMM, attention, etc.) +└── docs/tenstorrent/ # Tenstorrent backend documentation +``` + +## Coding Standards + +### Python (from copilot-instructions.md) + +- Follow PEP 8 standards +- Use type hints for all functions +- Include docstrings for public APIs +- Security-conscious dependency updates + +### C++ + +- Follow clang-format rules (run `format.sh`) +- Ensure compatibility with TVM coding style + +## CI/CD + +### Workflows + +1. **`tenstorrent-ci.yml`** - Tenstorrent backend CI: + - Triggers on PRs modifying `tilelang/engine/tt/`, `testing/python/tt/`, or workflow files + - Runs on GitHub-hosted runners (Ubuntu + Python 3.10) + - Uses LLVM backend (not CUDA) for lightweight CPU-only tests + - **Caching strategy:** + - TVM build cache (keyed by submodule commit) - saves ~5-6 min + - ccache (keyed by CMakeLists.txt) - fast incremental compilation + - pip packages (keyed by requirements files) + - Jobs: lint-and-format, build-and-test, static-analysis (mypy) + - Tests currently `continue-on-error: true` (backend incomplete) + +2. **`ci.yml`** - Main CI: + - Self-hosted NVIDIA runners + - Full CUDA build and test suite + +3. **`amd_ci.yml`** - AMD ROCm CI + +### Running CI Locally + +```bash +# Lint and format +bash format.sh + +# Build and test (mimics Tenstorrent CI) +USE_LLVM=true pip install -e . +LD_LIBRARY_PATH=build/tvm pytest testing/python/tt/ -v +``` + +## Development Workflow + +### For Tenstorrent Backend Development + +1. **Branch naming:** Use `ws1-*` prefix for workstream 1 tasks (auto-triggers CI) + +2. **Key files to modify:** + - `tilelang/engine/tt/` - Python-level target registration and engine + - `src/tt/` - C++ passes and codegen (when ready for Phase 0) + - `testing/python/tt/` - Tests for Tenstorrent backend + +3. **Testing strategy:** + - Start with target registration tests (`test_target_registration.py`) + - Add compile-only tests before hardware tests + - Use "dry-run" mode to emit kernel sources without execution + +4. **Documentation:** + - Update `docs/tenstorrent/` with design decisions + - Follow phased approach (Phase 0: GEMM, Phase 1: SDPA, Phase 2: Ergonomics) + +## Key Technical Details + +### Tenstorrent Execution Model + +- **Persistent kernels:** Each core runs a long-lived kernel iterating over assigned tiles +- **Tile size:** 32×32 elements (dtype determines bytes per tile) +- **Memory hierarchy:** DRAM ↔ L1 circular buffers ↔ Compute +- **Static partitioning:** Host assigns `(start_id, count)` per core before launch + +### Default Behavior (Backward Compatibility) + +When no Tenstorrent annotations provided: +- Schedule: `policy="contiguous"`, `order="row_major"` +- Layout: Row-major 32×32 DRAM tilization +- L1 circular buffers auto-generated around `T.copy` sites + +This allows existing GPU-style kernels to run on TT with minimal changes (subject to tile padding). + +## Related Documentation + +- [GPU vs Tenstorrent Architecture](docs/tenstorrent/GPU_vs_Tenstorrent.md) +- [Kernel Authoring Comparison](docs/tenstorrent/kernel_authoring_comparison.md) +- [CI Documentation](docs/tenstorrent/CI.md) +- [Installation Guide](docs/get_started/Installation.md) + +## Important Notes + +- **LLVM vs CUDA builds:** For Tenstorrent development, use `USE_LLVM=true` to avoid CUDA dependency +- **System LLVM preferred:** CI uses system LLVM (via apt) to avoid libtinfo.so.5 linking issues with downloaded LLVM 10.0.1 +- **TVM library path:** Always set `LD_LIBRARY_PATH=build/tvm` when running tests +- **Submodules:** Run `git submodule update --init --recursive` after fresh clone +- **Cython JIT adapter:** Auto-compiles on first use with caching in `.cycache/` diff --git a/README.md b/README.md index 1603ea9c4..f10229f17 100644 --- a/README.md +++ b/README.md @@ -8,239 +8,411 @@ -Tile Language (**tile-lang**) is a concise domain-specific language designed to streamline the development of high-performance GPU/CPU kernels (e.g., GEMM, Dequant GEMM, FlashAttention, LinearAttention). By employing a Pythonic syntax with an underlying compiler infrastructure on top of [TVM](https://tvm.apache.org/), tile-lang allows developers to focus on productivity without sacrificing the low-level optimizations necessary for state-of-the-art performance. +Tile Language (**tile-lang**) is a concise domain-specific language designed to streamline the development of high-performance GPU/CPU kernels (e.g., GEMM, Dequant GEMM, FlashAttention, LinearAttention) as well as accelerators such as [Tenstorrent AI architecture](https://github.com/tenstorrent/tt-metal/blob/main/METALIUM_GUIDE.md) and Huawei Ascend NPU. By employing a Pythonic syntax with an underlying compiler infrastructure on top of [TVM](https://tvm.apache.org/), tile-lang allows developers to focus on productivity without sacrificing the low-level optimizations necessary for state-of-the-art performance. -## Latest News -- 09/29/2025 🎉: Thrilled to announce that ​​AscendC​​ and ​Ascend​NPU IR​​ backends targeting Huawei Ascend chips are now supported! -Check out the preview here: -🔗 [link](https://github.com/tile-ai/tilelang-ascend). -This includes implementations across two branches: -[ascendc_pto](https://github.com/tile-ai/tilelang-ascend) and -[npuir](https://github.com/tile-ai/tilelang-ascend/tree/npuir). -Feel free to explore and share your feedback! -- 07/04/2025 🚀: Introduced `T.gemm_sp` for 2:4 sparse tensor core support, check out [Pull Request #526](https://github.com/tile-ai/tilelang/pull/526) for details. -- 06/05/2025 ✨: Added [NVRTC Backend](https://github.com/tile-ai/tilelang/pull/461) to significantly reduce compilation time for cute templates! -- 04/14/2025 🚀: Added high-performance FlashMLA implementation for AMD MI300X, achieving performance parity with hand-optimized assembly kernels of Aiter! See [example_mla_amd](./examples/deepseek_mla/amd/README.md) for details. -- 03/03/2025 🚀: Added high-performance MLA Decoding support using only 80 lines of Python code, achieving performance on par with FlashMLA on H100 (see [example_mla_decode.py](./examples/deepseek_mla/example_mla_decode.py))! We also provide [documentation](./examples/deepseek_mla/README.md) explaining how TileLang achieves this. -- 02/15/2025 ✨: Added WebGPU Codegen support, see [Pull Request #86](https://github.com/tile-ai/tilelang/pull/86)! -- 02/12/2025 ✨: Excited to announce the release of [v0.1.0](https://github.com/tile-ai/tilelang/releases/tag/v0.1.0)! -- 02/10/2025 🚀: Added debug tools for TileLang—`T.print` for printing variables/buffers ([docs](https://tilelang.com/tutorials/debug_tools_for_tilelang.html)) and a memory layout plotter ([examples/plot_layout](./examples/plot_layout)). -- 01/20/2025 ✨: We are excited to announce that tile-lang, a dsl for high performance AI workloads, is now open source and available to the public! +# TileLang → Tenstorrent (TT-Metalium) Backend -## Tested Devices -Although tile-lang aims to be portable across a range of Devices, it has been specifically tested and validated on the following devices: for NVIDIA GPUs, this includes the H100 (with Auto TMA/WGMMA support), A100, V100, RTX 4090, RTX 3090, and RTX A6000; for AMD GPUs, it includes the MI250 (with Auto MatrixCore support) and the MI300X (with Async Copy support). +**Status:** Draft proposal for community discussion +**Goal:** Add a first‑class **Tenstorrent TT‑Metalium** backend to TileLang, alongside the existing NVIDIA (CUDA), AMD (HIP), and Ascend targets. -## OP Implementation Examples -**tile-lang** provides the building blocks to implement a wide variety of operators. Some examples include: +This README doubles as a **technical plan** and a **call for contributions**. The intended path is to begin in a **public fork** and upstream in stages once CI and core features stabilize. -- [Matrix Multiplication](./examples/gemm/) -- [Dequantization GEMM](./examples/dequantize_gemm/) -- [Flash Attention](./examples/flash_attention/) -- [Flash Linear Attention](./examples/linear_attention/) -- [Flash MLA Decoding](./examples/deepseek_mla/) -- [Native Sparse Attention](./examples/deepseek_nsa/) +**Related docs** -Within the `examples` directory, you will also find additional complex kernels—such as convolutions, forward/backward passes for FlashAttention, more operators will continuously be added. +- **Architecture comparison:** [GPU vs. Tenstorrent (TT‑Metalium)](docs/tenstorrent/GPU_vs_Tenstorrent.md) +- **Kernel authoring patterns:** [GPU vs. Tenstorrent — Kernel Authoring](docs/tenstorrent/kernel_authoring_comparison.md) +--- -## Benchmark Summary - -TileLang achieves exceptional performance across a variety of computational patterns. Comprehensive benchmark scripts and settings are available at [tilelang-benchmark](https://github.com/tile-ai/tilelang-benchmark). Below are selected results showcasing its capabilities: +## Table of Contents + +- [Motivation](#motivation) +- [Background: Persistent Kernels & Tiles on Tenstorrent](#background-persistent-kernels--tiles-on-tenstorrent) +- [Key Idea: Grid‑to‑Persistent Mapping](#key-idea-grid-to-persistent-mapping) +- [User‑Facing Annotations](#user-facing-annotations) + - [Static Schedule Annotations](#static-schedule-annotations) + - [Sharding & Layout Annotations](#sharding--layout-annotations) + - [Defaults & Backward Compatibility](#defaults--backward-compatibility) +- [End‑to‑End Examples](#end-to-end-examples) + - [GEMM (no annotations → defaults)](#gemm-no-annotations--defaults) + - [Attention (with schedule & layout hints)](#attention-with-schedule--layout-hints) +- [Compiler & Codegen Plan (TVM/TileLang)](#compiler--codegen-plan-tvmtilelang) + - [Phase 0 — MVP (GEMM, Elementwise)](#phase-0--mvp-gemm-elementwise) + - [Phase 1 — SDPA, Dequant‑GEMM, Reuse/Multicast](#phase-1--sdpa-dequant-gemm-reusemulticast) + - [Phase 2 — Ergonomics, Safety, Diagnostics](#phase-2--ergonomics-safety-diagnostics) +- [Runtime Integration & Build](#runtime-integration--build) +- [Developer Workflow & Repository Layout](#developer-workflow--repository-layout) +- [Risks & Mitigations](#risks--mitigations) +- [Call for Contributions](#call-for-contributions) +- [Appendix](#appendix) + - [Why the Defaults Are Safe](#why-the-defaults-are-safe) + - [Attribute & API Sketch](#attribute--api-sketch) + - [Open Questions](#open-questions) + - [License](#license) -- MLA Decoding Performance on H100 +--- -
-
- mla decode performance bs64 on H100 -
-
- mla decode performance bs128 on H100 -
-
- -- Flash Attention Performance on H100 +## Motivation -
operator performance on H100 -
+- **Tenstorrent’s execution model is persistent**: each selected core runs a long‑lived kernel and iterates over a statically assigned set of **tiles** (typically 32×32 elements), while dedicated reader/compute/writer stages move tiles between **DRAM ↔ L1** and perform compute. +- **TileLang already supports GPU‑style grid kernels** (`bx, by`) and layout hints. We propose a backend that **automatically converts grid kernels into persistent TT kernels** by generating an **outer per‑core scheduler loop** inside the compute kernel. +- Users keep writing **grid‑style** kernels. When targeting TT, the backend injects a static, per‑core loop that visits the blocks (tiles) assigned to that core. Optional **annotations** let users choose the static schedule and **TT sharding/layout**. **Sane defaults** ensure most GPU‑style kernels “just work”. -- Matmul Performance on GPUs (RTX 4090, A100, H100, MI300X) +--- -
- gemm fp16 performance on Gpus -
+## Background: Persistent Kernels & Tiles on Tenstorrent -- Dequantize Matmul Performance on A100 +- **Static partitioning:** The host partitions the global tile space into per‑core chunks (e.g., `(start_id, count)`), then launches one persistent kernel per participating core. +- **Tiles:** Compute operates on **tile‑formatted** tensors (e.g., 32×32). Tiles may **reside in DRAM**; reader kernels stream tiles into L1 circular buffers; compute kernels consume them; writer kernels commit results back to DRAM. +- **Program model:** A host **Program** creates kernels on a **CoreRange / CoreRangeSet**, wires circular buffers, sets runtime args, and enqueues work. -
- dequantize gemv performance on A100 -
+--- -## Installation -### Method 1: Install with Pip +## Key Idea: Grid‑to‑Persistent Mapping -The quickest way to get started is to install the latest release from PyPI: +**Write once (GPU‑style) in TileLang:** -```bash -pip install tilelang +```python +with T.Kernel(grid_x=Nt, grid_y=Mt, threads=(...)) as (bx, by): + compute_one_block(bx, by) # body indexes by bx/by; no TT specifics ``` -Alternatively, you can install directly from the GitHub repository: +**Generated for TT (inside the compute kernel):** -```bash -pip install git+https://github.com/tile-ai/tilelang +```cpp +// Runtime args per core: start_id, count, grid_x (Nt), grid_y (Mt), etc. +for (uint32_t i = 0; i < count; ++i) { // persistent outer loop + uint32_t tid = start_id + i; // row-major block id + uint32_t by = tid / grid_x; // recover (bx, by) + uint32_t bx = tid % grid_x; + compute_one_block(bx, by); // same inner body as GPU-style kernel +} ``` -Or install locally: +This preserves the developer’s **grid mental model** while embracing TT’s **persistent, statically scheduled** execution. + +--- + +## User‑Facing Annotations -```bash -# install required system dependencies -sudo apt-get update -sudo apt-get install -y python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev +### Static Schedule Annotations -pip install -e . -v # remove -e option if you don't want to install in editable mode, -v for verbose output +Control how the global 2‑D block grid (`grid_x × grid_y`) is **partitioned across cores** and iterated **inside** the per‑core outer loop. + +```python +T.annotate_tt_schedule( + policy="contiguous", # "contiguous" | "strided" | "rect" + order="row_major", # "row_major" | "block_linear(k)" + rect=(by0, bx0, H, W), # for policy="rect" + stride=(first, step), # for policy="strided" + chunk_k_tiles=None, # optional: K-panel chunking for GEMM + qk_chunk_tiles=None, # optional: K/V chunking for Attention +) ``` -### Method 2: Build from Source -We currently provide three ways to install **tile-lang** from source: - - [Install from Source (using your own TVM installation)](./docs/get_started/Installation.md#method-1-install-from-source-using-your-own-tvm-installation) - - [Install from Source (using the bundled TVM submodule)](./docs/get_started/Installation.md#method-2-install-from-source-using-the-bundled-tvm-submodule) - - [Install Using the Provided Script](./docs/get_started/Installation.md#method-3-install-using-the-provided-script) +- **contiguous** (default): even, contiguous intervals `(start_id, count)` per core. +- **strided**: `tid = first + n*step` sequence per core; useful for load balancing irregular blocks. +- **rect**: assign **rectangles** of blocks to cores/groups; pairs well with reuse/multicast. +- **order**: default `row_major`, with optional `block_linear(k)` for cache/NoC locality. +- **chunk knobs**: feed into reader/compute loops (e.g., `Kt` for GEMM, `Sk` chunks for SDPA). -### Method 3: Install with Nightly Version +### Sharding & Layout Annotations -For users who want access to the latest features and improvements before official releases, we provide nightly builds of **tile-lang**. +Describe how tensors are **tilized**, **sharded across cores**, and **placed** (DRAM/L1). Extends TileLang’s layout hints with **TT‑specific sharding**. -```bash -pip install tilelang -f https://tile-ai.github.io/whl/nightly/cu121/ -# or pip install tilelang --find-links https://tile-ai.github.io/whl/nightly/cu121/ +```python +T.annotate_tt_sharding({ + A: T.TTShard(axis=0, tiles=("rows", 32), placement="DRAM", + order="row_major", faces="16x16"), + B: T.TTShard(axis=1, tiles=("cols", 32), placement="DRAM", + order="row_major"), + C: T.TTShard(axis=(0, 1), tiles=("rows","cols", 32), placement="DRAM"), +}) ``` -> **Note:** Nightly builds contain the most recent code changes but may be less stable than official releases. They're ideal for testing new features or if you need a specific bugfix that hasn't been released yet. +- **axis**: which dimension(s) are sharded into tiles across cores. +- **tiles**: 32×32 by default; dtype determines bytes per tile. +- **placement**: `"DRAM"` for persistent tensors; temporaries use **L1** circular buffers automatically. +- **order** / **faces**: row/col tile orders; optional faces/packing hints if needed. + +### Defaults & Backward Compatibility -## Quick Start +If **no annotations** are provided: -In this section, you'll learn how to write and execute a straightforward GEMM (matrix multiplication) kernel using tile-lang, followed by techniques for layout optimizations, pipelining, and L2-cache–friendly swizzling. +- **Schedule default:** `policy="contiguous"`, `order="row_major"`. +- **Layout default:** **row‑major 32×32 DRAM tilization**; L1 CBs are synthesized around `T.copy` sites. +- Result: **existing GPU‑style kernels run unchanged** on TT (subject to tile padding rules). -### GEMM Example with Annotations (Layout, L2 Cache Swizzling, and Pipelining, etc.) +--- -Below is an example that demonstrates more advanced features: layout annotation, parallelized copy, and swizzle for improved L2 cache locality. This snippet shows how to adapt your kernel to maximize performance on complex hardware. +## End‑to‑End Examples + +### GEMM (no annotations → defaults) ```python -import tilelang import tilelang.language as T +BLOCK = 32 + +@T.prim_func +def gemm(A: T.Buffer((M, K), "bf16"), + B: T.Buffer((K, N), "bf16"), + C: T.Buffer((M, N), "bf16")): + Mt, Nt, Kt = T.ceildiv(M, BLOCK), T.ceildiv(N, BLOCK), T.ceildiv(K, BLOCK) + with T.Kernel(grid_x=Nt, grid_y=Mt, threads=(32, 4)) as (bx, by): + i0, j0 = by * BLOCK, bx * BLOCK + Cacc = T.alloc_fragment((BLOCK, BLOCK), "bf16"); T.fill(Cacc, 0) + for kk in range(Kt): + Ablk = T.alloc_shared((BLOCK, BLOCK), "bf16") + Bblk = T.alloc_shared((BLOCK, BLOCK), "bf16") + T.copy(T.region(A[i0, kk*BLOCK], "r", BLOCK, BLOCK), Ablk) + T.copy(T.region(B[kk*BLOCK, j0], "r", BLOCK, BLOCK), Bblk) + T.gemm(Ablk, Bblk, Cacc) + T.copy(Cacc, T.region(C[i0, j0], "w", BLOCK, BLOCK)) +``` + +**TT mapping generated by backend:** + +- Per core runtime args `(start_id, count, grid_x=Nt, grid_y=Mt, Kt, …)`. +- Compute kernel outer loop iterates `i in [0..count)` and recovers `(bx,by)` from `start_id+i`. +- Reader/Writer kernels move DRAM tiles to/from L1 CBs; compute kernel calls TT tile primitives in the K‑panel loop. + +### Attention (with schedule & layout hints) + +```python +# Schedule & layout annotations (optional – can be omitted) +T.annotate_tt_schedule(policy="contiguous", order="row_major", qk_chunk_tiles=16) +T.annotate_tt_sharding({ + Q: T.TTShard(axis=0, tiles=("rows",32), placement="DRAM"), + K: T.TTShard(axis=0, tiles=("rows",32), placement="DRAM"), + V: T.TTShard(axis=0, tiles=("rows",32), placement="DRAM"), + O: T.TTShard(axis=0, tiles=("rows",32), placement="DRAM"), +}) + +@T.prim_func +def sdpa(Q, K, V, O, scale: T.float32, causal: T.int32): + Sq_t = T.ceildiv(Sq, 32) # Q tiles + Sk_t = T.ceildiv(Sk, 32) # K/V tiles + BH = B * H # fused batch×heads + + # grid = (Sq_t, BH); bx = q-tile, by = (b,h) + with T.Kernel(grid_x=Sq_t, grid_y=BH, threads=(...)) as (bx, by): + # streaming softmax state for (by, bx) + for k0 in range(0, Sk_t, 16): # comes from qk_chunk_tiles + # read Q(bx), K/V(k0 : k0+chunk) + # scores = Q @ K^T (tile GEMMs) → update (m,l) + # O(bx) += P @ V + # write O(bx) +``` + +**TT mapping generated by backend:** + +- Outer per‑core loop over `tid in [start_id, start_id+count)`, with `by = tid / grid_x`, `bx = tid % grid_x`. +- Reader streams K/V in chunks (`qk_chunk_tiles`), compute updates streaming softmax, writer stores outputs. + +--- + +## Compiler & Codegen Plan (TVM/TileLang) -# @tilelang.jit(target="cuda") -# target currently can be "cuda" or "hip" or "cpu". -# if not specified, it will be inferred from the input tensors during compile time -@tilelang.jit -def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): +> We integrate via TVM’s **BYOC** (external codegen), keeping the TT backend cleanly modular. - @T.prim_func - def matmul_relu_kernel( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), - ): - # Initialize Kernel Context - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): - A_shared = T.alloc_shared((block_M, block_K), dtype) - B_shared = T.alloc_shared((block_K, block_N), dtype) - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) +### Phase 0 — MVP (GEMM, Elementwise) - # Enable rasterization for better L2 cache locality (Optional) - # T.use_swizzle(panel_size=10, enable=True) +1. **`GridToPersistentTT` (new pass)** + - **In:** TIR/TileLang PrimFunc using `T.Kernel(grid_x, grid_y)` and `bx/by`. + - **Out:** Function wrapped in a **per‑core outer loop** driven by the selected schedule. + - **Spec:** + - Compute `total = grid_x * grid_y`; materialize policy = contiguous/strided/rect. + - Replace `bx/by` with expressions of `tid` recovered inside the loop. + - Attach PrimFunc attrs: + - `tt.grid = (grid_y, grid_x)` + - `tt.schedule = {policy, order, rect?, stride?, chunk_k_tiles?, qk_chunk_tiles?}` + - `tt.runtime_args = ["start_id","count", …]` + - **Error cases:** missing `grid_x/grid_y`; unsupported nest shapes; negative extents. - # Clear local accumulation - T.clear(C_local) +2. **`TTShardToCoreMap` (new pass)** + - **In:** TT sharding/layout annotations. + - **Out:** Concrete **CoreRangeSet** and per‑tensor sharding metadata. + - **Spec:** + - Translate high‑level `TTShard` into `(axis, tilization, order, placement)` + core ranges. + - Attach `tt.core_ranges`, `tt.shards` to buffers/PrimFunc. + - **Error cases:** non‑tile‑multiple shapes (defer to `TilePadTT`), inconsistent placements. - for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): - # Copy tile of A - # This is a sugar syntax for parallelized copy - T.copy(A[by * block_M, ko * block_K], A_shared) +3. **`TilePadTT` (new pass)** + - **In:** Tensors with extents not multiple of 32 on tiled axes. + - **Out:** Insert pad/unpad around compute or request zero‑fill tails in readers/writers. + - **Spec:** dtype‑aware tile bytes; optionally fuse pad into reader; mark effective shape. - # Copy tile of B - T.copy(B[ko * block_K, bx * block_N], B_shared) +4. **`MemorySpaceLowerTT` (new pass)** + - **In:** TIR with `T.copy` & shared/fragment allocations. + - **Out:** Explicit **DRAM↔L1** moves, **circular buffer** descriptors, syncs. + - **Spec:** + - Map `T.alloc_shared` → L1 CB segments; compute depths from schedule/chunk knobs. + - Lower copies to reader/writer enqueue ops; add attrs `tt.cb.{depth,format,bytes}`. - # Perform a tile-level GEMM on the shared buffers - # Currently we dispatch to the cute/hip on Nvidia/AMD GPUs - T.gemm(A_shared, B_shared, C_local) - - # relu - for i, j in T.Parallel(block_M, block_N): - C_local[i, j] = T.max(C_local[i, j], 0) +5. **`TensorizeTT` (new pass)** + - **In:** Canonical tile GEMM/epilogue patterns. + - **Out:** Calls to TT tile micro‑kernels (e.g., `matmul_tiles`). + - **Spec:** pattern match, replace with intrinsic calls; keep fallbacks if not matched. - # Copy result back to global memory - T.copy(C_local, C[by * block_M, bx * block_N]) +6. **`EmitTTKernels` (codegen)** + - **Out:** + - **Compute kernel** C++ source with the generated **outer scheduler loop** + intrinsic calls. + - **Reader/Writer kernels** C++ sources with DRAM address math from `(bx,by)` or rectangles. + - **Host stub** that builds the Program, creates kernels on **CoreRange/CoreRangeSet**, allocates CBs, sets **runtime args** (`start_id`, `count`, `grid`, `Kt`/chunk), and enqueues. - return matmul_relu_kernel +7. **Runtime glue** + - Produce a `tvm.runtime.Module` that compiles the host stub and kernels, resolves TT‑Metalium SDK, and exposes a callable `run(...)`. + - CMake guards: `-DTL_TT_BACKEND=ON`, `TT_METAL_HOME` discovery; non‑TT builds remain unaffected. +### Phase 1 — SDPA, Dequant‑GEMM, Reuse/Multicast -M = 1024 # M = T.symbolic("m") if you want to use dynamic shape -N = 1024 -K = 1024 -block_M = 128 -block_N = 128 -block_K = 32 +8. **`SDPAFusionTT` (new pass)** + - Fuse `Q·Kᵀ → softmax → P·V` into a streaming loop over **(B×H, Q tiles)** with **K‑chunking**. + - Emit per‑core persistent outer loop; map `qk_chunk_tiles` into reader/compute loops. -# 1. Define the kernel (matmul) and compile/lower it into an executable module -matmul_relu_kernel = matmul(M, N, K, block_M, block_N, block_K) +9. **`TTMulticastReuse` (opt pass)** + - Where layout implies neighbor reuse (A/B in GEMM, Q or K in SDPA), introduce sender/receiver ranges and multicast paths; synthesize variant readers/writers per range. -# 3. Test the kernel in Python with PyTorch data -import torch +10. **`RasterizationTT` (opt pass)** + - Switch `tid → (by,bx)` mapping to `block_linear(k)` or other locality‑aware orders. -# Create random input tensors on the GPU -a = torch.randn(M, K, device="cuda", dtype=torch.float16) -b = torch.randn(K, N, device="cuda", dtype=torch.float16) -c = torch.empty(M, N, device="cuda", dtype=torch.float16) +### Phase 2 — Ergonomics, Safety, Diagnostics -# Run the kernel through the Profiler -matmul_relu_kernel(a, b, c) +11. **Legalize & Guards** + - Insert masks/tails where partial tiles are unavoidable; fall back to scalar or smaller vectors. -print(c) -# Reference multiplication using PyTorch -ref_c = torch.relu(a @ b) +12. **Diagnostics** + - Validate shard/schedule feasibility; emit actionable errors. + - Dump `tt.plan.json` containing per‑core `(start_id, count)` or rectangle maps for inspection. -# Validate correctness -torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) -print("Kernel output matches PyTorch reference.") +--- + +## Runtime Integration & Build -# 4. Retrieve and inspect the generated CUDA source (optional) -# cuda_source = jit_kernel.get_kernel_source() -# print("Generated CUDA kernel:\n", cuda_source) +- Integrate as a **BYOC external codegen** module (e.g., `tilelang_tt`) with clean boundaries. +- Build only when `TL_TT_BACKEND=ON` and TT SDK is discoverable. +- Provide a **“dry‑run”** mode that emits the host/kernel sources and `tt.plan.json` without executing (useful for CI without hardware). + +--- -# 5.Profile latency with kernel -profiler = matmul_relu_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) +## Developer Workflow & Repository Layout -latency = profiler.do_bench() +**Phase 1 (public fork):** start at `tile-ai/tilelang-tt` (or similar) -print(f"Latency: {latency} ms") ``` +tilelang-tt/ +├─ python/tilelang_tt/annotations.py # annotate_tt_schedule / annotate_tt_sharding +├─ src/tt/passes/*.cc # GridToPersistentTT, TTShardToCoreMap, ... +├─ src/tt/codegen/*.cc # EmitTTKernels + host stubs +├─ include/tilelang_tt/*.h +├─ cmake/TTMetal.cmake # SDK discovery +├─ tests/tt/*.py # compile-only & dry-run tests +└─ docs/ # design notes, tt.plan.json examples +``` + +- Keep **vendor SDK deps** behind CMake options; never block other backends. +- Land **Phase 0** (GEMM) with compile‑time tests and at least one **hardware smoke test**. +- Publish **design docs** and **plans** per pass; keep PRs small and reviewable. + +**Phase 2 (upstream):** open a TileLang **RFC PR** to integrate as an official backend once: +- CI is green (build‑only + optional HIL), +- the API surface (annotations & attrs) is stable, +- core operators (GEMM, elementwise) and at least one **attention** path are in. -### Dive Deep into TileLang Beyond GEMM +--- + +## Risks & Mitigations -In addition to GEMM, we provide a variety of examples to showcase the versatility and power of TileLang, including: +| Risk | Mitigation | +|---|---| +| Shapes not multiple of tile size | `TilePadTT` + reader/writer tails; clear diagnostics. | +| Backend drift / SDK changes | Version‑gated CMake; isolate TT APIs in one module. | +| CI without TT hardware | “Dry‑run” build that prints generated sources + `tt.plan.json`. | +| Over‑eager tensorization | Keep fallbacks; allow `--disable-tt-tensorize` for debugging. | + +--- -- [Dequantize GEMM](./examples/dequantize_gemm/): Achieve high-performance dequantization by **fine-grained control over per-thread operations**, with many features now adopted as default behaviors in [BitBLAS](https://github.com/microsoft/BitBLAS), which utilizing magic layout transformation and intrins to accelerate dequantize gemm. -- [FlashAttention](./examples/flash_attention/): Enable cross-operator fusion with simple and intuitive syntax, and we also provide an example of auto tuning. -- [LinearAttention](./examples/linear_attention/): Examples include RetNet and Mamba implementations. -- [Convolution](./examples/convolution/): Implementations of Convolution with IM2Col. +## Call for Contributions -## Upcoming Features +We’re looking for collaborators in these areas: -Check our [tilelang v0.2.0 release plan](https://github.com/tile-ai/tilelang/issues/79) for upcoming features. +- **Pass implementation:** `GridToPersistentTT`, `MemorySpaceLowerTT`, `TensorizeTT`. +- **Kernel stencils:** robust **reader / compute / writer** templates for GEMM & SDPA. +- **Sharding heuristics:** sensible defaults for **CoreRangeSet** selection per device. +- **Testing:** correctness (NumPy/PyTorch refs), perf baselines, CI scaffold (dry‑run + optional HIL). +- **Docs & examples:** dequant‑GEMM, Flash/MLA‑style attention with `qk_chunk_tiles`. + +Please open issues/PRs in the fork and tag **`area:tt-backend`**. Include hardware/driver details where relevant. --- -TileLang has now been used in project [BitBLAS](https://github.com/microsoft/BitBLAS) and [AttentionEngine](https://github.com/microsoft/AttentionEngine). +## Appendix + +### Why the Defaults Are Safe + +- **Schedule:** `contiguous + row_major` matches the standard static split used in multi‑core matmul tutorials—each core gets a contiguous range of tile IDs. +- **Layout:** **Row‑major 32×32 tilization in DRAM** aligns with TT’s common tile format; L1 circular buffers are synthesized automatically around copy sites. -## Join the Discussion +### Attribute & API Sketch -Welcome to join our Discord community for discussions, support, and collaboration! +**Python (user annotations)** -[![Join our Discord](https://img.shields.io/badge/Discord-Join%20Us-blue?logo=discord&style=for-the-badge)](https://discord.gg/TUrHyJnKPG) +```python +# Scheduling +T.annotate_tt_schedule(policy="contiguous", + order="row_major", + rect=None, + stride=None, + chunk_k_tiles=None, + qk_chunk_tiles=None) + +# Sharding / layout +T.annotate_tt_sharding({ + TensorA: T.TTShard(axis=0, tiles=("rows", 32), placement="DRAM"), + TensorB: T.TTShard(axis=1, tiles=("cols", 32), placement="DRAM"), +}) +``` + +**PrimFunc / Buffer attributes (internal)** + +```text +tt.grid = (grid_y, grid_x) +tt.schedule = { policy, order, rect?, stride?, chunk_k_tiles?, qk_chunk_tiles? } +tt.core_ranges = CoreRangeSet(...) +tt.shards = { buffer_name: { axis, tiles, placement, order, faces? } } +tt.runtime_args = ["start_id","count", ...] +tt.cb = { name: { depth, format, l1_bytes } } +``` + +**`tt.plan.json` (debug dump)** + +```json +{ + "grid": [Mt, Nt], + "policy": "contiguous", + "mapping": [ + {"core": [y,x], "start_id": 0, "count": 128}, + {"core": [y,x], "start_id": 128,"count": 128} + ] +} +``` + +### Open Questions + +- Do we expose **CoreRangeSet selection** in Python, or compute it from sharding and device defaults? +- Preferred **default CB depths** per op and dtype? (derive from chunk sizes?) +- How soon to enable **multicast / reuse** by default for attention/GEMM rectangles? +- Which **TT devices** and SDK versions to qualify first (e.g., Wormhole/Blackhole)? + +### License + +This backend will be contributed under the same license as TileLang. Vendor SDK headers/libraries remain under their respective licenses. + +--- -## Acknowledgements +**Next steps:** +- Create the public fork, land **Phase 0** (GEMM) with compile‑time CI + optional hardware smoke tests. +- Iterate on annotations/spec, then open an upstream **RFC PR** to integrate as an official backend. -We would like to express our gratitude to the [TVM](https://github.com/apache/tvm) community for their invaluable contributions. The initial version of this project was mainly developed by [LeiWang1999](https://github.com/LeiWang1999), [chengyupku](https://github.com/chengyupku) and [nox-410](https://github.com/nox-410) with supervision from Prof. [Zhi Yang](https://yangzhihome.github.io) at Peking University. Part of this work was carried out during an internship at Microsoft Research, where Dr. Lingxiao Ma, Dr. Yuqing Xia, Dr. Jilong Xue, and Dr. Fan Yang offered valuable advice and support. We deeply appreciate their mentorship and contributions. diff --git a/docs/tenstorrent/CI.md b/docs/tenstorrent/CI.md new file mode 100644 index 000000000..f6a378f6a --- /dev/null +++ b/docs/tenstorrent/CI.md @@ -0,0 +1,168 @@ +# Tenstorrent Backend CI + +This document describes the CI setup for the Tenstorrent backend in TileLang. + +## Overview + +The Tenstorrent backend CI is defined in `.github/workflows/tenstorrent-ci.yml` and runs on: +- Pull requests that modify Tenstorrent-related files +- Pushes to `main` and `ws1-**` branches + +## Jobs + +### 1. Lint and Format (`lint-and-format`) + +**Environment:** Ubuntu runner with Python 3.10 + +**Purpose:** Ensure code formatting and style consistency + +**Steps:** +1. Checkout repository with submodules +2. Set up Python with pip caching (caches `requirements-lint.txt` dependencies) +3. Install lint dependencies: yapf, ruff, codespell, clang-format +4. Run `format.sh` to check formatting compliance + - If formatting issues are found, the job fails and shows the diff + +**Caching:** +- Pip packages are cached based on `requirements-lint.txt` hash +- Subsequent runs with unchanged dependencies skip pip installation + +### 2. Build and Test (`build-and-test`) + +**Environment:** Ubuntu runner with Python 3.10 + +**Purpose:** Build TileLang with LLVM backend and run Tenstorrent tests + +**Note:** Currently builds with LLVM backend (not CUDA) since we only run CPU tests at this stage. This keeps the CI lightweight and fast. GPU/CUDA testing will be added in future when needed. + +**Steps:** +1. Checkout repository with submodules +2. Set up Python with pip caching (caches `requirements-test.txt` dependencies) +3. Install system dependencies: cmake, ninja, llvm, build-essential, libedit-dev, libxml2-dev, zlib1g-dev +4. Install Python dependencies from requirements-test.txt +5. **Enable ccache:** + - Uses `hendrikmuhs/ccache-action` for compiler caching + - Cache key based on CMakeLists.txt hash + OS + version + - Max size: 2G + - Creates symlinks for automatic use by CMake +6. **TVM Build Caching:** + - Generate cache key based on TVM submodule commit hash + - Restore cached TVM build artifacts if available (uses `actions/cache/restore@v4`) + - Caches: `build/tvm/` (contains libtvm*.so), `build/libtilelang*.so`, and `build/3rdparty/` + - Save TVM artifacts after build completes (uses `actions/cache/save@v4` with `if: always()`) + - Cache is saved even if job fails, preventing redundant TVM rebuilds + - Only rebuilds TVM when the submodule is updated +7. Build TileLang with LLVM backend (ccache-enabled) + - Uses Ninja build system with ccache as compiler launcher + - Limited to 2 parallel jobs to avoid OOM on GitHub runners + - LLVM backend is sufficient for CPU-only testing + - Uses system LLVM packages instead of downloading LLVM 10.0.1 +8. Install TileLang and TVM Python packages + - Install TVM Python package from `3rdparty/tvm/python` with `TVM_LIBRARY_PATH` set + - Install TileLang with `USE_LLVM=true` to enable LLVM backend + - setup.py checks for nvcc availability before trying to use it + - Gracefully skips CUDA version detection if nvcc is not found +9. Print ccache statistics (with availability check) +10. Run Tenstorrent target registration tests + - Sets `LD_LIBRARY_PATH` to include `build/tvm` for TVM library discovery + - Continue-on-error enabled for graceful handling +11. Run all Tenstorrent Python tests (CPU-only) + - Sets `LD_LIBRARY_PATH` to include `build/tvm` for TVM library discovery + - Continue-on-error enabled for graceful handling + +**Caching Strategy:** +- **ccache (compiler cache):** Keyed by CMakeLists.txt hash + OS + version + - Caches compiled object files for fast recompilation + - 2G maximum size +- **TVM build artifacts:** Keyed by TVM submodule commit + OS + - Dramatically reduces build time (TVM build is expensive) + - Only invalidates when TVM submodule is updated + - Saved even on job failure to prevent rebuilding on retry +- **Pip packages:** Keyed by requirements-test.txt hash + - Reuses cached pytest and other test dependencies + +### 3. Static Analysis (`static-analysis`) + +**Environment:** Ubuntu runner with Python 3.10 + +**Purpose:** Type checking with mypy + +**Steps:** +1. Checkout repository +2. Set up Python with pip caching (caches `requirements-mypy.txt` dependencies) +3. Install mypy from requirements-mypy.txt +4. Run mypy on `tilelang/engine/tt/` (currently set to continue-on-error) + +**Caching:** +- Pip packages are cached based on `requirements-mypy.txt` hash +- Ensures consistent caching behavior across CI runs + +## Caching Summary + +The CI uses multiple layers of caching for efficiency: + +| Job | What's Cached | Cache Key | Benefit | +|-----|---------------|-----------|---------| +| lint-and-format | Pip packages | requirements-lint.txt hash | Fast linter installation | +| build-and-test | TVM build artifacts | TVM submodule commit + OS | Avoid rebuilding TVM (~30+ min), saved even on failure | +| build-and-test | ccache compiler cache | CMakeLists.txt hash + OS + version | Fast recompilation of unchanged files | +| build-and-test | Pip packages | requirements-test.txt hash | Fast pytest install | +| static-analysis | Pip packages | requirements-mypy.txt hash | Fast mypy installation | + +## Running Locally + +To ensure your changes will pass CI: + +```bash +# Run formatting checks +bash format.sh + +# If format.sh makes changes, review and commit them +git diff +git add . +git commit -m "Apply formatting" + +# Run tests (requires TileLang built with TVM) +cd testing/python/tt +pytest test_target_registration.py -v +``` + +## Triggering CI + +CI runs automatically on: +- Pull requests modifying: + - `tilelang/engine/tt/**` + - `testing/python/tt/**` + - `tilelang/utils/target.py` + - `.github/workflows/tenstorrent-ci.yml` +- Pushes to `main` or `ws1-**` branches + +## Performance Notes + +- **First run:** ~6-7 minutes (builds TVM from scratch with ccache) +- **Subsequent runs (TVM cache hit):** ~30-60 seconds (skips TVM build, uses ccache for incremental builds) +- **Cache storage:** GitHub Actions provides up to 10GB cache per repository +- **Cache eviction:** GitHub evicts caches not accessed in 7 days +- **ccache effectiveness:** Dramatically reduces compilation time for unchanged files +- **TVM cache effectiveness:** Eliminates ~5-6 minutes of TVM rebuild when submodule unchanged + +## Key Design Decisions + +1. **System LLVM vs Downloaded LLVM:** Uses system LLVM packages (installed via apt) instead of downloading LLVM 10.0.1. This avoids compatibility issues with newer Ubuntu versions, which do not include `libtinfo.so.5` by default—causing runtime linking errors when using the downloaded LLVM 10.0.1 binaries. + +2. **Separate TVM Python Installation:** TVM Python package is installed separately before TileLang to ensure proper library path configuration. + +3. **LD_LIBRARY_PATH for Tests:** Tests require `LD_LIBRARY_PATH` to be set to `build/tvm` so Python can find the TVM shared libraries at runtime. + +4. **Cache Split (Restore/Save):** Using separate `actions/cache/restore` and `actions/cache/save` with `if: always()` ensures TVM cache is saved even when the job fails, preventing redundant rebuilds on retry. + +5. **Continue-on-error for Tests:** Tests are marked with `continue-on-error: true` because the Tenstorrent target registration in TVM is incomplete. As a result, tests are expected to fail until the backend implementation and target registration are finished. + +## Future Improvements + +Potential optimizations: +- Add CUDA build and GPU testing when needed (will require NVIDIA container or GPU runners) +- Custom Docker image with pre-built TVM (eliminates TVM build entirely) +- Parallel test execution with pytest-xdist +- Separate workflow for expensive builds (only on main/release branches) +- Remove continue-on-error once Tenstorrent backend is fully implemented diff --git a/docs/tenstorrent/GPU_vs_Tenstorrent.md b/docs/tenstorrent/GPU_vs_Tenstorrent.md new file mode 100644 index 000000000..58a35fdd1 --- /dev/null +++ b/docs/tenstorrent/GPU_vs_Tenstorrent.md @@ -0,0 +1,43 @@ +# GPU (CUDA‑style) vs. Tenstorrent (TT‑Metalium) +**Architecture & Programming Model Comparison** + +This document contrasts mainstream **GPU** execution (CUDA‑style) with **Tenstorrent**’s **TT‑Metalium** programming model. It focuses on how work is assigned, how kernels are executed, how memory/staging works, and what primitives you use to build high‑performance pipelines. + +> Reference for TT concepts: the public **Metalium Guide** — https://github.com/tenstorrent/tt-metal/blob/main/METALIUM_GUIDE.md + +--- + +## At‑a‑Glance Comparison + +| Aspect | GPU (CUDA‑style) | Tenstorrent (TT‑Metalium) | +|---|---|---| +| **Execution unit** | **Streaming Multiprocessor (SM)** executes many thread blocks; each block has warps/threads. | **Tensix core** executes a **persistent kernel**; per‑core RISC‑V controllers orchestrate **reader / compute / writer** roles. | +| **Work assignment / scheduling** | Launch many blocks; a **hardware scheduler** dynamically assigns blocks to SMs as resources free up (oversubscription is common). | **Static** partitioning on host: each core receives a fixed subset of **tile IDs** and loops over them (**no dynamic scheduler / oversubscription**). | +| **Kernel lifetime on a core** | Short‑lived blocks; SMs pick up new blocks dynamically. | **Long‑lived (“persistent”)** kernels per core; explicit outer loop over assigned work (`start_id`, `count`). | +| **Indexing / grid model** | Kernel body written against `blockIdx.{x,y,z}`, `threadIdx.{x,y,z}`; hardware provides block/thread coordinates. | Keep a **grid‑style body** (e.g., `bx, by`). For TT, a **generated outer loop** recovers `(bx, by)` from linear tile IDs assigned to that core. | +| **Granularity of compute data** | Software‑chosen tiles (e.g., 16×16, 32×32) for blocking; no mandatory DRAM tile format. | **Native 32×32 tiles**; tensors are **tilized** in DRAM to that format; cores operate tile‑by‑tile. | +| **On‑chip scratchpad** | **Shared memory** (per‑SM, user‑managed, banked). | **L1 SRAM** per core, exposed via **Circular Buffers (CBs)** created/configured by the program. | +| **Global↔on‑chip staging** | In‑kernel copies (global→shared), often with `cp.async` to overlap copy/compute. | Separate **reader kernel** streams **DRAM tiles → L1 CBs**; **compute kernel** consumes/produces tiles; **writer kernel** drains **L1 → DRAM**. | +| **Software pipelining** | `cp.async` groups, double/multi‑buffering in shared memory within a thread block. | **CB depth** (double/multi‑buffer) + split reader/compute/writer gives pipeline overlap across DRAM/L1/compute. | +| **Synchronization (intra‑block/core)** | `__syncthreads()`, warp sync primitives. | CB read/write indices, semaphores, and per‑kernel roles coordinate producer/consumer within a core. | +| **Work distribution helpers** | Streams, graphs, cooperative groups; no API to pin blocks to specific SMs. | **CoreRange / CoreRangeSet** select participating cores; helpers (e.g., “split work to cores”) compute `(start_id, count)` per core. | +| **Oversubscription / load balance** | Yes: typically launch more blocks than SMs; hardware balances dynamically. | **No oversubscription**; balance is achieved by the **static partition** of tile ranges across cores. | +| **Core‑to‑core data movement** | Usually via global/L2; limited cross‑SM features (vendor‑specific). | **On‑chip NoC** with **multicast** (e.g., write once, deliver to a rectangle of cores) for reuse patterns. | +| **Memory model (where data lives)** | Global (device DRAM), L2, shared memory, registers. | **Tiles can live in DRAM** (tilized); **L1 CBs** hold working tiles; results written back to DRAM. | +| **Typical kernel structure** | Single kernel with load→shared, compute, store; block‑local cooperation for tiling. | **Program** with **three kernels per core** (reader / compute / writer) + persistent outer loop over assigned tiles. | +| **Performance knobs** | `num_warps`, `num_stages`, shared memory size, tile shapes, vectorization, occupancy. | **CB count/depth**, tile chunk sizes (e.g., K/V chunks), **per‑core work partition** (contiguous/strided/rectangles), multicast/reuse topology. | +| **“Grid‑style” portability** | Natural, since hardware schedules blocks dynamically. | Supported by **codegen**: keep `bx,by` in the body; TT backend wraps it with a generated **static scheduler loop** per core. | + +--- + +## Practical Notes + +- **Persistent kernels on TT** mean you’ll decide **ahead of time** which set of tiles each core owns and then implement a **top‑level loop** inside the compute kernel to walk those tiles. This replaces the GPU’s hardware block scheduler. +- **Circular Buffers (CBs)** in **L1** are the central mechanism for **software pipelining**: the reader fills, compute consumes, writer drains, with depths chosen for double or multi‑buffering. +- **Multicast** lets you feed identical tiles (e.g., shared activations/weights) to a **rectangle of cores**—useful for reuse and bandwidth efficiency. +- To port GPU‑style grid code, keep the kernel body expressed against `bx, by`; then generate a **per‑core outer loop** that maps a linear tile range `(start_id..start_id+count)` back to `(by, bx)`. + +--- + +### Reference +- Tenstorrent Metalium Guide: https://github.com/tenstorrent/tt-metal/blob/main/METALIUM_GUIDE.md diff --git a/docs/tenstorrent/kernel_authoring_comparison.md b/docs/tenstorrent/kernel_authoring_comparison.md new file mode 100644 index 000000000..93f4417d6 --- /dev/null +++ b/docs/tenstorrent/kernel_authoring_comparison.md @@ -0,0 +1,79 @@ +# GPU vs. Tenstorrent — Kernel Authoring Patterns + +This note complements the architecture comparison by focusing on **how you author kernels** for GPUs (CUDA‑style) vs. **Tenstorrent (TT‑Metalium)**, including how TileLang maps a grid‑style kernel body to TT’s persistent execution. + +--- + +## Side‑by‑Side (authoring perspective) + +| Topic | GPU (CUDA‑style / Triton) | Tenstorrent (TT‑Metalium) / TileLang→TT | +|---|---|---| +| **Launch unit** | `<<>>`: many **thread blocks**, each scheduled dynamically to an SM. | Host builds a **Program**, selects **CoreRange/CoreRangeSet**, and launches **persistent kernels** (reader / compute / writer) on each participating core. | +| **Kernel body indexers** | Use `blockIdx.{x,y,z}`, `threadIdx.{x,y,z}`; or high‑level schedules (Triton’s `tl.arange`, `num_warps/stages`). | **Author in TileLang with `bx,by`** (grid blocks). Backend generates a **per‑core outer loop** that recovers `(bx,by)` from a static tile list `(start_id, count)`. | +| **Thread‑level parallelism** | Warps/threads cooperate; shared memory tiling; `__syncthreads()`. | No SIMT threads. Compute engines operate on **tiles**; parallelism comes from **cores** and **pipelined CBs**. | +| **Persistent execution** | Optional (persistent‑threads pattern) but not required. | **Default**: per‑core **persistent loop** over assigned tiles. | +| **Data staging** | Global→**shared memory** via `cp.async`/TMA; compute from shared; write back to global. | **DRAM tiles → L1 Circular Buffers (CBs)** via **reader**; compute consumes from CBs; writer drains CBs → DRAM. | +| **SW pipelining** | `cp.async` stages; double/triple buffering within a block. | **CB depth** and **separate kernels** (reader/compute/writer) implement double/multi‑buffering. | +| **Block / tile size** | Chosen by developer; may align with tensor core fragment sizes. | **Native 32×32 tiles** (dtype‑dependent bytes). Pad edges or handle tails explicitly. | +| **Work distribution** | Implicit via hardware scheduler; limited control. | **Explicit**: choose **contiguous / strided / rectangular** carves; pass per‑core `(start_id, count)`; or generate rectangle loops. | +| **Multicast / reuse** | Generally via global/L2; SM‑to‑SM multicast is limited/specialized. | **Explicit multicast** over on‑chip NoC to **rectangles** of cores for A/B or Q/K reuse. | +| **Synchronization** | Barrier in a block (`__syncthreads()`); memory scopes for cp.async groups. | Synchronization implicit in **CB protocols** and kernel staging; reader/compute/writer coordination. | +| **Tuning knobs** | `num_warps`, `num_stages`, block size, shared‑mem footprint, occupancy. | **Core ranges**, **CB counts/depths**, **tile carve policy**, **chunk sizes** (e.g., `Kt`, `Sk`), multicast topology. | +| **Annotations (TileLang)** | Layout hints; vectorization, etc. | `annotate_tt_schedule(...)` for static schedule; `annotate_tt_sharding(...)` for tilization/sharding/placement. | +| **Fallbacks / libraries** | cuBLAS/cuDNN or Triton kernels. | TTNN ops or Metalium templates; untilize/tilize helpers. | +| **Debugging** | Nsight Systems/Compute; kernel printf. | Host logs, **plan dumps** (e.g., `tt.plan.json` in our backend), device traces. | + +--- + +## Code skeletons + +### GPU‑style (TileLang body written against `bx,by`) + +```python +import tilelang.language as T +BLOCK = 32 + +@T.prim_func +def gemm(A: T.Buffer((M, K), "bf16"), + B: T.Buffer((K, N), "bf16"), + C: T.Buffer((M, N), "bf16")): + Mt, Nt, Kt = T.ceildiv(M, BLOCK), T.ceildiv(N, BLOCK), T.ceildiv(K, BLOCK) + with T.Kernel(grid_x=Nt, grid_y=Mt, threads=(32, 4)) as (bx, by): + i0, j0 = by * BLOCK, bx * BLOCK + Cacc = T.alloc_fragment((BLOCK, BLOCK), "bf16"); T.fill(Cacc, 0) + for kk in range(Kt): + Ablk = T.alloc_shared((BLOCK, BLOCK), "bf16") + Bblk = T.alloc_shared((BLOCK, BLOCK), "bf16") + T.copy(T.region(A[i0, kk*BLOCK], "r", BLOCK, BLOCK), Ablk) + T.copy(T.region(B[kk*BLOCK, j0], "r", BLOCK, BLOCK), Bblk) + T.gemm(Ablk, Bblk, Cacc) + T.copy(Cacc, T.region(C[i0, j0], "w", BLOCK, BLOCK)) +``` + +### Tenstorrent compute stub (generated outer scheduler loop; pseudo‑C++) + +```cpp +// per-core runtime args: start_id, count, grid_x (Nt), Kt, etc. +for (uint32_t i = 0; i < count; ++i) { + uint32_t tid = start_id + i; + uint32_t by = tid / grid_x; + uint32_t bx = tid % grid_x; + + // Reader has already queued A(by,kk) and B(kk,bx) into L1 CBs. + for (uint32_t kk = 0; kk < Kt; ++kk) { + // tile GEMM primitive; indices derived from (by, bx, kk) + // ckernel::matmul_tiles(cb_a, cb_b, /*indices*/, dst_cb); + } + // Writer drains C tile from CB to DRAM. +} +``` + +--- + +## Adoption steps (TileLang → TT) + +1. **Keep your kernel body GPU‑style** (index by `bx,by`). +2. (Optional) Add `annotate_tt_schedule(...)` to pick **contiguous/strided/rectangular** carve or chunk sizes. +3. (Optional) Add `annotate_tt_sharding(...)` to specify DRAM tilization & sharding. +4. Let the backend generate the **per‑core outer loop** and **reader/compute/writer** pipeline; run. + diff --git a/docs/tenstorrent/project_1.md b/docs/tenstorrent/project_1.md new file mode 100644 index 000000000..50419db02 --- /dev/null +++ b/docs/tenstorrent/project_1.md @@ -0,0 +1,130 @@ +# Project Plan — TT Backend MVP (Matrix Multiplication Dry Run) + +## Scope & Goals +- Deliver the minimal Tenstorrent backend path that lowers a TileLang GEMM into Metalium-ready host/kernels without executing on hardware. +- Lock the MVP operating point to: contiguous per-core schedule, interleaved DRAM tilization via TensorAccessor, bf16 tensors, and no user-authored annotations. +- Provide a compile-only "dry run" that emits reader/compute/writer kernels, a host program stub, and scheduling metadata (`tt.plan.json`). +- Keep the existing CUDA/HIP/CPU backends untouched; the TT path activates only when the `tenstorrent` target is requested. + +## Default Operating Point Assumptions +- **Sharding/layout:** Use **DRAM interleaved tensors** backed by the TT-Metalium `TensorAccessor` (see `tt_metal/tt_metalium/runtime/tensor_accessor.*`). Each tile is 32×32 bf16; interleaving handles per-core striping without manual address swizzling. +- **Schedule:** Static, contiguous tile ranges per core (`policy="contiguous"`, `order="row_major"`). No K-panel chunking or multicast. +- **Kernels:** One reader, one compute, one writer kernel per active Tensix core using a depth-2 circular buffer pipeline. +- **Host runtime:** Generates runtime args `[start_id, count, grid_x, grid_y, kt_tiles]` for the compute kernel; reader/writer derive DRAM strides solely from interleaved descriptors. + +## Workstream 1 — Frontend Integration & Target Selection +**Outcome:** TileLang recognizes `target="tenstorrent"`, synthesizes default TT annotations, and routes the module through a TT-specific lowering stack. + +**Implementation** +- Extend `tilelang/utils/target.py` to register `"tenstorrent"` and skip auto-detection logic (explicit opt-in only). +- Add a target adapter in `tilelang/engine/lower.py` and `tilelang/engine/__init__.py` that dispatches to a `tilelang.engine.tt.lower` helper when the TT target is active; reuse existing CUDA/HIP branches otherwise. +- Introduce `python/tilelang_tt/target.py` with a small helper that stamps default TT schedule/sharding attrs (contiguous/interleaved) when user code omits them. +- Wire the helper into the standard lowering entry point (`tilelang/engine/lower.lower`) right after target determination. + - **Justification:** The TT default synthesis is backend-specific and would clutter the generic TileLang frontend if inlined; isolating it in `tilelang_tt` keeps other targets pristine. + +**Dependency Graph** +- `ws1_target_registration.md` → foundational; must land before any TT-specific branching. +- `ws1_engine_adapter.md` depends on target registration; unblocks TT-specific lowering orchestration. +- `ws1_default_annotation_helper.md` can start once the engine adapter skeleton exists; helper implementation can proceed in parallel with wiring work but requires the adapter before integration. +- `ws1_lower_hook.md` comes after the helper so defaults are callable; also gated on the adapter being in place. +- `ws1_target_registration_test.md` exercises the full chain and therefore runs last; it may be authored in parallel but only passes once prior tickets integrate. + +**Testing** +- New Python unit test `tests/python/tt/test_target_registration.py` that checks: + - `determine_target("tenstorrent", return_object=True)` returns a `Target` named `tenstorrent`. + - Lowering a toy PrimFunc with `target="tenstorrent"` injects the default TT attrs in the resulting IRModule. + +## Workstream 2 — Schedule & Sharding Metadata +**Outcome:** Inject TT schedule/shard metadata describing contiguous per-core ranges and DRAM interleaved tilization. + +**Implementation** +- Add `src/tt/transform/infer_tt_schedule.cc` implementing `InferDefaultTTSchedule`. + - Reads `T.Kernel` metadata, enumerates tiles (`grid_x * grid_y`), partitions them by core count, and stores `tt.schedule` plus runtime-arg schemas. + - **Justification:** No existing transform computes TT runtime metadata; adding a dedicated pass avoids overloading GPU-centric passes such as `LowerL2Persistent`. +- Add `src/tt/transform/infer_tt_shard.cc` providing `InferDefaultTTShard`. + - Generates interleaved DRAM descriptors referencing TensorAccessor stride rules; for new files we depend on TT-metal headers during codegen but keep the pass pure metadata. + - Marks non-multiple-of-32 axes for later padding. +- Update `python/tilelang_tt/__init__.py` to expose these passes so they can be invoked from Python. +- Register both passes in the lowering sequence applied by Workstream 3. + +**Testing** +- C++ unit tests under `tests/cpp/tt/test_infer_tt_schedule.cc` and `tests/cpp/tt/test_infer_tt_shard.cc` using `TVM_REGISTER_GLOBAL` to invoke the passes on synthetic PrimFuncs; assert emitted attrs match expected contiguous ranges and interleaved descriptors. +- Python regression `tests/python/tt/test_inferred_metadata.py` to ensure a TileLang matmul lowered with TT target carries the metadata expected by later passes. + +## Workstream 3 — TIR Transform Pipeline +**Outcome:** Convert the annotated PrimFunc into TT-ready IR with persistent loops, interleaved addressing, and TT-specific intrinsics. + +**Implementation** +1. **`GridToPersistentTT` (`src/tt/transform/grid_to_persistent.cc`):** + - Wrap the kernel body with `for (i = 0; i < count; ++i)` recovering `(bx, by)` from `start_id + i`. + - Replace symbolic `bx/by` bindings, annotate `tt.runtime_args`. + - **Justification:** Existing `persist_threadblock` (GPU) assumes CUDA thread semantics and cannot express TT runtime arg wiring. +2. **`TTShardToCoreMap` (`src/tt/transform/shard_to_core_map.cc`):** + - Use schedule metadata to pick a rectangular `CoreRangeSet`; attach `tt.core_ranges` and per-core `(start_id, count)` arrays for host emission. + - **Justification:** TileLang has no notion of Tensix topology today; dedicating a pass keeps the knowledge localized. +3. **`TilePadTT` (`src/tt/transform/tile_pad.cc`):** + - Insert pad/unpad ops where shapes are not tile-multiples; prefer reader-side zero fill when possible. +4. **`MemorySpaceLowerTT` (`src/tt/transform/memory_space_lower.cc`):** + - Lower shared/fragment buffers to circular buffers with depth=2; tag them with `tt.cb` attributes. + - Convert `T.copy` to TT enqueue/dequeue intrinsics. +5. **`TensorizeTT` (`src/tt/transform/tensorize_matmul.cc`):** + - Match matmul loop nests and replace with `tt.matmul_tiles(cb_a, cb_b, cb_c)`. +6. **`VerifyTTIR` (`src/tt/transform/verify.cc`):** + - Ensure required attrs, runtime args, and CB invariants are present before codegen. +- Update `python/tilelang_tt/pipeline.py` to define an ordered pass list: inference (WS2) → transforms (steps 1–6) → host/codegen stubs. + +**Testing** +- For each transform, add focused C++ tests in `tests/cpp/tt/` that apply the pass and compare the transformed IR against expected snippets (use `tvm::support::AsText`). +- Python-level test `tests/python/tt/test_tir_pipeline.py` runs the full pipeline on an MVP GEMM and asserts the final IR has the persistent loop, CB attrs, and Tensorize call. + +## Workstream 4 — Code Generation & Runtime Glue +**Outcome:** Emit Metalium-compatible reader/compute/writer kernels and a host program stub capable of constructing an interleaved TensorAccessor view. + +**Implementation** +- Introduce `src/tt/codegen/emit_kernels.cc` to walk TT-annotated PrimFuncs and produce C++ text for compute kernels; include headers from TT-metal for `TensorAccessor` and `CircularBuffer` definitions. +- Add `src/tt/codegen/emit_reader_writer.cc` to generate DRAM reader/writer kernels that program TensorAccessor iterators using interleaved layout metadata. +- Create `src/tt/codegen/emit_program.cc` building the host Program: allocate CBs, set runtime args, instantiate kernels on the `CoreRangeSet`, and dump `tt.plan.json`. + - **Justification:** Existing CUDA/HIP codegen paths rely on NVCC/HIPCC FFI; TT requires a distinct BYOC module that integrates with Metalium headers and the dry-run artifact flow. +- Provide Python glue in `python/tilelang_tt/codegen.py` registering `target.build.tilelang_tt` and `target.build.tilelang_tt_without_compile` with TVM. + +**Testing** +- Golden-file comparisons in `tests/python/tt/test_codegen_artifacts.py` that run the pipeline, inspect generated `compute.cpp`, `reader.cpp`, `writer.cpp`, and `tt.plan.json`, and diff against checked-in templates under `tests/python/tt/golden/`. +- Unit test `tests/python/tt/test_tensor_accessor_interleaving.py` verifying emitted reader/writer indices align with expected interleaved offsets for small matrices (compare against handcrafted TensorAccessor calculations). + +## Workstream 5 — Tooling, Testing, and Validation +**Outcome:** Establish reproducible dry-run validation, including the TileLang GEMM MVP acceptance test. + +**Implementation** +- Add command-line hook `python/tilelang_tt/cli.py` (optional) to dump artifacts for ad-hoc inspection during development. +- Integrate clang-format/clang-tidy checks for emitted sources in CI (reuse `format.sh`). +- Extend CI configuration to add a `TT_MVP_DRYRUN` job executing the tests below and archiving artifacts. +- Ensure the pipeline returns structured metadata so downstream tooling can inspect per-core work splits. + +**Testing** +- **TileLang MVP GEMM test:** Implement `tests/python/tt/test_matmul_mvp.py` that builds the canonical TileLang GEMM (from README Phase 0), lowers it with `target="tenstorrent"`, asserts all passes succeed, and validates that the generated `tt.plan.json` assigns the expected interleaved tiles. This test is the final acceptance gate for the MVP. +- Additional smoke test `tests/python/tt/test_dry_run_cli.py` invoking the optional CLI to confirm artifact emission. + +## Workstream 6 — Documentation & Follow-Up +**Outcome:** Users and contributors understand the TT pipeline, defaults, and next steps. + +**Implementation** +- Update `README.md` Phase 0 section to reference the new interleaved defaults, TensorAccessor dependency, and dry-run instructions. +- Add a HOWTO in `docs/tenstorrent/` (e.g., `docs/tenstorrent/dry_run_walkthrough.md`) detailing the CLI/output layout. +- Document API changes (`tenstorrent` target flag, new Python helpers) in `docs/api_reference.md` (or the appropriate API doc). + +**Testing** +- Documentation lint job verifying new Markdown (spelling, links) via existing docs tooling. +- Manual review checklist ensuring instructions match the behavior validated in Workstream 5. + +## Milestones & Sequencing +1. Land Workstreams 1–2 (target detection + metadata inference) with unit tests. +2. Implement Workstream 3 transforms sequentially, gating each with its dedicated C++ tests; land once `VerifyTTIR` passes on the MVP matmul. +3. Add Workstream 4 codegen and ensure golden artifacts stabilize; update CI. +4. Finalize Workstream 5 acceptance tests (including `test_matmul_mvp.py`) and enable the dry-run CI job. +5. Publish documentation updates (Workstream 6) concurrent with enabling the TT target for early adopters. + +## Acceptance Criteria +- Running `pytest tests/python/tt/test_matmul_mvp.py` succeeds and produces serialized kernels/host artifacts using interleaved TensorAccessor layout. +- All per-workstream unit tests (C++ and Python) pass in CI, and golden artifacts remain stable across runs. +- Documentation clearly states defaults, limitations, and links to the TensorAccessor reference for contributors. + diff --git a/docs/tenstorrent/project_1_prompt.md b/docs/tenstorrent/project_1_prompt.md new file mode 100644 index 000000000..d10ace48b --- /dev/null +++ b/docs/tenstorrent/project_1_prompt.md @@ -0,0 +1,9 @@ +# Project 1 Prompt — TT Backend MVP + +- based on the high level technical plan in README.md and docs/tenstorrent, make a detailed plan to implement MVP for matrix multiplication, but default sharding, default schedule, default bfloat16, DRAM tensors. place the project plan in docs/tenstorrent/project_1.md (make a detailed markdown file). the plan should include all the steps to modify transforms, passes, etc to allow dry test of basic matmul, ie we can genereate metalium host code and kernels (reader, compute, writer). +- feedback: the default sharding should be "interleaved tensors", the new TensorAccessor in TT-Metalium supports this. +- feedback: the default sharding should be "interleaved tensors", the new TensorAccessor in TT-Metalium supports this, search TT-Metalium repo if you're not familiar. +- feedback: add the TileLang MVP Gemm test, there needs to be a python test that needs to pass after all workstreams are done. +- feedback: for each workstream , there needs be more details on which transforms/files need to be modified, if we're adding new transform/files, justify why can't modify/augment existing ones. +- feedback: for each workstream, there need to be dedicated unit tests. + diff --git a/docs/tenstorrent/workstream1/ws1_default_annotation_helper.md b/docs/tenstorrent/workstream1/ws1_default_annotation_helper.md new file mode 100644 index 000000000..d9a90b51e --- /dev/null +++ b/docs/tenstorrent/workstream1/ws1_default_annotation_helper.md @@ -0,0 +1,22 @@ +# Ticket: Default TT Annotation Helper + +## Goal +Provide a Python helper that stamps default Tenstorrent schedule/sharding metadata (contiguous schedule + interleaved TensorAccessor layout) when user code omits TT annotations. + +## Context +- Workstream 1 specifies introducing `python/tilelang_tt/target.py` to centralize default policy synthesis. +- Helper should be reusable by future workstreams when additional defaults are needed. + +## Key Tasks +- Create `python/tilelang_tt/target.py` exporting a function (e.g., `apply_tt_defaults(mod)`) that injects attrs on each PrimFunc. +- Hook into the lowering pipeline before TT-specific transforms run. +- Document the default choices (contiguous schedule, interleaved DRAM tensors) within the helper for clarity. + +## Dependencies +- Requires the engine adapter ticket so the TT path has a hook to call the helper. +- Metadata produced here must align with Workstream 2 inference passes to avoid duplication. + +## Validation +- Verified indirectly by `tests/python/tt/test_target_registration.py`, asserting the lowered IR contains TT default attrs. + +Status: TODO diff --git a/docs/tenstorrent/workstream1/ws1_engine_adapter.md b/docs/tenstorrent/workstream1/ws1_engine_adapter.md new file mode 100644 index 000000000..0796c8dd2 --- /dev/null +++ b/docs/tenstorrent/workstream1/ws1_engine_adapter.md @@ -0,0 +1,22 @@ +# Ticket: Add Tenstorrent Engine Adapter + +## Goal +Teach the TileLang lowering entry point to delegate device codegen to a dedicated Tenstorrent helper when the `tenstorrent` target is active. + +## Context +- Workstream 1 requires a TT-specific `lower` helper in `tilelang.engine` to split host/device responsibilities. +- Must preserve existing CUDA/HIP/CPU behavior. + +## Key Tasks +- Introduce `tilelang/engine/tt/lower.py` (or similar) to encapsulate TT-specific lowering orchestration. +- Update `tilelang/engine/__init__.py` and `tilelang/engine/lower.py` to branch on `target.kind.name == "tenstorrent"` and invoke the TT helper. +- Ensure TT path integrates with `CompileArtifact` data structures without affecting current consumers. + +## Dependencies +- Depends on `ws1_target_registration.md` so the target can be selected. +- Precedes annotation helper wiring that will rely on the new entry point. + +## Validation +- Smoke test through `tests/python/tt/test_target_registration.py` to confirm the TT branch executes without raising. + +Status: In Review (changes pending on branch `ws1-engine-adapter`) diff --git a/docs/tenstorrent/workstream1/ws1_lower_hook.md b/docs/tenstorrent/workstream1/ws1_lower_hook.md new file mode 100644 index 000000000..09c697589 --- /dev/null +++ b/docs/tenstorrent/workstream1/ws1_lower_hook.md @@ -0,0 +1,21 @@ +# Ticket: Wire TT Defaults into Lowering Entry Point + +## Goal +Invoke the Tenstorrent default annotation helper during lowering so TT metadata is always present before TT-specific passes execute. + +## Context +- Complements the default helper by ensuring it runs automatically when the TT target is selected. +- Should be inserted immediately after target resolution but before pass pipelines. + +## Key Tasks +- Modify `tilelang/engine/lower.lower` (or the TT helper introduced in `ws1_engine_adapter.md`) to call the default annotation routine on the IRModule. +- Guarantee idempotency so repeated lowering passes do not duplicate attrs. +- Add logging or debug hooks (if appropriate) to confirm defaults were applied. + +## Dependencies +- Depends on `ws1_engine_adapter.md` and `ws1_default_annotation_helper.md`. + +## Validation +- Covered by the Workstream 1 test (`tests/python/tt/test_target_registration.py`) inspecting the transformed IR. + +Status: TODO diff --git a/docs/tenstorrent/workstream1/ws1_target_registration.md b/docs/tenstorrent/workstream1/ws1_target_registration.md new file mode 100644 index 000000000..9a16470c2 --- /dev/null +++ b/docs/tenstorrent/workstream1/ws1_target_registration.md @@ -0,0 +1,25 @@ +# Ticket: Register `tenstorrent` Target + +## Goal +Enable explicit opt-in for the Tenstorrent backend by adding `"tenstorrent"` to TileLang's target resolution so downstream lowering can select the TT pass pipeline. + +## Context +- Driven by Workstream 1 of `project_1.md`. +- Must avoid impacting existing auto-detection logic for CUDA/HIP backends. + +## Key Tasks +- Update `tilelang/utils/target.py`: + - Append `"tenstorrent"` to `AVALIABLE_TARGETS` and document that auto detection remains CUDA/HIP only. + - In `determine_target`, add an explicit branch handling the string/Target case where `target == "tenstorrent"` and return `Target("tenstorrent")` when `return_object=True`. + - Guard the `auto` path from ever choosing TT by ensuring the CUDA/HIP checks remain first and that TT raises if requested but not compiled with TT support. +- Define an informative error/warning path (e.g., `raise ValueError("Tenstorrent backend requires TL_TT_BACKEND build flag")`) for configurations built without TT support; place the check adjacent to the new branch so failure is immediate. +- Ensure the returned `Target` exposes `kind.name == "tenstorrent"` so later code can branch on it. +- Add inline comments noting that TT auto-detection is intentionally disabled until the backend can probe hardware. + +## Dependencies +- None; pure frontend change but should land before other TT-specific wiring. + +## Validation +- Covered by `tests/python/tt/test_target_registration.py` (see Workstream 1 testing ticket). + +Status: In Review (changes pending on branch `tt-matmul-mvp-plan`) diff --git a/docs/tenstorrent/workstream1/ws1_target_registration_test.md b/docs/tenstorrent/workstream1/ws1_target_registration_test.md new file mode 100644 index 000000000..968ef74d1 --- /dev/null +++ b/docs/tenstorrent/workstream1/ws1_target_registration_test.md @@ -0,0 +1,23 @@ +# Ticket: Target Registration Test Coverage + +## Goal +Add Python regression coverage ensuring the Tenstorrent target can be selected and that default metadata injection occurs during lowering. + +## Context +- Testing requirement from Workstream 1 to guard new frontend wiring. +- Relies on pytest infrastructure in `tests/python`. + +## Key Tasks +- Create `tests/python/tt/test_target_registration.py` that: + - Requests `target="tenstorrent"` through TileLang lowering APIs. + - Verifies the resulting IRModule carries `tt.schedule` and `tt.shard` attrs seeded by the default helper. + - Ensures no CUDA/HIP-specific passes execute when the TT path is chosen. +- Add fixtures/utilities if needed to inspect PrimFunc attrs. + +## Dependencies +- Depends on prior tickets that add the target registration and default helper wiring. + +## Validation +- Test passes under `pytest tests/python/tt/test_target_registration.py` and is added to CI once Workstream 1 lands. + +Status: In Review (tests added on branch `tt-matmul-mvp-plan`) diff --git a/requirements-mypy.txt b/requirements-mypy.txt new file mode 100644 index 000000000..7b18a177d --- /dev/null +++ b/requirements-mypy.txt @@ -0,0 +1 @@ +mypy>=1.0.0 diff --git a/setup.py b/setup.py index 9baa2868d..9b88092e4 100644 --- a/setup.py +++ b/setup.py @@ -65,15 +65,18 @@ def load_module_from_path(module_name, path): raise ValueError( "ROCM support is enabled (USE_ROCM=True) but ROCM_HOME is not set or detected.") -if not USE_ROCM and not CUDA_HOME: - raise ValueError( - "CUDA support is enabled by default (USE_ROCM=False) but CUDA_HOME is not set or detected.") +# For LLVM-only builds, skip CUDA/ROCM validation +if not USE_LLVM: + if not USE_ROCM and not CUDA_HOME: + raise ValueError( + "CUDA support is enabled by default (USE_ROCM=False) but CUDA_HOME is not set or detected." + ) -# Ensure one of CUDA or ROCM is available -if not (CUDA_HOME or ROCM_HOME): - raise ValueError( - "Failed to automatically detect CUDA or ROCM installation. Please set the CUDA_HOME or ROCM_HOME environment variable manually (e.g., export CUDA_HOME=/usr/local/cuda or export ROCM_HOME=/opt/rocm)." - ) + # Ensure one of CUDA or ROCM is available + if not (CUDA_HOME or ROCM_HOME): + raise ValueError( + "Failed to automatically detect CUDA or ROCM installation. Please set the CUDA_HOME or ROCM_HOME environment variable manually (e.g., export CUDA_HOME=/usr/local/cuda or export ROCM_HOME=/opt/rocm)." + ) # TileLang only supports Linux platform assert sys.platform.startswith("linux"), "TileLang only supports Linux platform (including WSL)." @@ -158,9 +161,11 @@ def get_tilelang_version(with_cuda=True, with_system_info=True, with_commit_id=F local_version_parts.append(f"rocm{rocm_version_str}") else: if CUDA_HOME: - cuda_version = str(get_nvcc_cuda_version()) - cuda_version_str = cuda_version.replace(".", "")[:3] - local_version_parts.append(f"cu{cuda_version_str}") + nvcc_path = os.path.join(CUDA_HOME, "bin", "nvcc") + if os.path.exists(nvcc_path): + cuda_version = str(get_nvcc_cuda_version()) + cuda_version_str = cuda_version.replace(".", "")[:3] + local_version_parts.append(f"cu{cuda_version_str}") if local_version_parts: version += f"+{'.'.join(local_version_parts)}" @@ -358,8 +363,19 @@ def is_git_repo(): def setup_llvm_for_tvm(): - """Downloads and extracts LLVM, then configures TVM to use it.""" - # Assume the download_and_extract_llvm function and its dependencies are defined elsewhere in this script + """Downloads and extracts LLVM, then configures TVM to use it. + + If system llvm-config is available, uses that instead of downloading LLVM. + """ + # Check if system llvm-config is available + system_llvm_config = shutil.which("llvm-config") + if system_llvm_config: + logger.info(f"Using system llvm-config at {system_llvm_config}") + extract_path = os.path.dirname(os.path.dirname(system_llvm_config)) # Go up to LLVM root + return extract_path, system_llvm_config + + # Otherwise download LLVM + logger.info(f"System llvm-config not found, downloading LLVM {LLVM_VERSION}") extract_path = download_and_extract_llvm(LLVM_VERSION, IS_AARCH64, EXTRACT_PATH) llvm_config_path = os.path.join(extract_path, "bin", "llvm-config") return extract_path, llvm_config_path diff --git a/testing/python/tt/test_target_registration.py b/testing/python/tt/test_target_registration.py new file mode 100644 index 000000000..594739258 --- /dev/null +++ b/testing/python/tt/test_target_registration.py @@ -0,0 +1,76 @@ +import importlib + +import pytest + +try: + import tvm + from tvm.target import Target +except ModuleNotFoundError as exc: + pytest.skip(f"TVM not available: {exc}", allow_module_level=True) + +_target_mod = importlib.import_module("tilelang.utils.target") +_tt_lower = importlib.import_module("tilelang.engine.tt.lower") +CompiledArtifact = importlib.import_module("tilelang.engine.param").CompiledArtifact + + +@pytest.fixture +def toggle_tt_backend(monkeypatch): + original = getattr(_target_mod, "_HAS_TENSTORRENT_BACKEND", False) + + def setter(value: bool): + monkeypatch.setattr(_target_mod, "_HAS_TENSTORRENT_BACKEND", value, raising=False) + + setter(original) + try: + yield setter + finally: + setter(original) + + +def test_available_targets_contains_tt(): + assert _target_mod.TENSTORRENT_TARGET in _target_mod.AVALIABLE_TARGETS + + +def test_determine_target_returns_target_when_backend_enabled(toggle_tt_backend): + toggle_tt_backend(True) + scope_name = _target_mod.determine_target(_target_mod.TENSTORRENT_TARGET) + assert scope_name == _target_mod.TENSTORRENT_TARGET + + target_obj = _target_mod.determine_target(_target_mod.TENSTORRENT_TARGET, return_object=True) + assert isinstance(target_obj, Target) + assert target_obj.kind.name == _target_mod.TENSTORRENT_TARGET + + +def test_determine_target_raises_when_backend_disabled(toggle_tt_backend): + toggle_tt_backend(False) + with pytest.raises(ValueError, match="Tenstorrent backend requires"): + _target_mod.determine_target(_target_mod.TENSTORRENT_TARGET) + + +def test_tenstorrent_engine_lower_raises_not_implemented(toggle_tt_backend): + toggle_tt_backend(True) + with pytest.raises( + NotImplementedError, match="Tenstorrent backend lowering is not yet implemented"): + _tt_lower.lower( + tvm.IRModule(), + params=None, + target=_target_mod.TENSTORRENT_TARGET, + target_host=None, + runtime_only=False, + enable_host_codegen=False, + enable_device_compile=False, + ) + + +def test_tenstorrent_engine_lower_validates_target(toggle_tt_backend): + toggle_tt_backend(True) + with pytest.raises(ValueError, match="Tenstorrent lowering called with invalid target"): + _tt_lower.lower( + tvm.IRModule(), + params=None, + target="cuda", + target_host=None, + runtime_only=False, + enable_host_codegen=False, + enable_device_compile=False, + ) diff --git a/tilelang/engine/__init__.py b/tilelang/engine/__init__.py index 476b40a35..503a2c3f2 100644 --- a/tilelang/engine/__init__.py +++ b/tilelang/engine/__init__.py @@ -1,3 +1,4 @@ from .lower import lower, is_device_call # noqa: F401 from .param import KernelParam # noqa: F401 from .callback import register_cuda_postproc, register_hip_postproc # noqa: F401 +from .tt import lower_tenstorrent # noqa: F401 diff --git a/tilelang/engine/lower.py b/tilelang/engine/lower.py index 65a14e6e6..7fced713b 100644 --- a/tilelang/engine/lower.py +++ b/tilelang/engine/lower.py @@ -15,12 +15,26 @@ LowerAndLegalize, OptimizeForTarget, ) +from tilelang.engine.tt import lower_tenstorrent +from tilelang.utils.target import TENSTORRENT_TARGET def is_cpu_device_backend(target: Target): return target.kind.name == "c" +def get_target_kind(target: Union[str, Target]) -> str: + """Extract the target kind name from a target object or string. + + Args: + target: Either a string target name or a Target object + + Returns: + The target kind name as a string + """ + return target.kind.name if isinstance(target, Target) else target + + def has_device_kernel_launch(attrs) -> bool: """Check if the attributes indicate a device kernel launch.""" return bool(attrs and "calling_conv" in attrs and @@ -213,8 +227,19 @@ def lower( target = determine_target(target) target_host = canon_target_host(target, target_host) - target_host = tvm.target.Target.canon_target(target_host) + + if get_target_kind(target) == TENSTORRENT_TARGET: + return lower_tenstorrent( + mod, + params, + target, + target_host, + runtime_only=runtime_only, + enable_host_codegen=enable_host_codegen, + enable_device_compile=enable_device_compile, + ) + target = tvm.target.Target(target, target_host) _is_host_call = get_host_call(is_device_c=is_cpu_device_backend(target)) diff --git a/tilelang/engine/tt/__init__.py b/tilelang/engine/tt/__init__.py new file mode 100644 index 000000000..6a3d2e837 --- /dev/null +++ b/tilelang/engine/tt/__init__.py @@ -0,0 +1,5 @@ +"""Tenstorrent engine helpers.""" + +from .lower import lower as lower_tenstorrent + +__all__ = ["lower_tenstorrent"] diff --git a/tilelang/engine/tt/lower.py b/tilelang/engine/tt/lower.py new file mode 100644 index 000000000..b7fdb5b97 --- /dev/null +++ b/tilelang/engine/tt/lower.py @@ -0,0 +1,67 @@ +"""Tenstorrent lowering entry point. + +This module provides a stub implementation that wires the Tenstorrent target +into TileLang's lowering flow. The real lowering pipeline will be added in +subsequent tickets. +""" + +from __future__ import annotations + +from typing import List, Optional, Union + +from tvm.target import Target + +from tilelang import tvm as tvm +from tilelang.engine.param import CompiledArtifact, KernelParam + + +def lower( + mod: tvm.IRModule, + params: Optional[List[KernelParam]], + target: Union[str, Target], + target_host: Optional[Union[str, Target]], + *, + runtime_only: bool, + enable_host_codegen: bool, + enable_device_compile: bool, +) -> CompiledArtifact: + """Lower the given module for the Tenstorrent backend. + + This is a stub implementation. It validates the target and then raises + NotImplementedError, since the actual lowering pipeline is not yet implemented. + The concrete lowering pipeline will be implemented in future workstreams. + + Args: + mod: The TVM IRModule to lower (unused in stub) + params: Optional list of kernel parameters (unused in stub) + target: The target (should be Tenstorrent target) + target_host: Optional host target (unused in stub) + runtime_only: Whether to generate runtime-only code (unused in stub) + enable_host_codegen: Whether to enable host code generation (unused in stub) + enable_device_compile: Whether to enable device compilation (unused in stub) + + Raises: + ValueError: If the target is not a Tenstorrent target + NotImplementedError: This stub implementation always raises this exception + instead of returning a CompiledArtifact + """ + from tilelang.engine.lower import get_target_kind + from tilelang.utils.target import TENSTORRENT_TARGET + + # Unused parameters in this stub implementation - will be used in full implementation + _ = mod + _ = params + _ = target_host + _ = runtime_only + _ = enable_host_codegen + _ = enable_device_compile + + # Validate that we're actually targeting Tenstorrent + target_kind = get_target_kind(target) + if target_kind != TENSTORRENT_TARGET: + raise ValueError(f"Tenstorrent lowering called with invalid target: {target_kind}. " + f"Expected: {TENSTORRENT_TARGET}") + + raise NotImplementedError("Tenstorrent backend lowering is not yet implemented. " + "This is a stub implementation. The lowering pipeline will be " + "added in future workstreams.") diff --git a/tilelang/utils/target.py b/tilelang/utils/target.py index 7d712d3ae..8d0aed055 100644 --- a/tilelang/utils/target.py +++ b/tilelang/utils/target.py @@ -5,6 +5,14 @@ from tvm.contrib import rocm from tilelang.contrib import nvcc +TENSTORRENT_TARGET = "tenstorrent" + + +def _is_tenstorrent_backend_enabled() -> bool: + """Detect whether the Tenstorrent backend has been registered with TVM.""" + return bool(tvm.get_global_func("target.build.tilelang_tt", allow_missing=True)) + + AVALIABLE_TARGETS = { "auto", "cuda", @@ -12,8 +20,11 @@ "webgpu", "c", # represent c source backend "llvm", + TENSTORRENT_TARGET, } +_HAS_TENSTORRENT_BACKEND = _is_tenstorrent_backend_enabled() + def check_cuda_availability() -> bool: """ @@ -78,8 +89,12 @@ def determine_target(target: Union[str, Target, Literal["auto"]] = "auto", raise ValueError("No CUDA or HIP available on this system.") else: # Validate the target if it's not "auto" + target_kind = target.kind.name if isinstance(target, Target) else target assert isinstance( - target, Target) or target in AVALIABLE_TARGETS, f"Target {target} is not supported" + target, Target) or target_kind in AVALIABLE_TARGETS, f"Target {target} is not supported" + if target_kind == TENSTORRENT_TARGET and not _HAS_TENSTORRENT_BACKEND: + raise ValueError( + "Tenstorrent backend requires TileLang to be built with TL_TT_BACKEND enabled.") return_var = target if return_object: