Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
2cc6e7e
poking around a bit
ptillet Aug 27, 2023
a0b3e61
very very basic POC
ptillet Aug 28, 2023
692bf30
cleanup
ptillet Aug 28, 2023
9c9cf0b
more work
ptillet Aug 28, 2023
24e60b5
.
ptillet Aug 28, 2023
2f6fe65
progress
ptillet Sep 4, 2023
c405f28
more cleaning
ptillet Sep 4, 2023
e4257a1
more cleaning
ptillet Sep 4, 2023
7457e00
remove interpret flag
ptillet Sep 4, 2023
b7cf36f
.
ptillet Sep 4, 2023
7595ad4
.
ptillet Sep 5, 2023
d1636a8
flash attention runs but produces incorrect result
ptillet Sep 11, 2023
145d70f
bugfix
ptillet Sep 11, 2023
d304a20
progress
ptillet Sep 11, 2023
5882703
flash attention fwd pass working
ptillet Sep 11, 2023
2ef140e
flash bwd also works
ptillet Sep 11, 2023
5d3d916
cleanup
ptillet Sep 11, 2023
2f1cab7
.
ptillet Sep 15, 2023
59075e4
Merge remote-tracking branch 'origin/main' into phil/new-interpreter
ptillet Sep 17, 2023
1af9397
.
ptillet Sep 17, 2023
84f7a0c
.
ptillet Sep 17, 2023
9d2858a
.
ptillet Sep 17, 2023
9d60d6a
more fixes
ptillet Sep 17, 2023
dfa4d22
.
ptillet Sep 17, 2023
3d95cc7
.
ptillet Sep 17, 2023
9d679f9
.
ptillet Sep 17, 2023
86b49d6
.
ptillet Sep 17, 2023
6bbd9fb
.
ptillet Sep 17, 2023
56888b4
.
ptillet Sep 17, 2023
a977b51
.
ptillet Sep 17, 2023
4cdfadd
.
ptillet Sep 17, 2023
41d4c79
interpreter test workflow
ptillet Sep 17, 2023
27f3c14
.
ptillet Sep 17, 2023
fbdb451
.
ptillet Sep 17, 2023
8bec7e0
.
ptillet Sep 17, 2023
4301e46
.
ptillet Sep 17, 2023
1853b8c
.
ptillet Sep 17, 2023
78325e9
.
ptillet Sep 17, 2023
602597e
trying to clean
ptillet Sep 17, 2023
213815e
.
ptillet Sep 17, 2023
2a091ec
.
ptillet Sep 17, 2023
9a23a1b
.
ptillet Sep 17, 2023
a6b0b06
Merge branch 'main' into phil/new-interpreter
ptillet Sep 17, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .github/workflows/integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ jobs:
echo '::set-output name=matrix-optional::["ubuntu-latest"]'
fi


Integration-Tests-Nvidia:
needs: Runner-Preparation

Expand Down Expand Up @@ -118,6 +119,14 @@ jobs:
run: |
rm -rf ~/.triton

- name: Run interpreter tests
env:
# TRITON_INTERPRET: "1"
CUA_VISIBLE_DEVICES: ""
run: |
cd python/test/unit
python3 -m pytest -vs operators/test_flash_attention.py

- name: Run partial tests on CUDA with ENABLE_TMA=1 and ENABLE_MMA_V3=1
if: ${{ env.BACKEND == 'CUDA' && env.ENABLE_TMA == '1' && env.ENABLE_MMA_V3 == '1'}}
run: |
Expand Down
5 changes: 2 additions & 3 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ class Package(NamedTuple):


def get_pybind11_package_info():
name = "pybind11-2.10.0"
url = "https://github.com/pybind/pybind11/archive/refs/tags/v2.10.0.tar.gz"
name = "pybind11-2.11.1"
url = "https://github.com/pybind/pybind11/archive/refs/tags/v2.11.1.tar.gz"
return Package("pybind11", name, url, "PYBIND11_INCLUDE_DIR", "", "PYBIND11_SYSPATH")

# llvm
Expand Down Expand Up @@ -296,7 +296,6 @@ def build_extension(self, ext):
"triton/_C",
"triton/common",
"triton/compiler",
"triton/interpreter",
"triton/language",
"triton/language/extra",
"triton/ops",
Expand Down
42 changes: 42 additions & 0 deletions python/src/triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
#include <stdexcept>
#include <string>

#include <pybind11/numpy.h>
namespace py = pybind11;

PYBIND11_MAKE_OPAQUE(mlir::triton::gpu::TMAMetadataTy);
Expand Down Expand Up @@ -1961,11 +1962,52 @@ void init_triton_translation(py::module &m) {
});
}

void init_triton_interpreter(py::module &&m) {
using ret = py::return_value_policy;

m.def("load",
[](py::array_t<uint64_t> ptrs, py::array_t<bool> masks, py::array other,
py::dtype ret_dtype) -> py::array {
int numel = ptrs.size();
auto shape =
std::vector<ptrdiff_t>(ptrs.shape(), ptrs.shape() + ptrs.ndim());
py::array ret(ret_dtype, py::array::ShapeContainer{numel});
py::array_t<uint64_t> reshaped_ptrs = ptrs.reshape({numel});
py::array_t<bool> reshaped_masks = masks.reshape({numel});
py::array reshaped_others = other.reshape({numel});
for (size_t i = 0; i < ptrs.size(); ++i) {
if (reshaped_masks.at(i))
memcpy(ret.mutable_data(i),
reinterpret_cast<void *>(reshaped_ptrs.at(i)),
ret_dtype.itemsize());
else
memcpy(ret.mutable_data(i), reshaped_others.data(i),
ret_dtype.itemsize());
}
return ret.reshape(shape);
});

m.def("store", [](py::array_t<uint64_t> ptrs, py::array values,
py::array_t<bool> mask) {
int numel = ptrs.size();
py::array_t<uint64_t> reshaped_ptrs = ptrs.reshape({numel});
py::array_t<int8_t> reshaped_masks = mask.reshape({numel});
py::array reshaped_values = values.reshape({numel});
for (size_t i = 0; i < ptrs.size(); ++i) {
if (reshaped_masks.at(i)) {
memcpy(reinterpret_cast<void *>(reshaped_ptrs.mutable_at(i)),
reshaped_values.data(i), values.dtype().itemsize());
}
}
});
}

void init_triton(py::module &m) {
py::module subm = m.def_submodule("triton");
init_triton_env_vars(subm);
// init_triton_codegen(subm.def_submodule("code_gen"));
init_triton_runtime(subm.def_submodule("runtime"));
init_triton_ir(subm.def_submodule("ir"));
init_triton_interpreter(subm.def_submodule("interpreter"));
init_triton_translation(subm);
}
69 changes: 0 additions & 69 deletions python/test/unit/interpreter/test_interpreter.py

This file was deleted.

1 change: 0 additions & 1 deletion python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2421,7 +2421,6 @@ def kernel(X, stride_xm, stride_xk,
if epilogue == 'chain-dot':
z_ref = np.matmul(z_ref, w)
# compare
# print(z_ref[:,0], z_tri[:,0])
if in_dtype == 'float32':
# XXX: Somehow there's a larger difference when we use float32
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3)
Expand Down
11 changes: 6 additions & 5 deletions python/test/unit/operators/test_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import triton.ops


@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 1024, 16),
(4, 48, 1024, 32),
(4, 48, 1024, 64),
(4, 48, 1024, 128)])
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(2, 4, 512, 16),
(2, 4, 512, 32),
(2, 4, 512, 64),
(2, 4, 512, 128)])
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
@pytest.mark.parametrize('causal', [True, False])
@pytest.mark.parametrize('seq_par', [True, False])
Expand All @@ -21,7 +21,8 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par):
pytest.skip('Segmentation fault')

capability = torch.cuda.get_device_capability()
if capability[0] < 8:
interpreter = os.environ.get("TRITON_INTERPRET", 'not found') in ["on", "true", "1"]
if not interpreter and capability[0] < 8:
pytest.skip("Flash attention only supported for compute capability < 80")
torch.manual_seed(20)
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
Expand Down
Empty file.
9 changes: 0 additions & 9 deletions python/triton/interpreter/core.py

This file was deleted.

171 changes: 0 additions & 171 deletions python/triton/interpreter/interpreter.py

This file was deleted.

Loading