diff --git a/xlstm/blocks/slstm/src/cuda_init.py b/xlstm/blocks/slstm/src/cuda_init.py index bc02065..de5366d 100644 --- a/xlstm/blocks/slstm/src/cuda_init.py +++ b/xlstm/blocks/slstm/src/cuda_init.py @@ -58,6 +58,9 @@ def defines_to_cflags(defines=Union[dict[str, Union[int, str]], Sequence[tuple[s def load(*, name, sources, extra_cflags=(), extra_cuda_cflags=(), **kwargs): + import platform + system_name = platform.system() + suffix = "" for flag in extra_cflags: pref = [st[0] for st in flag[2:].split("=")[0].split("_")] @@ -73,6 +76,21 @@ def load(*, name, sources, extra_cflags=(), extra_cuda_cflags=(), **kwargs): suffix = "_" + suffix suffix = suffix[:64] + # Determine platform-specific linker flags + extra_ldflags = ( + [f"/LIBPATH:{os.environ['CUDA_LIB']}", "cublas.lib"] + if system_name == "Windows" + else [f"-L{os.environ['CUDA_LIB']}", "-lcublas"] + ) + + # Handle -Xptxas flags platform-specifically + if system_name == "Windows": + xptxas_verbose = ["-Xptxas", "-v"] # must be split for nvcc on Windows + xptxas_opt = ["-Xptxas", "-O3"] + else: + xptxas_verbose = ['-Xptxas="-v"'] # works fine quoted on Linux + xptxas_opt = ['-Xptxas -O3'] + extra_cflags = list(extra_cflags) + [ "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", @@ -86,26 +104,23 @@ def load(*, name, sources, extra_cflags=(), extra_cuda_cflags=(), **kwargs): extra_cflags.append("-isystem") extra_cflags.append(eip) + # Compose full set of CUDA flags + extra_cuda_flags = [ + *xptxas_verbose, + "-gencode", "arch=compute_80,code=compute_80", + "-res-usage", + "--use_fast_math", + "-O3", + *xptxas_opt, + "--extra-device-vectorization", + ] + myargs = { "verbose": True, "with_cuda": True, - "extra_ldflags": [f"-L{os.environ['CUDA_LIB']}", "-lcublas"], + "extra_ldflags": extra_ldflags, "extra_cflags": [*extra_cflags], - "extra_cuda_cflags": [ - # "-gencode", - # "arch=compute_70,code=compute_70", - # "-dbg=1", - '-Xptxas="-v"', - "-gencode", - "arch=compute_80,code=compute_80", - "-res-usage", - "--use_fast_math", - "-O3", - "-Xptxas -O3", - "--extra-device-vectorization", - *extra_cflags, - *extra_cuda_cflags, - ], + "extra_cuda_cflags": [*extra_cuda_flags, *extra_cflags, *extra_cuda_cflags], } print(myargs) myargs.update(**kwargs)