[TMP]: Disable layout rematerialization cache#1275
Conversation
| auto layoutIt = layout.find(v); | ||
| assert(layoutIt != layout.end()); | ||
| // If we already have a remat value for this value, use it. | ||
| #if 0 // FIXME: Fails on LTS driver |
There was a problem hiding this comment.
AS discussed offline I think we should compile this code conditionally based on the target environment (e.g. driver version).
There was a problem hiding this comment.
Hi @etiotto I understand your concerns but I don't think that is a good idea for the following reasons:
- this change would affect many signatures in
RemoveLayoutConversions.cppcausing further divergence / merge conflict with upstream - this change is expected to be short lived once we identify the regression and find a better workaround or get the LTS driver team to fix it
- none of our nightly CI uses the LTS driver so the negative codepath would be essentially untested
If we object to having this change in llvm-target because of potential performance degradation then I would suggest just merging it directly to the PyTorch release branch. Of course, that precludes PyTorch upstream from using llvm-target until we find a better workaround or get the LTS driver fix in.
FWIW I actually did implement the change but can't test it because the only LTS machine is currently broken.
cc @vlad-penkin
There was a problem hiding this comment.
There's a new development I am investigating... standby...
There was a problem hiding this comment.
@etiotto I was able to recover the LTS machine reproducer and add the LTS driver flag. Please take a look. Because of the regression in llvm-target this branch is not up to date. After the PR is approved, I will update and merge. But I want to try and keep the branch where it is now locally so I can continue testing.
8e019df to
452f4ff
Compare
This phase does not affect LTS driver
| pm.enable_debug() | ||
| passes.ttir.add_convert_to_ttgpuir(pm, f"xpu:{device_arch}", opt.num_warps, opt.threads_per_warp, opt.num_ctas) | ||
|
|
||
| is_lts_driver = Version(metadata["target"].arch['driver_version']) == Version("1.3.27642") |
There was a problem hiding this comment.
NIT: Be consistent with string quote character within a file.
Temporarily disables the layout rematerialization cache for backward slice rematerialization. The remat cache is exposing a bug on systems with the LTS driver which results in an accuracy error with two HuggingFace models under the PyTorch Inductor benchmarks. Issue #1255 (cherry picked from commit eb51a81)
Temporarily disables the layout rematerialization cache for backward slice rematerialization. The remat cache is exposing a bug on systems with the LTS driver which results in an accuracy error with two HuggingFace models under the PyTorch Inductor benchmarks. Issue #1255 (cherry picked from commit eb51a81)
Temporarily disables the layout rematerialization cache for backward slice rematerialization. The remat cache is exposing a bug on systems with the LTS driver which results in an accuracy error with two HuggingFace models under the PyTorch Inductor benchmarks. Issue #1255 (cherry picked from commit eb51a81)
This reverts commit eb51a81.
This reverts commit eb51a81.
The workarounds were added in #1275 and #1337. All huggingface training float32 models pass with the LTS workaround removed: https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/9865658421 Signed-off-by: Whitney Tsang <whitney.tsang@intel.com>
Temporarily disables the layout rematerialization cache for backward slice rematerialization. The remat cache is exposing a bug on systems with the LTS driver which results in an accuracy error with two HuggingFace models under the PyTorch Inductor benchmarks. Tracking down the driver issue and providing a reproducer to the driver team will take some time, so this is a stopgap solution. I ran the same Float32 / training HF benchmarks compared with llvm-target and there is a small regression on the
T5Smallmodel, but all other models are relatively similar tollvm-target. I am running the accuracy tests locally for Float32 / Training HF and so far results are promising so I am marking this ready for review.Issue #1255