-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable multi-device computation in runtime #1716
Conversation
90ce1bf
to
2a566ed
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changes look good from a dialect perspective. Just a few minor comments inline.
@@ -275,7 +275,8 @@ class TTNNLayoutDPSOperandsRewriter | |||
LogicalResult matchAndRewrite(DestinationStyleOpInterface op, | |||
PatternRewriter &rewriter) const final { | |||
// To layout op is a special case, we don't want to rewrite it | |||
if (mlir::isa<ttir::ToLayoutOp>(op.getOperation())) { | |||
if (mlir::isa<ttir::ToLayoutOp>(op.getOperation()) || | |||
mlir::isa<ttir::MeshShardOp>(op.getOperation())) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any comment on why MeshShardOp is a special one in this regard?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TTNN mesh shard APIs are currently CPU only operations.. So, by enforcing tensors to be located in system memory, we can ensure (1) a tensor can be sharded into multi-device storage in cpu side, and (2) later tiled and transferred to device to individual devices.
2fd39a1
to
992997d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great! Thanks
efddf68
to
f5bec77
Compare
* Allow ttnn runtime operations including reduce_scatter, mesh_shard, and all_gather * Force mesh_shard ops to use system memory because they are host-side operations * Use strongly-typed sharding options for mesh_shard ops * Add Silicon multi-device test cases * Fix bug in determining axis of all_reduce when converting from stableHLO * Fix typo in ttnn workaround pass
f5bec77
to
3e57fd9
Compare
Enable multi-device computation in runtime