diff --git a/aiter/jit/utils/cpp_extension.py b/aiter/jit/utils/cpp_extension.py index 531ce8755b..da5248132a 100644 --- a/aiter/jit/utils/cpp_extension.py +++ b/aiter/jit/utils/cpp_extension.py @@ -1667,6 +1667,7 @@ def sanitize_flags(flags): # Version 1.3 is required for the `deps` directive. config = ["ninja_required_version = 1.3"] config.append(f"cxx = {compiler}") + config.append(f"ar = {os.environ.get('AITER_AR_BIN', 'ar')}") if with_cuda or cuda_dlink_post_cflags: nvcc = _join_rocm_home("bin", "hipcc") config.append(f"nvcc = {nvcc}") @@ -1719,24 +1720,53 @@ def sanitize_flags(flags): else: devlink_rule, devlink = [], [] + aiter_build_static = os.environ.get("AITER_BUILD_STATIC", "0") + emit_static_archive = ( + library_target is not None and + aiter_build_static != "0" and + ( + aiter_build_static in ("1", "all") or + os.path.splitext(library_target)[0] in aiter_build_static.split(",") + ) + ) + if library_target is not None: link_rule = ["rule link"] link_rule.append( " command = $cxx @$out.rsp $ldflags -o $out\n rspfile = $out.rsp\n rspfile_content = $in" ) - link = [f'build {library_target}: link {" ".join(objects)}'] - default = [f"default {library_target}"] + if emit_static_archive: + archive_rule = ["rule archive"] + archive_rule.append( + " command = rm -f $out && $ar rcs $out @$out.rsp\n rspfile = $out.rsp\n rspfile_content = $in" + ) + + archive_target = os.path.splitext(library_target)[0] + ".a" + archive = [f'build {archive_target}: archive {" ".join(objects)}'] + default = [f"default {library_target} {archive_target}"] + else: + archive_rule, archive = [], [] + default = [f"default {library_target}"] else: - link_rule, link, default = [], [], [] + link_rule, archive_rule, link, archive, default = [], [], [], [], [] # 'Blocks' should be separated by newlines, for visual benefit. blocks = [config, flags, compile_rule] if with_cuda: blocks.append(cuda_compile_rule) # type: ignore[possibly-undefined] - blocks += [devlink_rule, link_rule, build, devlink, link, default] + blocks += [ + devlink_rule, + link_rule, + archive_rule, + build, + devlink, + link, + archive, + default, + ] content = "\n\n".join("\n".join(b) for b in blocks) # Ninja requires a new lines at the end of the .ninja file content += "\n"