Skip to content

[DRIVER][VLLM] Auto-retry kernel compilation with large GRF mode on build failure#6123

Merged
anmyachev merged 5 commits into
mainfrom
egor/issue_3777
Feb 18, 2026
Merged

[DRIVER][VLLM] Auto-retry kernel compilation with large GRF mode on build failure#6123
anmyachev merged 5 commits into
mainfrom
egor/issue_3777

Conversation

@Egor-Krivov
Copy link
Copy Markdown
Contributor

Problem

When a Triton kernel requires too many registers on Intel XPU, the IGC backend
compiler fails with "total scratch space exceeds HW supported limit" (PTSS).
The existing retry logic in load_binary() only handles the case where
compilation succeeds but has a high spill count (>1000). When compilation
fails entirely, the error propagates immediately without attempting large
GRF mode.

Users must manually add grf_mode='256' to work around this, which is not
discoverable and differs from NVIDIA where such issues don't occur.

Fixes #3777

Solution

Extend the GRF retry logic in load_binary() (driver.c) to also cover
complete build failures:

  • When zeModuleCreate fails with ZE_RESULT_ERROR_MODULE_BUILD_FAILURE
    and no GRF mode was explicitly set, clear the error and retry compilation
    with -cl-intel-256-GRF-per-thread
  • If the retry succeeds, continue with the large-GRF kernel
  • If the retry also fails, propagate the original error
  • Print a recovery message to stderr so the user knows the initial error
    was handled
  • Apply the same retry-on-failure pattern to the ocloc offline compilation
    path in make_zebin() (compiler.py)

The existing spill-count-based retry (successful compilation but >1000 spills)
is unchanged and still applies when the build-failure retry is not triggered.

Testing

Added parametrized test test_auto_grf_on_build_failure covering:

  • grf_mode='default': build fails → auto-retries with large GRF → succeeds
  • grf_mode='256': explicit large GRF → compiles directly, no retry
  • grf_mode='128': explicit small GRF → fails, no retry (respects user choice)

When I run the original reproducer from #3777 I now get a pass:

(triton) (312) jovyan@jupyter-ekrivov:~/triton/intel-xpu-backend-for-triton/issues/3777/ut$ python test_old.py 
L0 build module failed. Log: 
warning: [RetryManager] Start recompilation of the kernel
in kernel: 'sample_recovered_tokens_kernel'

error: total scratch space exceeds HW supported limit for kernel sample_recovered_tokens_kernel: 270848 bytes (max permitted PTSS 262144 bytes)
error: backend compiler failed build.

(I): Build failure recovered by retrying with large GRF mode for "sample_recovered_tokens_kernel"

@Egor-Krivov Egor-Krivov requested review from anmyachev, etiotto and whitneywhtsang and removed request for anmyachev February 17, 2026 15:36

// Always print recovery message to stderr to follow up on the
// "L0 build module failed" error that was already printed.
std::cerr << "(I): Build failure recovered by retrying with large GRF "
Copy link
Copy Markdown
Contributor Author

@Egor-Krivov Egor-Krivov Feb 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I write to stderr because stderr at this point already contains error message from IGC and we need to tell the user that issue is probably fixed

Copy link
Copy Markdown
Contributor

@anmyachev anmyachev left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense for me. Since it's a change in the driver - let's update your test to cover more cases.

Comment thread python/test/unit/intel/test_driver.py Outdated
("256", False, False), # Explicit large GRF — compiles on first attempt
("128", False, True), # Explicit small GRF — should fail, no retry
])
def test_auto_grf_on_build_failure(device, monkeypatch, capfd, grf_mode, expect_retry, expect_fail):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's test both: make_zebin and load_binary. I guess TRITON_XPU_GEN_NATIVE_CODE should help for this.

Comment thread third_party/intel/backend/driver.c Outdated
compileLevelZeroObjects(binary_ptr, binary_size, kernel_name, l0_device,
l0_context, build_flags(), is_spv);
if (PyErr_Occurred()) {
// Retry also failed — propagate the error.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we raise the initial exception here to be align with make_zebin?

@Egor-Krivov
Copy link
Copy Markdown
Contributor Author

@anmyachev I updated the PR based on your feedback, but GH doesn't show changes for some reason
main...egor/issue_3777

@Egor-Krivov
Copy link
Copy Markdown
Contributor Author

Ok, after 36ab67d PR got updated.

Comment thread third_party/intel/backend/driver.c Outdated
Copy link
Copy Markdown
Contributor

@anmyachev anmyachev left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One small comment, everything else is LGTM!

@Egor-Krivov Egor-Krivov enabled auto-merge (squash) February 18, 2026 16:13
@anmyachev anmyachev changed the title [VLLM] Auto-retry kernel compilation with large GRF mode on build failure [DRIVER][VLLM] Auto-retry kernel compilation with large GRF mode on build failure Feb 18, 2026
@anmyachev anmyachev disabled auto-merge February 18, 2026 17:32
@anmyachev anmyachev enabled auto-merge (squash) February 18, 2026 17:32
@anmyachev anmyachev merged commit 3b92b8f into main Feb 18, 2026
15 checks passed
@anmyachev anmyachev deleted the egor/issue_3777 branch February 18, 2026 19:42
@jikunshang
Copy link
Copy Markdown

may I know which triton-xpu release will contain this fix?

wdziurdz pushed a commit that referenced this pull request Apr 7, 2026
* src/main:
  [DRIVER][VLLM] Auto-retry kernel compilation with large GRF mode on build failure (#6123)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[VLLM] Implement automatic usage of grf_mode="256" for kernels that use a lot of registers

4 participants