diff --git a/python/test/unit/intel/test_driver.py b/python/test/unit/intel/test_driver.py index b43dae007e..bc77ed6dc2 100644 --- a/python/test/unit/intel/test_driver.py +++ b/python/test/unit/intel/test_driver.py @@ -99,3 +99,60 @@ def test_has_opencl_extension_error(device): # Pass an invalid device_id (out of range) to trigger error with pytest.raises(RuntimeError, match="Device is not found"): driver.active.utils.has_opencl_extension(device_count, b"cl_khr_fp16") + + +@pytest.mark.parametrize("grf_mode, expect_retry, expect_fail", + [("default", True, False), # Should auto-retry with large GRF and succeed + ("256", False, False), # Explicit large GRF — compiles on first attempt + ("128", False, True), # Explicit small GRF — should fail, no retry + ]) +@pytest.mark.parametrize("generate_native_code", [False, True], ids=["load_binary", "make_zebin"]) +def test_auto_grf_on_build_failure(device, monkeypatch, capfd, grf_mode, expect_retry, expect_fail, + generate_native_code): + """Test GRF mode behavior for register-heavy kernels on both compilation paths: + - load_binary (generate_native_code=False): L0 runtime compilation via zeModuleCreate + - make_zebin (generate_native_code=True): offline compilation via ocloc + """ + monkeypatch.setenv("TRITON_DEBUG", "1") + + @triton.jit + def _register_heavy_kernel( + output_ptr, + input_ptr, + q_ptr, + size, + BLOCK: tl.constexpr, + ): + off = tl.arange(0, BLOCK) + mask = off < size + x = tl.load(input_ptr + off, mask=mask, other=0.0) + q = tl.load(q_ptr + off, mask=mask, other=float("-inf")) + result = tl.argmax(x / q, axis=-1) + tl.store(output_ptr, result) + + BLOCK = 131072 # Large enough to exceed PTSS with default/small GRF + size = 128000 + + x = torch.randn(size, dtype=torch.float32, device=device) + q = torch.rand(size, dtype=torch.float32, device=device) + out = torch.empty(1, dtype=torch.int32, device=device) + + if expect_fail: + with pytest.raises(RuntimeError): + _register_heavy_kernel[(1, )](out, x, q, size, BLOCK=BLOCK, grf_mode=grf_mode, + generate_native_code=generate_native_code) + else: + _register_heavy_kernel[(1, )](out, x, q, size, BLOCK=BLOCK, grf_mode=grf_mode, + generate_native_code=generate_native_code) + + outs = capfd.readouterr().out + if expect_retry and not generate_native_code: + # load_binary path prints a retry message to stdout. + assert "retrying with large GRF mode" in outs or "recompiling the kernel using large GRF mode" in outs + elif expect_retry and generate_native_code: + # make_zebin path retries silently via ocloc — no stdout message. + # Success without exception is sufficient verification. + pass + else: + assert "retrying with large GRF mode" not in outs + assert "Build failed" not in outs diff --git a/third_party/intel/backend/compiler.py b/third_party/intel/backend/compiler.py index 0a95990f63..7cdcf50167 100644 --- a/third_party/intel/backend/compiler.py +++ b/third_party/intel/backend/compiler.py @@ -495,16 +495,33 @@ def make_zebin(cls, src, metadata, options): ocloc_cmd[-1] = metadata["build_flags"] + shader_dump_opt subprocess.check_output(ocloc_cmd, stderr=subprocess.STDOUT, text=True) except subprocess.CalledProcessError as e: - if e.returncode == 255: - error = 'Internal Triton ZEBIN codegen error' - elif e.returncode == 128 + signal.SIGSEGV: - error = '`ocloc` raised SIGSEGV' - else: - error = f'`ocloc` failed with error code {e.returncode}' - - raise RuntimeError(f'{error}\n' - f'`ocloc` stderr:\n{e.output}\n' - f'Repro command: {ocloc_cmd}\n') from e + # If GRF mode was not explicitly set, retry with large GRF mode + # before giving up. This handles cases where the default GRF mode + # doesn't provide enough registers (e.g., scratch space exceeds + # HW PTSS limit). + retry_succeeded = False + if options.grf_mode == 'default' and \ + "-cl-intel-256-GRF-per-thread" not in metadata.get("build_flags", ""): + metadata["build_flags"] += " -cl-intel-256-GRF-per-thread" + ocloc_cmd[-1] = metadata["build_flags"] + shader_dump_opt + try: + subprocess.check_output(ocloc_cmd, stderr=subprocess.STDOUT, text=True) + retry_succeeded = True + except subprocess.CalledProcessError: + # Retry also failed — raise the original error below. + pass + + if not retry_succeeded: + if e.returncode == 255: + error = 'Internal Triton ZEBIN codegen error' + elif e.returncode == 128 + signal.SIGSEGV: + error = '`ocloc` raised SIGSEGV' + else: + error = f'`ocloc` failed with error code {e.returncode}' + + raise RuntimeError(f'{error}\n' + f'`ocloc` stderr:\n{e.output}\n' + f'Repro command: {ocloc_cmd}\n') from e with open(fbin, 'rb') as f: zebin = f.read() diff --git a/third_party/intel/backend/driver.c b/third_party/intel/backend/driver.c index 0b6691e6f8..473d147f2f 100644 --- a/third_party/intel/backend/driver.c +++ b/third_party/intel/backend/driver.c @@ -272,12 +272,55 @@ extern "C" EXPORT_FUNC PyObject *load_binary(PyObject *args) { auto [l0_module, l0_kernel, n_spills] = compileLevelZeroObjects(binary_ptr, binary_size, kernel_name, l0_device, l0_context, build_flags(), is_spv); - if (PyErr_Occurred()) { - return NULL; - } const bool debugEnabled = getBoolEnv("TRITON_DEBUG"); + // If the initial compilation failed entirely (e.g., scratch space exceeds + // HW limit), and GRF mode was not explicitly set, retry with large GRF mode. + // This handles cases where the default GRF mode doesn't provide enough + // registers, causing the backend compiler to fail. + if (PyErr_Occurred() && is_spv && !build_flags.hasGRFSizeFlag()) { + // Save the original error before clearing it for the retry attempt. + PyObject *orig_type, *orig_value, *orig_tb; + PyErr_Fetch(&orig_type, &orig_value, &orig_tb); + + if (debugEnabled) + std::cout << "(I): Build failed for \"" << kernel_name + << "\", retrying with large GRF mode" << std::endl; + + build_flags.addLargeGRFSizeFlag(); + + auto [l0_module_retry, l0_kernel_retry, n_spills_retry] = + compileLevelZeroObjects(binary_ptr, binary_size, kernel_name, l0_device, + l0_context, build_flags(), is_spv); + if (PyErr_Occurred()) { + // Retry also failed — propagate the original error. + PyErr_Restore(orig_type, orig_value, orig_tb); + return NULL; + } + + // Retry succeeded — discard the saved original error. + Py_XDECREF(orig_type); + Py_XDECREF(orig_value); + Py_XDECREF(orig_tb); + + l0_module = l0_module_retry; + l0_kernel = l0_kernel_retry; + n_spills = n_spills_retry; + + // Always print recovery message to stderr to follow up on the + // "L0 build module failed" error that was already printed. + std::cerr << "(I): Build failure recovered by retrying with large GRF " + "mode for \"" + << kernel_name << "\"" << std::endl; + + if (debugEnabled) + std::cout << "(I): Retry with large GRF succeeded, kernel has " + << n_spills << " spills" << std::endl; + } else if (PyErr_Occurred()) { + return NULL; + } + if (is_spv) { constexpr int32_t max_reg_spill = 1000; const bool is_GRF_mode_specified = build_flags.hasGRFSizeFlag();