Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Runtime stitching APIs and sanity tests, ttnn runtime submit refactor #1301

Merged
merged 1 commit into from
Dec 4, 2024

Conversation

jnie-TT
Copy link
Contributor

@jnie-TT jnie-TT commented Nov 17, 2024

closes #671 #103 #366
Runtime support for APIs that enable stitching separate runtime submit calls together.

APIs
std::vector submit(Device, Binary, programIndex, inputTensors)

  • Execute binary.
  • Implicitly executes asynchronously in ttnn unless blocking APIs are executed within the Binary graph such as ttnn::from_device.
  • Returns an a vector of tensors.
  • inputTensors can be on host or on device. We can save perf cycles if for example an input tensor is already on device. This way to_device ops can return immediately.
  • Tensors returned may be on host or on device, depending on the flatbuffer execution graph.
  • Any further blocking api calls such as toHost will automatically wait for all required tensors to be populated.
  • Automatically converts input tensors to the desired layout.

Layout getLayout(Binary, ProgramIndex, InputIndex)

  • Returns layout of input at index as defined in binary.

Tensor toLayout(Tensor, Device, Layout)

  • Returns tensor of required layout.
  • TODO: Some paths that require host fallback for device tensor layout conversions are unchecked currently. We can add checks for that once we hit a use-case. Leaving unchecked for now.

Tensor createTensor(Device device, Layout layout,
std::vectorstd::uint32_t const &shape,
std::vectorstd::uint32_t const &stride,
std::uint32_t itemsize)

  • Creates an empty tensor on host/device (location derived from layout) with the desired layout.

void deallocateTensor(Tensor, bool force)

  • Deallocates tensor. Set force to true to force deallocation.

void memcpy(Tensor dst, Tensor src)

  • Copies src to dst. Assumes bost tensors are on host and asserts dst size == src size.
  • Works for host to host, host to device, device to host.
  • TODO: add support for multichip tensors

Tensor toHost(Tensor, bool untilize)

  • Waits for tensor operations to complete and copies a tensor to the host.
  • The original tensor remains allocated in device memory.
  • Set untilize to true to untilize the tensor.

Testing

  • Updated ttrt and test_subtract to use the new submit API.
  • Added a runtime python test folder that contains unit tests for runtime APIs. Added testing for the runtime stitching APIs in runtime/test/python/ttnn/test_runtime_api.py.
    • Currently this requires ttrt whls to be compiled and installed.
  • Added a utility library for runtime testing. This currently contains helper functions that generate specific layouts, as we currently cannot get these layouts from the flatbuffers with getLayout until compiler support is added.
  • Added a runtime stitching test that stitches an multiple eltwise programs together. Currently the mlir for this test is purely in the ttnn dialect and I manually removed to_device ops to avoid hitting asserts that the tensor is already on device, since the compiler currently hardcodes input layouts to host and generates ops based on this assumption.
    • On this note, we want to add an mlir directory specifically for runtime testing. Currently I added a folder under test/ttmlir/Runtime that contains the hacked ttnn dialect mlir for the runtime stitching test. Adding it under Silicon didn't work because ttrt will run all flatbuffers generated there, and the runtime stitching mlir needs a different flow. I'm open to suggestions for better locations of adding this, but it would be great to have a dedicated location to add mlir files that the compiler may not support yet and are used specifically for runtime testing.

Backwards Compatibility

  • Kept the old submit API so that FEs have a grace period, but it would desirable that FEs migrate soon.

Misc

  • Fixed bug in isfinite tests where the input-output data formats are different. Currently the output of this test is incorrect on main.
  • Marked simple_broadcast.mlir as unsupported as currently the generated ttnn graph is invalid: simple_broadcast.mlir is invalid #1314

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Clang-Tidy found issue(s) with the introduced code (1/1)

runtime/lib/ttnn/runtime.cpp Outdated Show resolved Hide resolved
runtime/lib/ttnn/runtime.cpp Outdated Show resolved Hide resolved
runtime/lib/ttnn/runtime.cpp Outdated Show resolved Hide resolved
runtime/lib/ttnn/runtime.cpp Outdated Show resolved Hide resolved
runtime/lib/ttnn/runtime.cpp Outdated Show resolved Hide resolved
runtime/lib/ttnn/runtime.cpp Outdated Show resolved Hide resolved
runtime/lib/ttnn/runtime.cpp Outdated Show resolved Hide resolved
@jnie-TT jnie-TT force-pushed the jnie/runtime_stitching_2_rebased branch 3 times, most recently from 107f0ad to 6138bd5 Compare November 18, 2024 01:11
@jnie-TT jnie-TT force-pushed the jnie/runtime_stitching_2_rebased branch from 6138bd5 to 80419f0 Compare November 18, 2024 15:33
runtime/include/tt/runtime/runtime.h Outdated Show resolved Hide resolved
runtime/include/tt/runtime/runtime.h Outdated Show resolved Hide resolved
runtime/lib/runtime.cpp Outdated Show resolved Hide resolved
@nsmithtt
Copy link
Contributor

@pilkicTT / @rpavlovicTT, would really like your input on these runtime API changes.

@jnie-TT jnie-TT force-pushed the jnie/runtime_stitching_2_rebased branch 3 times, most recently from c3f32d6 to 274cf17 Compare November 18, 2024 20:43
@rpavlovicTT
Copy link
Contributor

@pilkicTT / @rpavlovicTT, would really like your input on these runtime API changes.

At the moment it looks good overall.

Few general questions --

  1. What about error handling? If something crashes/asserts (for example we run OOM) in runtime how does TPRT handles it?
  2. Does it make sense to have AllocateTensor API as well? From the TPRT perspective, is there a case when you want to allocate memory but not to copy data immediately?
  3. Do we want to have callback support? As in, invoke a provided callback when an event is complete?
  4. Should checkpointing be part of this API discussion or should it be separate?

@pilkicTT
Copy link
Contributor

@pilkicTT / @rpavlovicTT, would really like your input on these runtime API changes.

@nsmithtt Sorry, I haven't had a chance yet to take a look. Hopefully, later today or tomorrow (in the worst case)...

@jnie-TT I've noticed there are multiple unrelated changes in this PR, this should be avoided, especially with larger changes. For at least two reasons:

  1. Making review harder
  2. Making harder to revert the change, i.e. if runtime stitching creates a bug, you will need to manually revert (in order to keep other changes) or vice versa.

I am not blocking on this, just think it's good to avoid these situations in the future. Thanks!

@jnie-TT
Copy link
Contributor Author

jnie-TT commented Nov 19, 2024

@pilkicTT / @rpavlovicTT, would really like your input on these runtime API changes.

At the moment it looks good overall.

Few general questions --

  1. What about error handling? If something crashes/asserts (for example we run OOM) in runtime how does TPRT handles it?
  2. Does it make sense to have AllocateTensor API as well? From the TPRT perspective, is there a case when you want to allocate memory but not to copy data immediately?
  3. Do we want to have callback support? As in, invoke a provided callback when an event is complete?
  4. Should checkpointing be part of this API discussion or should it be separate?
  1. Currently there's no memory-checker in runtime, tt-metal will assert if there's not enough memory left on device. We can discuss what makes the most sense for TPRT. Currently TPRT would need to catch these errors manually.

  2. We can add this API if we need it, just wondering if it's sufficient to create the tensor once you have the data, or do we have a use-case where pre-allocating the tensor is required?

  3. There won't be a notion of events in ttnn runtime anymore. TTNN implicitly handles async executions and automatically blocks on certain APIs. If we want to query info from the tensor those APIs will be blocking implicitly. There's a callback infra up after this change #1190: Added runtime support for doing golden comparision for flatbuffers in ttrt #1218.

  4. This should probably be a separate effort.

@jnie-TT jnie-TT force-pushed the jnie/runtime_stitching_2_rebased branch from 274cf17 to a3541bf Compare November 19, 2024 17:59
@rpavlovicTT
Copy link
Contributor

rpavlovicTT commented Nov 20, 2024

Thanks.
I missed the callback support, kudos to @tapspatel .

When you say

If we want to query info from the tensor those APIs will be blocking implicitly.

Are you saying we'll be able to poll if tensor is ready/not?

  1. We can add this API if we need it, just wondering if it's sufficient to create the tensor once you have the data, or do we > have a use-case where pre-allocating the tensor is required?

I asked this question but don't have a concrete answer, it's more of a question for TPRT folks. cc @mrakitaTT @AleksKnezevic

Copy link
Contributor

@pilkicTT pilkicTT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what I see in the code, it doesn't look like the inputs to the program are allowed to be either on host or device (i.e. they must be always on host, if the binary expects it that way)? I couldn't find the place where the inputs to the program are specially handled if they're already on the device?

It would be good if we add tests for these kinds of basic API scenarios. Not necessarily with this PR...

runtime/lib/ttnn/include/CMakeLists.txt Outdated Show resolved Hide resolved
runtime/lib/ttnn/runtime.cpp Show resolved Hide resolved
runtime/lib/ttnn/runtime.cpp Outdated Show resolved Hide resolved
@jnie-TT
Copy link
Contributor Author

jnie-TT commented Nov 20, 2024

From what I see in the code, it doesn't look like the inputs to the program are allowed to be either on host or device (i.e. they must be always on host, if the binary expects it that way)? I couldn't find the place where the inputs to the program are specially handled if they're already on the device?

It would be good if we add tests for these kinds of basic API scenarios. Not necessarily with this PR...

I'll add some runtime stitching tests that pre-moves tensors to device and run with tensors already on device, though right now the flatbuffer binaries have all input tensors hardcoded to host. We're currently looking at how we can canonicalize the input layout, would also be great to get your feedback there from a FE perspective.

@jnie-TT
Copy link
Contributor Author

jnie-TT commented Nov 20, 2024

@rpavlovicTT

Are you saying we'll be able to poll if tensor is ready/not?

Yes the APIs will block and poll for tensor readiness, and return the info when the tensor is ready.

@pilkicTT
Copy link
Contributor

From what I see in the code, it doesn't look like the inputs to the program are allowed to be either on host or device (i.e. they must be always on host, if the binary expects it that way)? I couldn't find the place where the inputs to the program are specially handled if they're already on the device?
It would be good if we add tests for these kinds of basic API scenarios. Not necessarily with this PR...

I'll add some runtime stitching tests that pre-moves tensors to device and run with tensors already on device, though right now the flatbuffer binaries have all input tensors hardcoded to host. We're currently looking at how we can canonicalize the input layout, would also be great to get your feedback there from a FE perspective.

@jnie-TT Thanks for the confirmation. Can you please update the commit & PR description (for future reference)? Just to clarify that at this point, the APIs are prepared but cannot effectively be used in the context of stitching multiple programs by keeping the inputs always on the device (since we always assume the inputs are on host).

With regards to the way how we can implement this, I am really not sure (i will throw out some ideas)... I was thinking that for v0, we would just handle the cases in the runtime (without reflecting it in the binary or the IR). Obviously not ideal, but the most simple way.

For a proper solution, I was thinking about if there is an option to extend dialects (both TTIR and TTNNIR) in the following way:

  1. TTIR (or maybe just TTNNIR) - represent tensors (all input tensors) that live either on system or on the device; i.e. extend the system, device attributes (or however they are called) with any
  2. TTNNIR:
    2.1. extend the ttnn.to_device() to handle transformation of any -> system|device (what ever is desired by the compiler)
    2.2. add support for branching/optional execution depending on the predicate in the IR, and represent this logic directly in the IR

Not sure, if this makes sense in the context of mlir, but it looks to me that we'll have complete transparency about what happens in the program (at least from the standpoint of moving tensors between the host and the device).

@jnie-TT jnie-TT force-pushed the jnie/runtime_stitching_2_rebased branch from a3541bf to 4361f74 Compare November 20, 2024 23:06
@jnie-TT jnie-TT force-pushed the jnie/runtime_stitching_2_rebased branch 5 times, most recently from 3c5eda8 to 124aa5d Compare November 21, 2024 00:41
@jnie-TT jnie-TT force-pushed the jnie/runtime_stitching_2_rebased branch 9 times, most recently from 64d5f5d to b64b43e Compare November 29, 2024 18:23
Copy link
Contributor

@nsmithtt nsmithtt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, thank you for adding tests! Just as a sanity check before we land, can we kick off a tt-forge-fe CI run rebased on this change?

Let's try to pre-empt any fallout.

runtime/include/tt/runtime/runtime.h Show resolved Hide resolved
runtime/lib/runtime.cpp Show resolved Hide resolved
@jnie-TT jnie-TT force-pushed the jnie/runtime_stitching_2_rebased branch from b64b43e to 6e8d137 Compare December 3, 2024 18:21
@jnie-TT
Copy link
Contributor Author

jnie-TT commented Dec 3, 2024

Awesome, thank you for adding tests! Just as a sanity check before we land, can we kick off a tt-forge-fe CI run rebased on this change?

Let's try to pre-empt any fallout.

Forge CI green: https://github.com/tenstorrent/tt-forge-fe/actions/runs/12149172667

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Runtime stitching / runtime weights
6 participants