Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions docker/Dockerfile.rocm
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,43 @@ ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential git wget \
libgtest-dev libprotobuf-dev protobuf-compiler libgflags-dev libsqlite3-dev llvm-dev \
rocm-dev rocm-libs hip-dev hipblas-dev rocblas-dev \
&& apt-get clean autoclean && rm -rf /var/lib/apt/lists/{cache,log} /tmp/* /var/tmp/*

ENV PATH="/opt/conda/bin:${PATH}"
ENV LIBGL_ALWAYS_INDIRECT=1
ENV USE_ROCM=1
ENV USE_CUDA=0
ENV ROCM_HOME=/opt/rocm
ENV HIP_PLATFORM=amd
ENV PYTORCH_ROCM_ARCH="gfx90a;gfx942"


RUN conda run -n py_3.10 conda install pip cmake -y && \
conda run -n py_3.10 conda install -c conda-forge libstdcxx-ng=12 -y && \
conda clean --all

RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev
RUN apt-get update && apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev && \
apt-get clean autoclean && rm -rf /var/lib/apt/lists/{cache,log} /tmp/* /var/tmp/*

RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main tilelang && \
mv /opt/conda/envs/py_3.10/compiler_compat /opt/conda/envs/py_3.10/compiler_compat.bak || true && \
conda run -n py_3.10 bash -c "pip install 'numpy<2.0' --force-reinstall && cd tilelang && USE_ROCM=1 pip install -e . -v"
# Copy local tilelang directory instead of cloning from git
# Build from tilelang root: docker build -f docker/Dockerfile.rocm -t mi300:latest .
COPY . /root/tilelang

RUN mv /opt/conda/envs/py_3.10/compiler_compat /opt/conda/envs/py_3.10/compiler_compat.bak || true && \
conda run -n py_3.10 bash -c "export USE_ROCM=1 USE_CUDA=0 && pip install 'numpy<2.0' --force-reinstall" && \
conda run -n py_3.10 bash -c "cd /root/tilelang && \
# Backup and modify pyproject.toml to remove torch from dependencies \
cp pyproject.toml pyproject.toml.bak && \
sed -i '/^[[:space:]]*\"torch/d' pyproject.toml && \
# Install tilelang with all dependencies except torch \
USE_ROCM=1 USE_CUDA=0 pip install -e . -v && \
# Restore original pyproject.toml \
mv pyproject.toml.bak pyproject.toml"
Comment on lines +38 to +44
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Consider more robust dependency management.

The sed-based modification of pyproject.toml to remove torch dependencies is fragile and could break if the file format or dependency structure changes. Consider using a Python-based TOML parser (e.g., tomli/tomli-w) for more reliable manipulation.

Example alternative approach:

python3 -c "
import tomli, tomli_w
with open('pyproject.toml', 'rb') as f:
    data = tomli.load(f)
# Remove torch from dependencies
deps = data.get('project', {}).get('dependencies', [])
data['project']['dependencies'] = [d for d in deps if not d.startswith('torch')]
with open('pyproject.toml', 'wb') as f:
    tomli_w.dump(data, f)
"


RUN conda init bash && \
echo "conda activate py_3.10" >> /root/.bashrc

SHELL ["/bin/bash", "-l", "-c"]

ENTRYPOINT ["/bin/bash", "--login", "-i"]
ENTRYPOINT ["/bin/bash", "--login", "-i"]
17 changes: 16 additions & 1 deletion examples/amd/example_amd_flash_attn_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,21 @@
from functools import partial


# Custom supply function to ensure tensors are created on GPU
def supply_tensors_gpu(params):
"""Supply function that creates tensors on GPU for ROCm/HIP."""
tensors = []
for param in params:
if hasattr(param, 'shape') and hasattr(param, 'dtype'):
# Force creation on GPU device
shape = [int(s) for s in param.shape]
tensor = torch.randn(shape, dtype=param.dtype, device='cuda')
tensors.append(tensor)
else:
tensors.append(param)
return tensors


def ref_program(Q, K, V, is_causal, groups=1):
assert Q.size(
2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}"
Expand Down Expand Up @@ -63,7 +78,7 @@ def get_configs():
return valid_configs


@tilelang.autotune(configs=get_configs(), cache_input_tensors=True)
@tilelang.autotune(configs=get_configs(), cache_input_tensors=True, supply_prog=supply_tensors_gpu)
@tilelang.jit(out_idx=[3])
def fast_flashattn(
batch,
Expand Down
4 changes: 3 additions & 1 deletion src/op/math.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ TVM_REGISTER_OP("tl.pow_of_int")
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure))
.set_attr<TScriptPrinterName>("TScriptPrinterName", "pow_of_int")
.set_attr<FLowerIntrinsic>("hip.FLowerIntrinsic", pow_of_int_op)
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", pow_of_int_op);

PrimExpr infinity_op(PrimExpr args) {
Expand All @@ -59,7 +60,8 @@ TVM_REGISTER_OP("tl.infinity")
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure))
.set_attr<TScriptPrinterName>("TScriptPrinterName", "infinity")
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", infinity_op);
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", infinity_op)
.set_attr<FLowerIntrinsic>("hip.FLowerIntrinsic", infinity_op);

} // namespace tl
} // namespace tvm
4 changes: 2 additions & 2 deletions src/target/codegen_hip.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1190,9 +1190,9 @@ inline void PrintConst(const FloatImmNode *op, std::ostream &os,
if (op->value < 0) {
temp << "-";
}
temp << ((op->dtype.bits() == 32) ? "HIPRT_INF_F" : "HIPRT_INF");
temp << ((op->dtype.bits() == 32) ? "HUGE_VALF" : "HUGE_VAL");
} else if (std::isnan(op->value)) {
temp << ((op->dtype.bits() == 32) ? "HIPRT_NAN_F" : "HIPRT_NAN");
temp << ((op->dtype.bits() == 32) ? "NAN" : "NAN");
} else {
temp << std::scientific << op->value;
if (op->dtype.bits() == 32)
Expand Down
Loading