diff --git a/.gitignore b/.gitignore index 7c5081621fc3..c89206f736e3 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ python/build/ python/triton.egg-info/ python/triton/_C/libtriton.pyd python/triton/_C/libtriton.so +python/triton/_C/triton.dll # Backends copied from submodules python/triton/backends/ diff --git a/CMakeLists.txt b/CMakeLists.txt index 086d67dafaa2..fa4d0fe28716 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -30,8 +30,17 @@ set(TRITON_CODEGEN_BACKENDS "" CACHE STRING "Enable different codegen backends") # used conditionally in this file and by lit tests # Customized release build type with assertions: TritonRelBuildWithAsserts -set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g") -set(CMAKE_CXX_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g") +if(NOT MSVC) + set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g") + set(CMAKE_CXX_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g") +else() + set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "/Zi /Ob0 /Od /RTC1") + set(CMAKE_CXX_FLAGS_TRITONRELBUILDWITHASSERTS "/Zi /Ob0 /Od /RTC1") + set(CMAKE_EXE_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL") + set(CMAKE_MODULE_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL") + set(CMAKE_SHARED_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL") + set(CMAKE_STATIC_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL") +endif() # Default build type if(NOT CMAKE_BUILD_TYPE) @@ -45,7 +54,15 @@ endif() # Compiler flags include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) -set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17") +if(NOT MSVC) + if(NOT WIN32) + set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17") + else() + set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -D__STDC_FORMAT_MACROS -std=gnu++17 -Wno-deprecated") + endif() +else() + set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -D__STDC_FORMAT_MACROS /wd4244 /wd4624 /wd4715 /wd4530") +endif() # Third-party include_directories(${PYBIND11_INCLUDE_DIR}) @@ -103,7 +120,11 @@ endfunction() # Disable warnings that show up in external code (gtest;pybind11) -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default -fvisibility=hidden") +if(NOT MSVC) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default -fvisibility=hidden") +else() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /WX-") +endif() include_directories(".") include_directories(${MLIR_INCLUDE_DIRS}) @@ -139,6 +160,8 @@ if(TRITON_BUILD_PYTHON_MODULE) if(PYTHON_INCLUDE_DIRS) include_directories(${PYTHON_INCLUDE_DIRS}) + message(STATUS "PYTHON_LIB_DIRS ${PYTHON_LIB_DIRS}") + link_directories(${PYTHON_LIB_DIRS}) else() find_package(Python3 REQUIRED COMPONENTS Development Interpreter) include_directories(${Python3_INCLUDE_DIRS}) @@ -238,6 +261,8 @@ if(TRITON_BUILD_PYTHON_MODULE) target_link_libraries(triton PUBLIC ${TRITON_LIBRARIES}) if(WIN32) target_link_libraries(triton PRIVATE ${CMAKE_DL_LIBS}) + set_target_properties(triton PROPERTIES SUFFIX ".pyd") + set_target_properties(triton PROPERTIES PREFIX "lib") else() target_link_libraries(triton PRIVATE z) endif() @@ -255,6 +280,11 @@ if(TRITON_BUILD_PYTHON_MODULE AND NOT WIN32) target_link_libraries(triton PRIVATE ${PYTHON_LDFLAGS}) endif() +if(WIN32) + option(CMAKE_USE_WIN32_THREADS_INIT "using WIN32 threads" ON) + option(gtest_disable_pthreads "Disable uses of pthreads in gtest." ON) +endif() + add_subdirectory(bin) add_subdirectory(test) add_subdirectory(unittest) diff --git a/bin/CMakeLists.txt b/bin/CMakeLists.txt index 9acab3da1f9b..2f5880e0c5e7 100644 --- a/bin/CMakeLists.txt +++ b/bin/CMakeLists.txt @@ -71,6 +71,7 @@ mlir_check_all_link_libraries(triton-lsp) add_llvm_executable(triton-llvm-opt + PARTIAL_SOURCES_INTENDED triton-llvm-opt.cpp DEPENDS diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 67665b7b1483..c6752f2401ce 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -1433,6 +1433,7 @@ MfmaEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef shape, return {32, parentShapePerCTA[1]}; } else { assert(0 && "DotOperandEncodingAttr opIdx must be 0 or 1"); + return {}; } } diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index ea038eead67f..9a77c480b432 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -913,8 +913,8 @@ void mlir::triton::asyncLaunchDots(scf::ForOp forOp) { lastOp = op; op = op->getBlock()->getParentOp(); } - return std::distance(lastOp->getBlock()->getParent()->begin(), - lastOp->getBlock()->getIterator()); + return (long)std::distance(lastOp->getBlock()->getParent()->begin(), + lastOp->getBlock()->getIterator()); }; /// XXX(Keren): Clean up the following duplicate code with checkDotOp /// dots to be pipelined diff --git a/python/setup.py b/python/setup.py index bca5f258beb5..667c3729870b 100644 --- a/python/setup.py +++ b/python/setup.py @@ -184,7 +184,7 @@ def get_thirdparty_packages(): if p.syspath_var_name in os.environ: package_dir = os.environ[p.syspath_var_name] version_file_path = os.path.join(package_dir, "version.txt") - if p.syspath_var_name not in os.environ and\ + if p.syspath_var_name not in os.environ and p.url and\ (not os.path.exists(version_file_path) or Path(version_file_path).read_text() != p.url): try: shutil.rmtree(package_root_dir) @@ -197,6 +197,9 @@ def get_thirdparty_packages(): # write version url to package_dir with open(os.path.join(package_dir, "version.txt"), "w") as f: f.write(p.url) + elif p.syspath_var_name not in os.environ and not p.url: + raise RuntimeError( + f'{p.syspath_var_name} not set ! Please install {p.package} manually and set {p.syspath_var_name}.') if p.include_flag: thirdparty_cmake_args.append(f"-D{p.include_flag}={package_dir}/include") if p.lib_flag: @@ -210,13 +213,17 @@ def download_and_copy(src_path, variable, version, url_func): return base_dir = os.path.dirname(__file__) system = platform.system() - arch = {"x86_64": "64", "arm64": "aarch64", "aarch64": "aarch64"}[platform.machine()] - url = url_func(arch, version) + arch = {"x86_64": "64", "AMD64": "64", "arm64": "aarch64", "aarch64": "aarch64"}[platform.machine()] + supported = {"Linux": "linux", "Windows": "win"} + is_supported = system in supported + if is_supported: + url = url_func(supported[system], arch, version) tmp_path = os.path.join(triton_cache_path, "nvidia") # path to cache the download dst_path = os.path.join(base_dir, os.pardir, "third_party", "nvidia", "backend", src_path) # final binary path src_path = os.path.join(tmp_path, src_path) + src_path += ".exe" if os.name == "nt" else "" download = not os.path.exists(src_path) - if os.path.exists(dst_path) and system == "Linux": + if os.path.exists(dst_path) and is_supported: curr_version = subprocess.check_output([dst_path, "--version"]).decode("utf-8").strip() curr_version = re.search(r"V([.|\d]+)", curr_version).group(1) download = download or curr_version != version @@ -316,6 +323,10 @@ def build_extension(self, ext): "-DTRITON_CODEGEN_BACKENDS=" + ';'.join([b.name for b in backends if not b.is_external]), "-DTRITON_PLUGIN_DIRS=" + ';'.join([b.src_dir for b in backends if b.is_external]) ] + if platform.system() == "Windows": + installed_base = sysconfig.get_config_var('installed_base') + py_lib_dirs = os.getenv("PYTHON_LIB_DIRS", os.path.join(installed_base, "libs")) + cmake_args.append("-DPYTHON_LIB_DIRS=" + py_lib_dirs) if lit_dir is not None: cmake_args.append("-DLLVM_EXTERNAL_LIT=" + lit_dir) cmake_args.extend(thirdparty_cmake_args) @@ -325,9 +336,8 @@ def build_extension(self, ext): build_args = ["--config", cfg] if platform.system() == "Windows": + cmake_args += ["-DCMAKE_BUILD_TYPE=" + cfg] cmake_args += [f"-DCMAKE_RUNTIME_OUTPUT_DIRECTORY_{cfg.upper()}={extdir}"] - if sys.maxsize > 2**32: - cmake_args += ["-A", "x64"] else: cmake_args += ["-DCMAKE_BUILD_TYPE=" + cfg] max_jobs = os.getenv("MAX_JOBS", str(2 * os.cpu_count())) @@ -369,29 +379,32 @@ def build_extension(self, ext): subprocess.check_call(["cmake", "--build", ".", "--target", "mlir-doc"], cwd=cmake_dir) -download_and_copy( - src_path="bin/ptxas", - variable="TRITON_PTXAS_PATH", - version="12.3.52", - url_func=lambda arch, version: - f"https://anaconda.org/nvidia/cuda-nvcc/12.3.52/download/linux-{arch}/cuda-nvcc-{version}-0.tar.bz2", -) -download_and_copy( - src_path="bin/cuobjdump", - variable="TRITON_CUOBJDUMP_PATH", - version="12.3.52", - url_func=lambda arch, version: - f"https://anaconda.org/nvidia/cuda-cuobjdump/12.3.52/download/linux-{arch}/cuda-cuobjdump-{version}-0.tar.bz2", -) -download_and_copy( - src_path="bin/nvdisasm", - variable="TRITON_NVDISASM_PATH", - version="12.3.52", - url_func=lambda arch, version: - f"https://anaconda.org/nvidia/cuda-nvdisasm/12.3.52/download/linux-{arch}/cuda-nvdisasm-{version}-0.tar.bz2", -) - -backends = [*BackendInstaller.copy(["nvidia", "amd"]), *BackendInstaller.copy_externals()] +if platform.system() in ["Linux", "Windows"]: + download_and_copy( + src_path="bin/ptxas", + variable="TRITON_PTXAS_PATH", + version="12.3.52", + url_func=lambda system, arch, version: + f"https://anaconda.org/nvidia/cuda-nvcc/{version}/download/{system}-{arch}/cuda-nvcc-{version}-0.tar.bz2", + ) + download_and_copy( + src_path="bin/cuobjdump", + variable="TRITON_CUOBJDUMP_PATH", + version="12.3.52", + url_func=lambda system, arch, version: + f"https://anaconda.org/nvidia/cuda-cuobjdump/{version}/download/{system}-{arch}/cuda-cuobjdump-{version}-0.tar.bz2", + ) + download_and_copy( + src_path="bin/nvdisasm", + variable="TRITON_NVDISASM_PATH", + version="12.3.52", + url_func=lambda system, arch, version: + f"https://anaconda.org/nvidia/cuda-nvdisasm/{version}/download/{system}-{arch}/cuda-nvdisasm-{version}-0.tar.bz2", + ) +backends = ["nvidia", "amd"] +if os.name == "nt": + backends = ["nvidia"] +backends = [*BackendInstaller.copy(backends), *BackendInstaller.copy_externals()] def add_link_to_backends(): diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index 3bebd3ac4e09..dc819a53ac79 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -70,7 +70,7 @@ def test_nested1_change(): def write_and_load_module(code, num_extra_lines): - with tempfile.NamedTemporaryFile(mode='w+', suffix='.py') as f: + with tempfile.NamedTemporaryFile(mode='w+', suffix='.py', delete=False) as f: f.write(('# extra line\n' * num_extra_lines) + code) f.flush() spec = importlib.util.spec_from_file_location("module.name", f.name) diff --git a/python/test/unit/runtime/test_subproc.py b/python/test/unit/runtime/test_subproc.py index b9b9c7d3cab8..71ffa6d253e7 100644 --- a/python/test/unit/runtime/test_subproc.py +++ b/python/test/unit/runtime/test_subproc.py @@ -34,11 +34,15 @@ def kernel_sub(a, b, o, N: tl.constexpr): def test_compile_in_subproc() -> None: + import os major, minor = torch.cuda.get_device_capability(0) cc = major * 10 + minor config = triton.compiler.AttrsDescriptor(tuple(range(4)), ()) - multiprocessing.set_start_method('fork') + if os.name == "nt": + multiprocessing.set_start_method('spawn') + else: + multiprocessing.set_start_method('fork') proc = multiprocessing.Process(target=compile_fn, args=(config, cc)) proc.start() proc.join() @@ -64,7 +68,7 @@ def test_compile_in_forked_subproc() -> None: capability = major * 10 + minor config = triton.compiler.AttrsDescriptor(tuple(range(1)), ()) - assert multiprocessing.get_start_method() == 'fork' + assert multiprocessing.get_start_method() in ['fork', 'spawn'] proc = multiprocessing.Process(target=compile_fn_dot, args=(config, capability)) proc.start() proc.join() diff --git a/python/triton/backends/compiler.py b/python/triton/backends/compiler.py index 0655b3fa5ade..12a9ba0064d1 100644 --- a/python/triton/backends/compiler.py +++ b/python/triton/backends/compiler.py @@ -12,13 +12,17 @@ def __init__(self, target: tuple) -> None: @staticmethod def _path_to_binary(binary: str): + binary += ".exe" if os.name == "nt" else "" base_dir = os.path.join(os.path.dirname(__file__), os.pardir) paths = [ os.environ.get(f"TRITON_{binary.upper()}_PATH", ""), os.path.join(base_dir, "third_party", "cuda", "bin", binary), ] for p in paths: - bin = p.split(" ")[0] + if os.name != "nt": + bin = p.split(" ")[0] + else: + bin = p if os.path.exists(bin) and os.path.isfile(bin): result = subprocess.check_output([bin, "--version"], stderr=subprocess.STDOUT) if result is not None: diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index c0965b17bb60..d179ef27058e 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -154,7 +154,8 @@ def triton_key(): contents += [hashlib.sha1(f.read()).hexdigest()] # backend libtriton_hash = hashlib.sha1() - with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f: + ext = "so" if os.name != "nt" else "pyd" + with open(os.path.join(TRITON_PATH, "_C", "libtriton." + ext), "rb") as f: while True: chunk = f.read(1024**2) if not chunk: diff --git a/python/triton/runtime/build.py b/python/triton/runtime/build.py index d7baeb2868b0..9726836a2939 100644 --- a/python/triton/runtime/build.py +++ b/python/triton/runtime/build.py @@ -18,6 +18,26 @@ def quiet(): sys.stdout, sys.stderr = old_stdout, old_stderr +def _cc_cmd(cc, src, out, include_dirs, library_dirs, libraries): + if cc in ["cl", "clang-cl"]: + cc_cmd = [cc, src, "/nologo", "/O2", "/LD"] + cc_cmd += [f"/I{dir}" for dir in include_dirs] + cc_cmd += ["/link"] + cc_cmd += [f"/LIBPATH:{dir}" for dir in library_dirs] + cc_cmd += [f'{lib}.lib' for lib in libraries] + cc_cmd += [f"/OUT:{out}"] + else: + cc_cmd = [cc, src, "-O3", "-shared", "-fPIC"] + cc_cmd += [f'-l{lib}' for lib in libraries] + cc_cmd += [f"-L{dir}" for dir in library_dirs] + cc_cmd += [f"-I{dir}" for dir in include_dirs] + cc_cmd += ["-o", out] + + if os.name == "nt": cc_cmd.pop(cc_cmd.index("-fPIC")) + + return cc_cmd + + def _build(name, src, srcdir, library_dirs, include_dirs, libraries): suffix = sysconfig.get_config_var('EXT_SUFFIX') so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix)) @@ -41,10 +61,7 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries): scheme = 'posix_prefix' py_include_dir = sysconfig.get_paths(scheme=scheme)["include"] include_dirs = include_dirs + [srcdir, py_include_dir] - cc_cmd = [cc, src, "-O3", "-shared", "-fPIC", "-o", so] - cc_cmd += [f'-l{lib}' for lib in libraries] - cc_cmd += [f"-L{dir}" for dir in library_dirs] - cc_cmd += [f"-I{dir}" for dir in include_dirs] + cc_cmd = _cc_cmd(cc, src, so, include_dirs, library_dirs, libraries) ret = subprocess.check_call(cc_cmd) if ret == 0: return so @@ -58,7 +75,7 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries): language='c', sources=[src], include_dirs=include_dirs, - extra_compile_args=extra_compile_args + ['-O3'], + extra_compile_args=extra_compile_args + ['-O3' if "-O3" in cc_cmd else "/O2"], extra_link_args=extra_link_args, library_dirs=library_dirs, libraries=libraries, diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index 5823c14af033..f3219ad6805c 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -14,13 +14,17 @@ from pathlib import Path def _path_to_binary(binary: str): + binary += ".exe" if os.name == "nt" else "" paths = [ os.environ.get(f"TRITON_{binary.upper()}_PATH", ""), os.path.join(os.path.dirname(__file__), "bin", binary), ] for p in paths: - bin = p.split(" ")[0] + if os.name != "nt": + bin = p.split(" ")[0] + else: + bin = p if os.path.exists(bin) and os.path.isfile(bin): result = subprocess.check_output([bin, "--version"], stderr=subprocess.STDOUT) if result is not None: @@ -226,16 +230,19 @@ def make_cubin(src, metadata, opt, capability): fsrc.flush() fbin = fsrc.name + '.o' - line_info = '' if os.environ.get('TRITON_DISABLE_LINE_INFO') else ' -lineinfo' - fmad = '' if opt.enable_fp_fusion else ' --fmad=false' - suffix = 'a ' if capability == 90 else ' ' + line_info = '' if os.environ.get('TRITON_DISABLE_LINE_INFO') else '-lineinfo' + fmad = '' if opt.enable_fp_fusion else '--fmad=false' + suffix = 'a' if capability == 90 else '' + cmd = [ptxas] + cmd += [line_info] if line_info != '' else [] + cmd += [fmad] if fmad != '' else [] + cmd += ['-v'] if os.environ.get("DISABLE_PTXAS_OPT", "0") == "1": - cmd = f'{ptxas}{line_info}{fmad} -v --opt-level 0 --gpu-name=sm_{capability}{suffix}{fsrc.name} -o {fbin} 2> {flog.name}' - else: - cmd = f'{ptxas}{line_info}{fmad} -v --gpu-name=sm_{capability}{suffix}{fsrc.name} -o {fbin} 2> {flog.name}' + cmd += ["-opt-level", "0"] + cmd += [f'--gpu-name=sm_{capability}{suffix}', fsrc.name, '-o', fbin] try: - subprocess.run(cmd, shell=True, check=True) + subprocess.run(cmd, check=True, stderr=flog) except subprocess.CalledProcessError as e: with open(flog.name) as log_file: log = log_file.read() @@ -246,16 +253,17 @@ def make_cubin(src, metadata, opt, capability): f'Please run `ptxas {fsrc.name}` to confirm that this is a bug in `ptxas`\n{log}') else: raise RuntimeError(f'`ptxas` failed with error code {e.returncode}: \n{log}') - finally: - if os.path.exists(fsrc.name): - os.remove(fsrc.name) - if os.path.exists(flog.name): - os.remove(flog.name) with open(fbin, 'rb') as f: cubin = f.read() if os.path.exists(fbin): os.remove(fbin) + + if os.path.exists(fsrc.name): + os.remove(fsrc.name) + if os.path.exists(flog.name): + os.remove(flog.name) + return cubin def add_stages(self, stages, options): diff --git a/third_party/nvidia/backend/driver.c b/third_party/nvidia/backend/driver.c index 037cfca1d0e4..64eaf0771862 100644 --- a/third_party/nvidia/backend/driver.c +++ b/third_party/nvidia/backend/driver.c @@ -1,5 +1,10 @@ #include "cuda.h" +#ifndef _WIN32 #include +#else +#define WIN32_LEAN_AND_MEAN +#include +#endif #include #define PY_SSIZE_T_CLEAN #include @@ -135,6 +140,7 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) { typedef CUresult (*cuOccupancyMaxActiveClusters_t)( int *numClusters, CUfunction func, const CUlaunchConfig *config); +#ifndef _WIN32 #define defineGetFunctionHandle(name, symbolName) \ static symbolName##_t name() { \ /* Open the shared library */ \ @@ -156,6 +162,27 @@ typedef CUresult (*cuOccupancyMaxActiveClusters_t)( } \ return funcHandle; \ } +#else +#define defineGetFunctionHandle(name, symbolName) \ + static symbolName##_t name() { \ + /* Open the shared library */ \ + HMODULE handle = LoadLibraryA("nvcuda.dll"); \ + if (!handle) { \ + PyErr_SetString(PyExc_RuntimeError, "Failed to open nvcuda.dll"); \ + return NULL; \ + } \ + symbolName##_t funcHandle = \ + (symbolName##_t)GetProcAddress((HMODULE)handle, #symbolName); \ + /* Check for errors */ \ + long err = GetLastError(); \ + if (err) { \ + PyErr_SetString(PyExc_RuntimeError, \ + "Failed to retrieve " #symbolName " from nvcuda.dll"); \ + return NULL; \ + } \ + return funcHandle; \ + } +#endif defineGetFunctionHandle(getCuOccupancyMaxActiveClustersHandle, cuOccupancyMaxActiveClusters); diff --git a/third_party/nvidia/backend/driver.py b/third_party/nvidia/backend/driver.py index 5bd178f09028..e6bd6b5758cf 100644 --- a/third_party/nvidia/backend/driver.py +++ b/third_party/nvidia/backend/driver.py @@ -1,6 +1,7 @@ import functools import os import hashlib +import sysconfig import subprocess import tempfile from pathlib import Path @@ -13,12 +14,20 @@ libdevice_dir = os.path.join(dirname, "lib") libraries = ['cuda'] +if os.name == "nt": + include_dir += [os.path.join(os.environ.get("CUDA_PATH"), "include")] + @functools.lru_cache() def libcuda_dirs(): env_libcuda_path = os.getenv("TRITON_LIBCUDA_PATH") if env_libcuda_path: return [env_libcuda_path] + if os.name == "nt": + installed_base = sysconfig.get_config_var('installed_base') + dirs = [os.path.join(os.environ.get("CUDA_PATH"), "lib", "x64")] + dirs += [os.getenv("PYTHON_LIB_DIRS", os.path.join(installed_base, "libs"))] + return dirs libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode() # each line looks like the following: @@ -47,7 +56,8 @@ def library_dirs(): def compile_module_from_src(src, name): key = hashlib.md5(src.encode("utf-8")).hexdigest() cache = get_cache_manager(key) - cache_path = cache.get_file(f"{name}.so") + so_name = f'{name}.{"so" if os.name != "nt" else "pyd"}' + cache_path = cache.get_file(so_name) if cache_path is None: with tempfile.TemporaryDirectory() as tmpdir: src_path = os.path.join(tmpdir, "main.c") @@ -55,7 +65,7 @@ def compile_module_from_src(src, name): f.write(src) so = _build(name, src_path, tmpdir, library_dirs(), include_dir, libraries) with open(so, "rb") as f: - cache_path = cache.put(f.read(), f"{name}.so", binary=True) + cache_path = cache.put(f.read(), so_name, binary=True) import importlib.util spec = importlib.util.spec_from_file_location(name, cache_path) mod = importlib.util.module_from_spec(spec) @@ -147,7 +157,12 @@ def format_of(ty): #include \"cuda.h\" #include #include +#ifndef _WIN32 #include +#else +#define WIN32_LEAN_AND_MEAN +#include +#endif static inline void gpuAssert(CUresult code, const char *file, int line) {{ @@ -170,6 +185,7 @@ def format_of(ty): typedef CUresult (*cuLaunchKernelEx_t)(const CUlaunchConfig* config, CUfunction f, void** kernelParams, void** extra); +#ifndef _WIN32 static cuLaunchKernelEx_t getLaunchKernelExHandle() {{ // Open the shared library void* handle = dlopen("libcuda.so", RTLD_LAZY); @@ -188,6 +204,25 @@ def format_of(ty): }} return cuLaunchKernelExHandle; }} +#else +static cuLaunchKernelEx_t getLaunchKernelExHandle() {{ + // Open the shared library + HMODULE handle = LoadLibraryA("nvcuda.dll"); + if (!handle) {{ + PyErr_SetString(PyExc_RuntimeError, "Failed to open nvcuda.dll"); + return NULL; + }} + cuLaunchKernelEx_t cuLaunchKernelExHandle = + (cuLaunchKernelEx_t)GetProcAddress((HMODULE)handle, "cuLaunchKernelEx"); + // Check for errors + long error = GetLastError(); + if (error) {{ + PyErr_SetString(PyExc_RuntimeError, "Failed to retrieve cuLaunchKernelEx from nvcuda.dll"); + return NULL; + }} + return cuLaunchKernelExHandle; +}} +#endif static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ void *params[] = {{ {', '.join(f"&arg{i}" for i in params)} }};