Skip to content

Conversation

@LPanosTT
Copy link
Contributor

@LPanosTT LPanosTT commented Mar 10, 2025

Use a torch.export.ExportedProgram to generate the initial MLIR module.

This requires us to create an ExportedProgram from the initial GraphModule.

Benefits:

  • We can use the torch-mlir's official entrypoint
    • This handles in-place ops for us
  • We can run decompositions and keep location data
    • This location data will stick around throughout the compile process

Issues:

  • aten.clamp is decomposed by torch-mlir to maximum(minimum(input, max), min). ttnn.maximum requires that the operand which needs to be broadcasted is on the RHS. Currently, in tt-mlir the PartiallyBroadcastable op trait only enforces that the broadcasted operand is on the LHS
  • Graph parameters are inlined as constants in the graph. To have the FxImporter treat them as graph inputs we need to edit the ExportedModules ExportedGraphSignature and force all "parameter" types to "user inputs"
    • This is a hack as the ExportedGraphSignature is meant to be a private member of ExportedProgram
    • Ideally we can configure the FxImporter to not inline the parameters if we pass a flag of some sort. Perhaps a future contribution to torch-mlir.

Other Info:

  • We need to upgrade to PyTorch 2.6.0 as it contains crucial changes which allow us to use custom decompositions (necessary to support interpolation)
    • AdaptiveAvgPool2d is lowered AvgPool2d and eventually to stablehlo.reduce_window **even in the case where the op is equivalent to a global average**. Since we do not have support for lowering a sum_pool in StablehloToTTIRPatterns.cpp(sum because the division is afterward), I've temporarily added a custom decomposition ofaten.avg_pool2dwhich will convert to a mean over the spatial dimensions when theavg_pool2d` is equivalent to it.
    • aten.split is no longer lowered to a series of narrow ops. Instead it is now lowered to a series of as_strided ops.
      • narrow is lowered to slice, which can be lowered to stablehlo.slice. as_strided cannot be lowered from Torch Backend IR to Stablehlo. I've temporarily added back the old decomposition from PyTorch 2.5.0 which uses narrow as a custom decomposition.
      • I've made a PR which adds a lowering of AtenAsStridedOp to stablehlo::SliceOp in our fork of torch-mlir: Lower AtenAsStridedOp to stablehlo::SliceOp when possible. llvm-torch-mlir#4
    • The tracer which generates the GraphModule which is passed to backend does not account for control flow, I believe in PyTorch 2.5.0 a graph break would be triggered during .generate methods in transformers LLMs. It does not anymore and so .generate will run until the max length is reached.
      • this means that the entire generation becomes one program
      • Once the first EOS token is generated, the rest of the length is filled with padding. We cannot compare the golden output to the result from the GraphModule as the output shapes are different.
        • Since the output of .generate graphs are integers PCC/atol verification is not quite useful but does return True when the outputs are identical
        • The tokenizer can decode the outputs and strip padding.
        • I've added a flag to ModelTester that informs the ModelTester it is testing a .generate call. It will decode the output tokens and we compare the resulting strings.
      • PyTorch has an experimental torch.cond which they seem to intend to use to trace data-dependent control-flow. There's a note in the transformers source that says they intend to use it when it is no longer experimental
  • When the graph is compiled, the user inputs are placed at the end of the arguments passed to the program rather than the front. That is graph constants first, then inputs.
  • I needed to implement an FxImporter hook for importing literals to the graph. By default it will make all non-scalars DenseElementsResourceAttrs, however, this causes the process to hang upon cleanup whether the test fails or not. So the hook just uses DenseElementsAttr for all literals.

@LPanosTT LPanosTT marked this pull request as draft March 10, 2025 21:53
@github-actions
Copy link

TestsPassed ✅Skipped ⚠️Failed
TT-Torch Tests435 ran331 passed104 skipped0 failed
TestResult
No test annotations available

@codecov-commenter
Copy link

codecov-commenter commented Mar 10, 2025

Codecov Report

Attention: Patch coverage is 55.62130% with 75 lines in your changes missing coverage. Please review.

Project coverage is 70.81%. Comparing base (c79248a) to head (b1df435).

✅ All tests successful. No failed tests found.

Files with missing lines Patch % Lines
tt_torch/dynamo/decompositions.py 28.37% 53 Missing ⚠️
tt_torch/dynamo/torch_backend.py 74.07% 14 Missing ⚠️
tt_torch/dynamo/backend.py 70.00% 3 Missing ⚠️
tt_torch/dynamo/shlo_backend.py 50.00% 2 Missing ⚠️
tt_torch/dynamo/executor.py 85.71% 1 Missing ⚠️
tt_torch/dynamo/passes.py 93.75% 1 Missing ⚠️
tt_torch/tools/utils.py 75.00% 1 Missing ⚠️

❗ There is a different number of reports uploaded between BASE (c79248a) and HEAD (b1df435). Click for more details.

HEAD has 18 uploads less than BASE
Flag BASE (c79248a) HEAD (b1df435)
19 1
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #416      +/-   ##
==========================================
- Coverage   76.16%   70.81%   -5.36%     
==========================================
  Files          11       11              
  Lines        1586     1583       -3     
==========================================
- Hits         1208     1121      -87     
- Misses        378      462      +84     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@LPanosTT LPanosTT force-pushed the lpanos/torch_export_experimental branch from bf98a24 to 4dd7dd5 Compare March 11, 2025 16:18
@LPanosTT LPanosTT marked this pull request as ready for review March 11, 2025 16:18
@github-actions
Copy link

TestsPassed ✅Skipped ⚠️Failed
TT-Torch Tests435 ran331 passed104 skipped0 failed
TestResult
No test annotations available

@LPanosTT LPanosTT force-pushed the lpanos/torch_export_experimental branch from 4dd7dd5 to b56d00f Compare March 11, 2025 19:28
@github-actions
Copy link

TestsPassed ✅Skipped ⚠️Failed
TT-Torch Tests435 ran331 passed104 skipped0 failed
TestResult
No test annotations available

@github-actions
Copy link

TestsPassed ❌️SkippedFailed
TT-Torch Tests0 ran0 passed0 skipped0 failed
TestResult
No test annotations available

@LPanosTT LPanosTT force-pushed the lpanos/torch_export_experimental branch from 2b760e8 to ce0aeb0 Compare March 11, 2025 22:06
@github-actions
Copy link

TestsPassed ❌️SkippedFailed
TT-Torch Tests0 ran0 passed0 skipped0 failed
TestResult
No test annotations available

1 similar comment
@github-actions
Copy link

TestsPassed ❌️SkippedFailed
TT-Torch Tests0 ran0 passed0 skipped0 failed
TestResult
No test annotations available

@github-actions
Copy link

TestsPassed ✅Skipped ⚠️Failed
TT-Torch Tests465 ran331 passed134 skipped0 failed
TestResult
No test annotations available

@LPanosTT LPanosTT force-pushed the lpanos/torch_export_experimental branch from 98013e4 to 9ea2257 Compare March 12, 2025 13:30
@github-actions
Copy link

TestsPassed ✅Skipped ⚠️Failed
TT-Torch Tests465 ran331 passed134 skipped0 failed
TestResult
No test annotations available

2 similar comments
@github-actions
Copy link

TestsPassed ✅Skipped ⚠️Failed
TT-Torch Tests465 ran331 passed134 skipped0 failed
TestResult
No test annotations available

@github-actions
Copy link

TestsPassed ✅Skipped ⚠️Failed
TT-Torch Tests465 ran331 passed134 skipped0 failed
TestResult
No test annotations available

@github-actions
Copy link

TestsPassed ☑️Skipped ⚠️Failed ❌️
TT-Torch Tests465 ran327 passed134 skipped4 failed
TestResult
TT-Torch Tests
pytest
test_basic.test_multiple_ops❌ failure
test_basic.test_unused_output❌ failure
test_basic.test_multiple_users❌ failure
test_maxpool2d❌ failure

@LPanosTT LPanosTT force-pushed the lpanos/torch_export_experimental branch from 3ac4272 to 72495ea Compare March 12, 2025 17:34
@github-actions
Copy link

TestsPassed ✅Skipped ⚠️Failed
TT-Torch Tests465 ran331 passed134 skipped0 failed
TestResult
No test annotations available

@github-actions
Copy link

TestsPassed ✅Skipped ⚠️Failed
TT-Torch Tests465 ran331 passed134 skipped0 failed
TestResult
No test annotations available

1 similar comment
@github-actions
Copy link

TestsPassed ✅Skipped ⚠️Failed
TT-Torch Tests465 ran331 passed134 skipped0 failed
TestResult
No test annotations available

@LPanosTT LPanosTT force-pushed the lpanos/torch_export_experimental branch from f847985 to 48ad60a Compare March 13, 2025 02:39
@github-actions
Copy link

TestsPassed ✅Skipped ⚠️Failed
TT-Torch Tests465 ran331 passed134 skipped0 failed
TestResult
No test annotations available

@github-actions
Copy link

TestsPassed ✅Skipped ⚠️Failed
TT-Torch Tests464 ran331 passed133 skipped0 failed
TestResult
No test annotations available

2 similar comments
@github-actions
Copy link

TestsPassed ✅Skipped ⚠️Failed
TT-Torch Tests464 ran331 passed133 skipped0 failed
TestResult
No test annotations available

@github-actions
Copy link

TestsPassed ✅Skipped ⚠️Failed
TT-Torch Tests464 ran331 passed133 skipped0 failed
TestResult
No test annotations available

@github-actions
Copy link

TestsPassed ☑️Skipped ⚠️Failed ❌️
TT-Torch Tests464 ran330 passed133 skipped1 failed
TestResult
TT-Torch Tests
pytest
test_constant_fold.test_interp❌ failure

@LPanosTT LPanosTT force-pushed the lpanos/torch_export_experimental branch from 9d12ec5 to 9f4db69 Compare March 14, 2025 21:52
@github-actions
Copy link

TestsPassed ✅Skipped ⚠️Failed
TT-Torch Tests464 ran331 passed133 skipped0 failed
TestResult
No test annotations available

@LPanosTT LPanosTT force-pushed the lpanos/torch_export_experimental branch from 8476700 to cecc609 Compare March 16, 2025 16:31
@github-actions
Copy link

TestsPassed ✅Skipped ⚠️Failed
TT-Torch Tests464 ran331 passed133 skipped0 failed
TestResult
No test annotations available

@LPanosTT LPanosTT force-pushed the lpanos/torch_export_experimental branch from 7789cf1 to 89582ab Compare March 16, 2025 21:30
@github-actions
Copy link

TestsPassed ✅Skipped ⚠️Failed
TT-Torch Tests464 ran331 passed133 skipped0 failed
TestResult
No test annotations available

@github-actions
Copy link

TestsPassed ✅Skipped ⚠️Failed
TT-Torch Tests464 ran331 passed133 skipped0 failed
TestResult
No test annotations available

mode,
compiler_config=cc,
record_property_handle=record_property,
is_transformers_generation=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe is_string_output is a better name, but that's pretty nit-picky.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I'm thinking is_token_output might be better since it's an integer tensor. Also I'll verify that self.tokenizer was added if this is the case.

batch_t = torch.unsqueeze(img_t, 0)
batch_t = batch_t.to(torch.bfloat16)
return batch_t
return (batch_t,)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed? Do all models need to return a tuple?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not needed anymore. Remanent from earlier commits when it was necessary

inH = 5
inW = 5
inC = 1
scale_factor = 3
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this intended?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this variable isn't used.

# to generate a weight matrix that allows the matmul to be identical to to
# torch's upsample_bilinear2d when align_corners=True.
# This logic was derived from @brentyi's implementation in:
# https://github.com/jax-ml/jax/issues/11206#issuecomment-1423140760
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this still the right source?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes but I just added a new special case, the normal implementation is inspired by that source

@LPanosTT LPanosTT force-pushed the lpanos/torch_export_experimental branch 2 times, most recently from ec85815 to 1596c08 Compare March 17, 2025 16:24
Custom resolve_literal hook

outline model parameters by default, added option to leave them inlined

Fix input order

Add command line option for inline_parameters. Remove remove_embedded_constants CompilerConfig option as it does nothing now

Upsample decomp requires pytorch 2.6

Add decomps for linear/nearest upsample for 1d, 2d, 3d

Add pytorch 2.6.0 to wheel build

Force adaptive_avg_pool2d decomp and add special case for reduce_mean equivalent avg_pool

Use default torch decompositions

Use program in executor, fix input ordering in __call__

Disable mobilenetv3 full tests

Pass buffers in pre-execution depth GraphModule execution. Remove bug which causes recursive decomposition for avg pool

Fix test yaml

Remove skip on test_interp. Remove redundant test in test_basic.py

Use split_with_sizes decomposition from pytorch 2.5.0

run_shape_prop after exporting the program for the last time

Patches for Stablehlo backend

Fix test name for stable diffusion e2e test

outline everything or nothing

Skip flan_t5 test. Increase atol threshold on a few tests

Add flag to ModelTester to inform the ModelTester that the thing being tested isa huggingface .generate call. Verify outputs by comparing decoded output instead

Add aten.clamp decomposition which orders inputs correctly so mobilenetv3 can execute. This should be only temporary as the fix belongs in tt-mlir

Fix edge case when linearly sampling a single datup. Skip RMGB

.
@LPanosTT LPanosTT force-pushed the lpanos/torch_export_experimental branch from 1596c08 to b1df435 Compare March 17, 2025 16:26
@LPanosTT LPanosTT enabled auto-merge (rebase) March 17, 2025 16:29
@LPanosTT LPanosTT disabled auto-merge March 17, 2025 16:29
@LPanosTT LPanosTT enabled auto-merge (squash) March 17, 2025 16:29
@github-actions
Copy link

TestsPassed ✅Skipped ⚠️Failed
TT-Torch Tests464 ran331 passed133 skipped0 failed
TestResult
No test annotations available

@LPanosTT LPanosTT merged commit 0decda6 into main Mar 17, 2025
12 checks passed
@LPanosTT LPanosTT deleted the lpanos/torch_export_experimental branch March 17, 2025 16:54
kmabeeTT pushed a commit that referenced this pull request May 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants