-
Notifications
You must be signed in to change notification settings - Fork 9
Use Torch ExportedModule to import initial MLIR module #416
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
||||||||||||||
Codecov ReportAttention: Patch coverage is
✅ All tests successful. No failed tests found.
Additional details and impacted files@@ Coverage Diff @@
## main #416 +/- ##
==========================================
- Coverage 76.16% 70.81% -5.36%
==========================================
Files 11 11
Lines 1586 1583 -3
==========================================
- Hits 1208 1121 -87
- Misses 378 462 +84 ☔ View full report in Codecov by Sentry. |
bf98a24 to
4dd7dd5
Compare
|
||||||||||||||
4dd7dd5 to
b56d00f
Compare
|
||||||||||||||
|
||||||||||||||
2b760e8 to
ce0aeb0
Compare
|
||||||||||||||
1 similar comment
|
||||||||||||||
|
||||||||||||||
98013e4 to
9ea2257
Compare
|
||||||||||||||
2 similar comments
|
||||||||||||||
|
||||||||||||||
|
||||||||||||||||||||||||
3ac4272 to
72495ea
Compare
|
||||||||||||||
b93dab4 to
c44150f
Compare
|
||||||||||||||
1 similar comment
|
||||||||||||||
f847985 to
48ad60a
Compare
|
||||||||||||||
|
||||||||||||||
2 similar comments
|
||||||||||||||
|
||||||||||||||
|
||||||||||||||||||
9d12ec5 to
9f4db69
Compare
|
||||||||||||||
8476700 to
cecc609
Compare
|
||||||||||||||
7789cf1 to
89582ab
Compare
|
||||||||||||||
|
||||||||||||||
tests/models/mamba/test_mamba.py
Outdated
| mode, | ||
| compiler_config=cc, | ||
| record_property_handle=record_property, | ||
| is_transformers_generation=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe is_string_output is a better name, but that's pretty nit-picky.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I'm thinking is_token_output might be better since it's an integer tensor. Also I'll verify that self.tokenizer was added if this is the case.
| batch_t = torch.unsqueeze(img_t, 0) | ||
| batch_t = batch_t.to(torch.bfloat16) | ||
| return batch_t | ||
| return (batch_t,) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this needed? Do all models need to return a tuple?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not needed anymore. Remanent from earlier commits when it was necessary
| inH = 5 | ||
| inW = 5 | ||
| inC = 1 | ||
| scale_factor = 3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was this intended?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes this variable isn't used.
| # to generate a weight matrix that allows the matmul to be identical to to | ||
| # torch's upsample_bilinear2d when align_corners=True. | ||
| # This logic was derived from @brentyi's implementation in: | ||
| # https://github.com/jax-ml/jax/issues/11206#issuecomment-1423140760 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this still the right source?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes but I just added a new special case, the normal implementation is inspired by that source
ec85815 to
1596c08
Compare
Custom resolve_literal hook outline model parameters by default, added option to leave them inlined Fix input order Add command line option for inline_parameters. Remove remove_embedded_constants CompilerConfig option as it does nothing now Upsample decomp requires pytorch 2.6 Add decomps for linear/nearest upsample for 1d, 2d, 3d Add pytorch 2.6.0 to wheel build Force adaptive_avg_pool2d decomp and add special case for reduce_mean equivalent avg_pool Use default torch decompositions Use program in executor, fix input ordering in __call__ Disable mobilenetv3 full tests Pass buffers in pre-execution depth GraphModule execution. Remove bug which causes recursive decomposition for avg pool Fix test yaml Remove skip on test_interp. Remove redundant test in test_basic.py Use split_with_sizes decomposition from pytorch 2.5.0 run_shape_prop after exporting the program for the last time Patches for Stablehlo backend Fix test name for stable diffusion e2e test outline everything or nothing Skip flan_t5 test. Increase atol threshold on a few tests Add flag to ModelTester to inform the ModelTester that the thing being tested isa huggingface .generate call. Verify outputs by comparing decoded output instead Add aten.clamp decomposition which orders inputs correctly so mobilenetv3 can execute. This should be only temporary as the fix belongs in tt-mlir Fix edge case when linearly sampling a single datup. Skip RMGB .
1596c08 to
b1df435
Compare
|
||||||||||||||
…on handling and proper error catching (#800)
Use a
torch.export.ExportedProgramto generate the initial MLIR module.This requires us to create an
ExportedProgramfrom the initialGraphModule.Benefits:
Issues:
aten.clampis decomposed by torch-mlir tomaximum(minimum(input, max), min).ttnn.maximumrequires that the operand which needs to be broadcasted is on the RHS. Currently, in tt-mlir thePartiallyBroadcastableop trait only enforces that the broadcasted operand is on the LHSFxImportertreat them as graph inputs we need to edit theExportedModulesExportedGraphSignatureand force all "parameter" types to "user inputs"ExportedGraphSignatureis meant to be a private member ofExportedProgramFxImporterto not inline the parameters if we pass a flag of some sort. Perhaps a future contribution to torch-mlir.Other Info:
stablehlo.reduce_window **even in the case where the op is equivalent to a global average**. Since we do not have support for lowering a sum_pool inStablehloToTTIRPatterns.cpp(sum because the division is afterward), I've temporarily added a custom decomposition ofaten.avg_pool2dwhich will convert to a mean over the spatial dimensions when theavg_pool2d` is equivalent to it.aten.splitis no longer lowered to a series ofnarrowops. Instead it is now lowered to a series ofas_stridedops.narrowis lowered toslice, which can be lowered tostablehlo.slice.as_stridedcannot be lowered from Torch Backend IR to Stablehlo. I've temporarily added back the old decomposition from PyTorch 2.5.0 which uses narrow as a custom decomposition.AtenAsStridedOptostablehlo::SliceOpin our fork of torch-mlir: LowerAtenAsStridedOptostablehlo::SliceOpwhen possible. llvm-torch-mlir#4GraphModulewhich is passed tobackenddoes not account for control flow, I believe in PyTorch 2.5.0 a graph break would be triggered during.generatemethods intransformersLLMs. It does not anymore and so.generatewill run until the max length is reached.GraphModuleas the output shapes are different..generategraphs are integers PCC/atol verification is not quite useful but does returnTruewhen the outputs are identicalModelTesterthat informs theModelTesterit is testing a.generatecall. It will decode the output tokens and we compare the resulting strings.torch.condwhich they seem to intend to use to trace data-dependent control-flow. There's a note in thetransformerssource that says they intend to use it when it is no longer experimentalFxImporterhook for importing literals to the graph. By default it will make all non-scalarsDenseElementsResourceAttrs, however, this causes the process to hang upon cleanup whether the test fails or not. So the hook just usesDenseElementsAttrfor all literals.DenseElementsResourceAttr: [MLIR] Fix thread safety of the deleter in PyDenseResourceElementsAttribute llvm/llvm-project#124832