Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Windows build final #7

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 148 additions & 18 deletions .github/workflows/integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,26 +23,25 @@ jobs:
matrix-required: ${{ steps.set-matrix.outputs.matrix-required }}
matrix-optional: ${{ steps.set-matrix.outputs.matrix-optional }}
steps:
- name: Prepare runner matrix
- name: Prepare matrix
id: set-matrix
run: |
if [ x"${{ github.repository }}" == x"openai/triton" ]; then
echo '::set-output name=matrix-required::[["self-hosted", "A100"], ["self-hosted", "H100"]]'
echo '::set-output name=matrix-optional::[["self-hosted", "gfx908"], ["self-hosted", "arc770"]]'
echo 'matrix-required={"runner": [["self-hosted", "A100"], ["self-hosted", "H100"]], "python-version": ["3.11"], "cuda-version": ["12.1"], "cc": ["clang"]}' >> "$GITHUB_OUTPUT"
echo 'matrix-optional={"runner": [["self-hosted", "gfx908"], ["self-hosted", "arc770"]], "python-version": ["3.11"], "cuda-version": ["12.1"], "cc": ["clang"]}' >> "$GITHUB_OUTPUT"
else
echo '::set-output name=matrix-required::["ubuntu-latest"]'
echo '::set-output name=matrix-optional::["ubuntu-latest"]'
echo 'matrix-required={"runner":["ubuntu-latest", "windows-latest"], "python-version": ["3.10", "3.11"], "cuda-version": ["12.1"], "cc": ["clang"]}' >> "$GITHUB_OUTPUT"
echo 'matrix-optional={"runner":["ubuntu-latest", "windows-latest"], "python-version": ["3.10", "3.11"], "cuda-version": ["12.1"], "cc": ["clang"]}' >> "$GITHUB_OUTPUT"
fi

Integration-Tests:
needs: Runner-Preparation

runs-on: ${{ matrix.runner }}
timeout-minutes: 20
timeout-minutes: 60

strategy:
matrix:
runner: ${{fromJson(needs.Runner-Preparation.outputs.matrix-required)}}
matrix: ${{fromJson(needs.Runner-Preparation.outputs.matrix-required)}}

steps:
- name: Checkout
Expand All @@ -56,11 +55,119 @@ jobs:
echo "ENABLE_TMA=0" >> "${GITHUB_ENV}"
echo "TRITON_DISABLE_LINE_INFO=1" >> "${GITHUB_ENV}"

- name: Set up Python ${{ matrix.python-version }}
if: matrix.runner[0] == 'self-hosted'
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}

- name: Set up MSVC
if: matrix.runner == 'windows-latest'
uses: ilammy/[email protected]
with:
arch: amd64

- name: Setup Mambaforge (Windows)
if: matrix.runner == 'windows-latest'
uses: conda-incubator/setup-miniconda@v3
with:
miniforge-variant: Mambaforge
miniforge-version: latest
activate-environment: triton-env
use-mamba: true

- uses: conda-incubator/setup-miniconda@v3
if: matrix.runner == 'windows-latest'
with:
activate-environment: triton-env
environment-file: environment.yml
auto-activate-base: true
python-version: ${{ matrix.python-version }}

- name: set Environment Variables (Windows)
if: matrix.runner == 'windows-latest'
shell: bash -el {0}
run: |
LLVM_SHORTHASH="$(cat cmake/llvm-hash.txt | cut -c1-8)"
# prepare LLVM prebuilt path. will be downloaded and extracted by setup.py step
echo "~/.triton/llvm/llvm-$LLVM_SHORTHASH-windows-x64/bin" >> "$GITHUB_PATH"
# compile with a selected matrix.cc
if [ "${{matrix.cc}}" = "cl" ]; then
echo "CC=cl" >> "${GITHUB_ENV}"
echo "CXX=cl" >> "${GITHUB_ENV}"
elif [ "${{matrix.cc}}" = "clang" ]; then
echo "CC=clang" >> "${GITHUB_ENV}"
echo "CXX=clang++" >> "${GITHUB_ENV}"
fi

- name: CUDA toolkit ${{ matrix.cuda-version }}
shell: bash -el {0}
if: matrix.runner[0] != 'self-hosted'
run: |
if [ "${{ matrix.runner }}" = "ubuntu-latest" ]; then
# prepare space for ubuntu
sudo rm -rf /usr/share/dotnet
sudo rm -rf /opt/ghc
fi

addon=""
cuda_version=${{ matrix.cuda-version }}
[ "$cuda_version" = "12.1" ] && cuda_version="12.1.1" && addon="cuda-cudart-static cuda-nvrtc"
[ "$cuda_version" = "11.8" ] && cuda_version="11.8.0"

conda install cuda-libraries-dev cuda-nvcc cuda-nvtx cuda-cupti cuda-cudart cuda-cudart-dev cuda-runtime cuda-libraries $addon -c "nvidia/label/cuda-$cuda_version"

- name: Get Date (Windows)
if: matrix.runner == 'windows-latest'
id: get-date
run: echo "today=$(date -u '+%Y%m%d')" >> $GITHUB_OUTPUT
shell: bash

- name: Cache conda env (Windows)
if: matrix.runner == 'windows-latest'
id: cache
uses: actions/cache@v3
env:
# Increase this value to reset cache if environment.yml has not changed
CACHE_NUMBER: 0
with:
path: ${{ env.CONDA }}/envs
key:
${{ matrix.runner }}--${{ steps.get-date.outputs.today }}--conda-${{ env.CACHE_NUMBER }}-cp${{ matrix.python-version }}-${{ hashFiles('environment.yml') }}


- name: Update conda environment (Windows)
if: ${{(matrix.runner == 'windows-latest')}}
shell: bash -el {0}
run: |
if [ "${{ steps.cache.outputs.cache-hit }}" != "true" ]; then
mamba env update -n triton-env -f environment.yml
cat environment.yml
fi

- name: Update environment
if: matrix.runner[0] != 'self-hosted'
shell: bash
run: |
echo "BACKEND=CUDA" >> "${GITHUB_ENV}"
echo "ENABLE_TMA=0" >> "${GITHUB_ENV}"
echo "TRITON_DISABLE_LINE_INFO=1" >> "${GITHUB_ENV}"


- name: Set reusable strings
# Turn repeated input strings (such as the build output directory) into step outputs. These step outputs can be used throughout the workflow file.
id: strings
shell: bash
run: |
echo "build-output-dir=${{ github.workspace }}/build" >> "$GITHUB_OUTPUT"

- name: Clear cache
shell: bash
run: |
rm -rf ~/.triton

- name: Update PATH
if: matrix.runner[0] == 'self-hosted'
run: |
echo "PATH=${HOME}/.local/bin:${PATH}" >> "${GITHUB_ENV}"

Expand All @@ -70,17 +177,29 @@ jobs:
python3 -m pre_commit run --all-files --verbose

- name: Install Triton
if: ${{ env.BACKEND == 'CUDA'}}
if: matrix.runner != 'windows-latest'
run: |
cd python
python3 -m pip install --upgrade pip
python3 -m pip install cmake==3.24 ninja pytest-xdist
python3 -m pip install cmake==3.24 ninja pytest-xdist wheel
sudo apt-get update -y
sudo apt-get install -y ccache clang lld
TRITON_BUILD_WITH_CLANG_LLD=true TRITON_BUILD_WITH_CCACHE=true python3 -m pip install --no-build-isolation -vvv '.[tests]'
if [ "${{ matrix.runner }}" = 'ubuntu-latest' ]; then
python3 setup.py bdist_wheel
fi

- name: Install Triton (Windows)
if: matrix.runner == 'windows-latest'
run: |
cd python
python3 -m pip install --upgrade pip
python3 -m pip install cmake==3.24 ninja pytest-xdist wheel
python3 -m pip install --no-build-isolation -vvv .
python3 setup.py bdist_wheel

- name: Run lit tests
if: ${{ env.BACKEND == 'CUDA'}}
if: matrix.runner[0] == 'self-hosted' && env.BACKEND == 'CUDA'
run: |
python3 -m pip install lit
cd python
Expand All @@ -96,7 +215,7 @@ jobs:
echo "ENABLE_TMA=1" >> "${GITHUB_ENV}"

- name: Run python tests on CUDA with ENABLE_TMA=1
if: ${{ env.BACKEND == 'CUDA' && env.ENABLE_TMA == '1'}}
if: ${{(matrix.runner[0] == 'self-hosted') && env.BACKEND == 'CUDA' && env.ENABLE_TMA == '1'}}
run: |
cd python/test/unit
python3 -m pytest -n 8 --ignore=runtime --ignore=operators --ignore=language/test_line_info.py --ignore=language/test_subprocess.py
Expand All @@ -109,7 +228,7 @@ jobs:
python3 -m pytest hopper/test_flashattention.py

- name: Run python tests on CUDA with ENABLE_TMA=0
if: ${{ env.BACKEND == 'CUDA' && env.ENABLE_TMA == '0'}}
if: ${{(matrix.runner[0] == 'self-hosted') && env.BACKEND == 'CUDA' && env.ENABLE_TMA == '0'}}
run: |
cd python/test/unit
python3 -m pytest -n 8 --ignore=runtime --ignore=hopper --ignore=operators --ignore=language/test_line_info.py
Expand All @@ -119,10 +238,12 @@ jobs:
TRITON_DISABLE_LINE_INFO=0 python3 -m pytest language/test_line_info.py

- name: Clear cache
shell: bash
run: |
rm -rf ~/.triton

- name: Run interpreter tests
if: matrix.runner[0] == 'self-hosted'
env:
# TRITON_INTERPRET: "1"
CUA_VISIBLE_DEVICES: ""
Expand All @@ -131,17 +252,25 @@ jobs:
python3 -m pytest -vs operators/test_flash_attention.py

- name: Run partial tests on CUDA with ENABLE_TMA=1
if: ${{ env.BACKEND == 'CUDA' && env.ENABLE_TMA == '1'}}
if: ${{(matrix.runner[0] == 'self-hosted') && env.BACKEND == 'CUDA' && env.ENABLE_TMA == '1'}}
run: |
cd python/test/unit
python3 -m pytest -n 8 operators

- name: Run partial tests on CUDA with ENABLE_TMA=0
if: ${{ env.BACKEND == 'CUDA' && env.ENABLE_TMA == '0'}}
if: ${{(matrix.runner[0] == 'self-hosted') && env.BACKEND == 'CUDA' && env.ENABLE_TMA == '0'}}
run: |
cd python/test/unit
python3 -m pytest -n 8 operators

- name: Upload Build artifacts
if: matrix.runner[0] != 'self-hosted'
uses: actions/upload-artifact@v3
with:
name: triton-dist ${{ matrix.runner }}
path: |
${{ github.workspace }}/python/dist/

- name: Create artifacts archive
if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'V100' || matrix.runner[1] == 'A100' || matrix.runner[1] == 'H100')}}
run: |
Expand All @@ -150,20 +279,20 @@ jobs:

- name: Upload artifacts archive
if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'V100' || matrix.runner[1] == 'A100' || matrix.runner[1] == 'H100')}}
uses: actions/upload-artifact@v2
uses: actions/upload-artifact@v3
with:
name: artifacts ${{ matrix.runner[1] }}
path: ~/.triton/artifacts.tar.gz

- name: Run CXX unittests
if: ${{ env.BACKEND == 'CUDA'}}
if: ${{(matrix.runner[0] == 'self-hosted') && env.BACKEND == 'CUDA'}}
run: |
cd python
cd "build/$(ls build | grep -i cmake)"
ctest

- name: Regression tests
if: ${{ contains(matrix.runner, 'A100') }}
if: ${{ (matrix.runner[0] == 'self-hosted') && contains(matrix.runner, 'A100') }}
run: |
python3 -m pip install pytest-rerunfailures
cd python/test/regression
Expand All @@ -173,6 +302,7 @@ jobs:
sudo nvidia-smi -i 0 -rgc

Compare-artifacts:
if: ${{(github.repository == 'openai/triton')}}
needs: Integration-Tests
timeout-minutes: 5

Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ python/build/
python/triton.egg-info/
python/triton/_C/libtriton.pyd
python/triton/_C/libtriton.so
python/triton/_C/triton.dll

# Python caches
__pycache__/
Expand Down
50 changes: 35 additions & 15 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,17 @@ set(TRITON_CODEGEN_BACKENDS "" CACHE STRING "Enable different codegen backends")
# used conditionally in this file and by lit tests

# Customized release build type with assertions: TritonRelBuildWithAsserts
set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g")
set(CMAKE_CXX_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g")
if(NOT MSVC)
set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g")
set(CMAKE_CXX_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g")
else()
set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "/Zi /Ob0 /Od /RTC1")
set(CMAKE_CXX_FLAGS_TRITONRELBUILDWITHASSERTS "/Zi /Ob0 /Od /RTC1")
set(CMAKE_EXE_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL")
set(CMAKE_MODULE_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL")
set(CMAKE_SHARED_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL")
set(CMAKE_STATIC_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL")
endif()

# Default build type
if(NOT CMAKE_BUILD_TYPE)
Expand All @@ -47,7 +56,15 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
# Third-party
include_directories(${PYBIND11_INCLUDE_DIR})

set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17 -fvisibility=hidden -fvisibility-inlines-hidden")
if(NOT MSVC)
if(NOT WIN32)
set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17 -fvisibility=hidden -fvisibility-inlines-hidden")
else()
set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -D__STDC_FORMAT_MACROS -std=gnu++17 -Wno-deprecated -fvisibility=hidden -fvisibility-inlines-hidden")
endif()
else()
set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -D__STDC_FORMAT_MACROS /wd4244 /wd4624 /wd4715 /wd4530")
endif()

if(APPLE)
set(CMAKE_OSX_DEPLOYMENT_TARGET 11.6)
Expand All @@ -59,7 +76,7 @@ endif()
if(NOT MLIR_DIR)
if(NOT LLVM_LIBRARY_DIR)
if(WIN32)
find_package(LLVM 13 REQUIRED COMPONENTS nvptx amdgpu)
find_package(LLVM 17 REQUIRED COMPONENTS nvptx amdgpu)

include_directories(${LLVM_INCLUDE_DIRS})
separate_arguments(LLVM_DEFINITIONS_LIST NATIVE_COMMAND ${LLVM_DEFINITIONS})
Expand Down Expand Up @@ -154,6 +171,8 @@ if(TRITON_BUILD_PYTHON_MODULE)

if(PYTHON_INCLUDE_DIRS)
include_directories(${PYTHON_INCLUDE_DIRS})
message(STATUS "PYTHON_LIB_DIRS ${PYTHON_LIB_DIRS}")
link_directories(${PYTHON_LIB_DIRS})
else()
find_package(Python3 REQUIRED COMPONENTS Development Interpreter)
include_directories(${Python3_INCLUDE_DIRS})
Expand All @@ -163,16 +182,6 @@ if(TRITON_BUILD_PYTHON_MODULE)
endif()
endif()

# # Triton
# file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc)
# if (WIN32 AND TRITON_BUILD_PYTHON_MODULE)
# Python3_add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
# set_target_properties(triton PROPERTIES SUFFIX ".pyd")
# set_target_properties(triton PROPERTIES PREFIX "lib")
# else()
# add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
# endif()

# MLIR
find_package(MLIR REQUIRED CONFIG PATHS ${MLIR_DIR})

Expand All @@ -184,7 +193,11 @@ include(AddLLVM)
include(AddMLIR)

# Disable warnings that show up in external code (gtest;pybind11)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default")
if(NOT MSVC)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default")
else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /WX-")
endif()

include_directories(${MLIR_INCLUDE_DIRS})
include_directories(${LLVM_INCLUDE_DIRS})
Expand Down Expand Up @@ -239,6 +252,8 @@ if(TRITON_BUILD_PYTHON_MODULE)
target_link_libraries(triton PRIVATE ${LLVM_LIBRARIES} ${CMAKE_DL_LIBS}
${TRITON_LIBRARIES}
)
set_target_properties(triton PROPERTIES SUFFIX ".pyd")
set_target_properties(triton PROPERTIES PREFIX "lib")
elseif(APPLE)
target_link_libraries(triton ${LLVM_LIBRARIES} z
${TRITON_LIBRARIES}
Expand Down Expand Up @@ -275,6 +290,11 @@ if (${CODEGEN_BACKENDS_LEN} GREATER 0)
endforeach()
endif()

if(WIN32)
option(CMAKE_USE_WIN32_THREADS_INIT "using WIN32 threads" ON)
option(gtest_disable_pthreads "Disable uses of pthreads in gtest." ON)
endif()

add_subdirectory(test)

add_subdirectory(unittest)
1 change: 1 addition & 0 deletions bin/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ llvm_update_compile_flags(triton-translate)
mlir_check_all_link_libraries(triton-translate)

add_llvm_executable(triton-llvm-opt
PARTIAL_SOURCES_INTENDED
triton-llvm-opt.cpp

DEPENDS
Expand Down
Loading
Loading