Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable multi-device computation in runtime #1716

Merged
merged 1 commit into from
Jan 14, 2025

Conversation

wooseokTT
Copy link
Contributor

Enable multi-device computation in runtime

  • Allow ttnn runtime operations including reduce_scatter, mesh_shard, and all_gather
  • Force mesh_shard ops to use system memory because they are host-side operations
  • Use strongly-typed sharding attributes of mesh_shard ops
  • Add Silicon multi-device test cases
  • Fix bug in determining axis of all_reduce when converting from stableHLO to ttir
  • Fix typo in ttnn workaround pass

Copy link
Contributor

@sdjordjevicTT sdjordjevicTT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes look good from a dialect perspective. Just a few minor comments inline.

@@ -275,7 +275,8 @@ class TTNNLayoutDPSOperandsRewriter
LogicalResult matchAndRewrite(DestinationStyleOpInterface op,
PatternRewriter &rewriter) const final {
// To layout op is a special case, we don't want to rewrite it
if (mlir::isa<ttir::ToLayoutOp>(op.getOperation())) {
if (mlir::isa<ttir::ToLayoutOp>(op.getOperation()) ||
mlir::isa<ttir::MeshShardOp>(op.getOperation())) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any comment on why MeshShardOp is a special one in this regard?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TTNN mesh shard APIs are currently CPU only operations.. So, by enforcing tensors to be located in system memory, we can ensure (1) a tensor can be sharded into multi-device storage in cpu side, and (2) later tiled and transferred to device to individual devices.

lib/Dialect/TTNN/Transforms/TTNNLayout.cpp Outdated Show resolved Hide resolved
@wooseokTT wooseokTT force-pushed the wooseok/enable_multidevice_runtime branch 2 times, most recently from 2fd39a1 to 992997d Compare January 10, 2025 14:34
Copy link
Contributor

@nsmithtt nsmithtt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great! Thanks

@wooseokTT wooseokTT force-pushed the wooseok/enable_multidevice_runtime branch 2 times, most recently from efddf68 to f5bec77 Compare January 13, 2025 18:57
* Allow ttnn runtime operations including reduce_scatter, mesh_shard,
  and all_gather

* Force mesh_shard ops to use system memory because they are host-side
  operations

* Use strongly-typed sharding options for mesh_shard ops

* Add Silicon multi-device test cases

* Fix bug in determining axis of all_reduce when converting from
  stableHLO

* Fix typo in ttnn workaround pass
@wooseokTT wooseokTT force-pushed the wooseok/enable_multidevice_runtime branch from f5bec77 to 3e57fd9 Compare January 14, 2025 19:35
@wooseokTT wooseokTT enabled auto-merge (squash) January 14, 2025 19:37
@wooseokTT wooseokTT merged commit d1a5e78 into main Jan 14, 2025
20 checks passed
@wooseokTT wooseokTT deleted the wooseok/enable_multidevice_runtime branch January 14, 2025 20:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Push Jax test through
5 participants