diff --git a/CMakeLists.txt b/CMakeLists.txt index 6b7eee31d5b8..888f53c8cece 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -260,7 +260,7 @@ endfunction() # Disable warnings that show up in external code (gtest;pybind11) if(NOT MSVC) set(TRITON_DISABLE_EH_RTTI_FLAGS "$<$:-fno-exceptions;-fno-rtti>") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default -fvisibility=hidden") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default") else() set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /wd4244 /wd4624 /wd4715 /wd4530") endif() @@ -409,6 +409,12 @@ if(TRITON_BUILD_PYTHON_MODULE) # Link triton with its dependencies target_link_libraries(triton PRIVATE ${TRITON_LIBRARIES}) + + # Do not propagate libraries that libtriton depends on. This ensures that + # targets that link against libtriton do not accidentally link in their own + # copies of core Triton code and LLVM. + set_target_properties(triton PROPERTIES INTERFACE_LINK_LIBRARIES "") + if(WIN32) target_link_libraries(triton PRIVATE ${CMAKE_DL_LIBS}) set_target_properties(triton PROPERTIES SUFFIX ".pyd") @@ -431,10 +437,8 @@ if(TRITON_BUILD_PYTHON_MODULE) "${TRITON_WHEEL_DIR}/FileCheck" COPYONLY) -endif() - -if (UNIX AND NOT APPLE) - set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,--exclude-libs,ALL") + # Only build plugins when building libtriton since they depend on libtriton. + add_subdirectory(examples/plugins) endif() if(TRITON_BUILD_PYTHON_MODULE AND NOT WIN32) @@ -460,7 +464,6 @@ find_package(Threads REQUIRED) add_subdirectory(third_party/f2reduce) add_subdirectory(bin) add_subdirectory(test) -add_subdirectory(examples) if(TRITON_BUILD_UT) add_subdirectory(unittest) diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt deleted file mode 100644 index 0e89371e07e6..000000000000 --- a/examples/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_subdirectory(plugins) diff --git a/examples/plugins/CMakeLists.txt b/examples/plugins/CMakeLists.txt index 6b4a14952b7b..e89bd44b41eb 100644 --- a/examples/plugins/CMakeLists.txt +++ b/examples/plugins/CMakeLists.txt @@ -24,7 +24,7 @@ foreach( plugin ${TRITON_PLUGIN_PASSES} ) TritonCanonicalizeIncGen TritonPluginsIncGen ) - target_link_libraries(${plugin} PRIVATE MLIRPass) + target_link_libraries(${plugin} PRIVATE triton) # CMAKE_LIBRARY_OUTPUT_DIRECTORY is only set during the Python # build. It is empty if building directly from the root diff --git a/examples/plugins/DialectPlugins/DialectPlugin/lib/DialectPlugin/CMakeLists.txt b/examples/plugins/DialectPlugins/DialectPlugin/lib/DialectPlugin/CMakeLists.txt index 2e0271800053..aa1be10bb2db 100644 --- a/examples/plugins/DialectPlugins/DialectPlugin/lib/DialectPlugin/CMakeLists.txt +++ b/examples/plugins/DialectPlugins/DialectPlugin/lib/DialectPlugin/CMakeLists.txt @@ -20,10 +20,7 @@ add_mlir_dialect_library(MLIRDialectPlugin MLIRDialectPluginPassesIncGen LINK_LIBS PUBLIC - MLIRPass - LLVMSupport - MLIRSupport - TritonNVIDIAGPUToLLVM + triton "$<$:-undefined dynamic_lookup>" ) diff --git a/python/test/unit/plugins/test_dialect_plugin.py b/python/test/unit/plugins/test_dialect_plugin.py index 55ba6c36950f..279e3107130b 100644 --- a/python/test/unit/plugins/test_dialect_plugin.py +++ b/python/test/unit/plugins/test_dialect_plugin.py @@ -2,15 +2,13 @@ import subprocess import pathlib import pytest +import re -from triton._internal_testing import is_cuda, is_hip, is_hip_cdna2 - -pytestmark = pytest.mark.skipif(is_hip_cdna2(), reason="old AMD GPUs are not supported") +from triton._internal_testing import is_cuda, is_hip +@pytest.mark.skipif(is_hip(), reason="plugin not supported/tested on AMD yet") def test_override(tmp_path: pathlib.Path): - if os.environ.get('LLVM_BUILD_SHARED_LIBS', '0') == '0': - return dir_path = os.path.dirname(os.path.realpath(__file__)) # Run once to get the file dumps @@ -40,9 +38,6 @@ def test_override(tmp_path: pathlib.Path): os.remove(ptx_files[0]) os.remove(cubin_files[0]) - if is_hip(): - pytest.skip("plugin not supported/tested on AMD yet") - filename = str(list(tmp_path.rglob("*.ttir"))[0]) with open(filename, "r") as infile: @@ -51,8 +46,10 @@ def test_override(tmp_path: pathlib.Path): # # Add ttgir instrumentation with open(filename, "w") as outfile: for line in file_str: - if "tt.get_program_id x" in line: - line = ' %pid_base = arith.constant 0 : i32\n %pid = plugin.magic %pid_base : i32\n' + match = re.search(r'(%\w+)\s*=\s*tt\.get_program_id\s+x', line) + if match: + ssa_name = match.group(1) + line = f' %pid_base = arith.constant 0 : i32\n {ssa_name} = plugin.magic %pid_base : i32\n' outfile.write(line) # # # Run again with kernel override diff --git a/python/test/unit/plugins/test_plugin.py b/python/test/unit/plugins/test_plugin.py index 9a895174b1b8..3687420a4a6e 100644 --- a/python/test/unit/plugins/test_plugin.py +++ b/python/test/unit/plugins/test_plugin.py @@ -1,11 +1,11 @@ import torch import pytest -import os import triton import triton.language as tl from triton import knobs +from triton._internal_testing import is_hip import custom_stages @@ -21,10 +21,8 @@ def kernel2(BLOCK_SIZE: tl.constexpr): return +@pytest.mark.skipif(is_hip(), reason="plugin not supported/tested on AMD yet") def test_op(capfd, device: str): - if os.environ.get('LLVM_BUILD_SHARED_LIBS', '0') == '0': - return - size = 98432 x = torch.rand(size, device=device) output = torch.empty_like(x)