Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ttnn::ones() op #1476

Merged
merged 4 commits into from
Dec 11, 2024
Merged

Add ttnn::ones() op #1476

merged 4 commits into from
Dec 11, 2024

Conversation

svuckovicTT
Copy link
Contributor

@svuckovicTT svuckovicTT commented Dec 3, 2024

This PR adds support for ttnn::ones() op.

  • A ttnn-to-emitc utils file has been added
  • In MLIRToFlatbuffer.h, two toFlatbufferOptional method have been added to make the TTNNToFlatbuffer.cpp code cleaner, as all the optional stuff is abstracted away
  • More streamlined conversion of the op in TTNNToEmitC.cpp, would like to try to apply this to other ops, and hopefully create a generic solution that would work for most ops without any special case handling
  • In a similar manner, run() method in runtime/lib/ttnn/operations/creation/ones.cpp has been written so that it "mechanically" unpacks a flatbuffer and calls the appropriate ttnn method - it would be great if we could make this generic so that we don't need to manually handle each op

N̶o̶t̶e̶:̶ ̶t̶h̶e̶r̶e̶'̶s̶ ̶a̶ ̶p̶i̶e̶c̶e̶ ̶o̶f̶ ̶c̶o̶d̶e̶ ̶̶l̶i̶b̶/̶C̶o̶n̶v̶e̶r̶s̶i̶o̶n̶/̶T̶T̶N̶N̶T̶o̶E̶m̶i̶t̶C̶/̶U̶t̶i̶l̶s̶.̶c̶p̶p̶̶ ̶t̶h̶a̶t̶ ̶I̶ ̶c̶h̶a̶n̶g̶e̶d̶ ̶j̶u̶s̶t̶ ̶t̶o̶ ̶m̶a̶k̶e̶ ̶t̶h̶e̶ ̶t̶e̶s̶t̶s̶ ̶r̶u̶n̶,̶ ̶I̶'̶l̶l̶ ̶m̶a̶r̶k̶ ̶i̶t̶ ̶w̶i̶t̶h̶ ̶a̶ ̶c̶o̶m̶m̶e̶n̶t̶,̶ ̶t̶h̶a̶t̶ ̶w̶i̶l̶l̶ ̶b̶e̶ ̶r̶e̶m̶o̶v̶e̶d̶ ̶b̶e̶f̶o̶r̶e̶ ̶t̶h̶i̶s̶ ̶P̶R̶ ̶i̶s̶ ̶m̶e̶r̶g̶e̶d̶ ̶w̶i̶t̶h̶ ̶m̶a̶i̶n̶ ̶-̶ ̶@̶m̶t̶o̶p̶a̶l̶o̶v̶i̶c̶T̶T̶ ̶h̶a̶s̶ ̶a̶ ̶f̶i̶x̶ ̶i̶n̶ ̶t̶h̶e̶ ̶w̶o̶r̶k̶s̶.̶

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Clang-Tidy found issue(s) with the introduced code (1/1)

//
// SPDX-License-Identifier: Apache-2.0

#include "ones.h"
Copy link
Contributor

@jnie-TT jnie-TT Dec 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like we should just use full op for this. Under the hood ttnn implements it as a full op with fill value 1. This way we can reuse the code of full op for different ops that target specific fill values (zeros, ones, twos etc.). It feels cumbersome to copy and paste the code for full op just to target a specific fill value.

Not sure what the best way to do this on the compiler/flatbuffer side is though. We could just always use full with specific values, or we could add a ones op in the ttnn dialect that then lower it to full with specific values when translating to flatbuffer. So the flatbuffer will only have a schema for full.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I somewhat agree, but...

Full op is not implemented fully today (pun intended) - the only arguments that are supported are

let arguments = (ins TT_Device:$device, F32Attr:$fillValue);

while the full set of operands for ttnn::full is:

    static ttnn::Tensor invoke(
        const ttnn::Shape& shape,
        const float fill_value,
        const std::optional<DataType>& dtype = std::nullopt,
        const std::optional<Layout>& layout = std::nullopt,
        detail::OptionalAnyDevice device = std::nullopt,
        const std::optional<MemoryConfig>& memory_config = std::nullopt,
        std::optional<ttnn::Tensor> optional_output_tensor = std::nullopt)

So I wouldn't block this PR on that.

Copy link
Contributor

@jnie-TT jnie-TT Dec 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I think a better way to do this is to use the current full op implementation as an API. We can add zeros, ones etc. to full.h/full.cpp, and just execute it with specific fill values. This way we get all the implementation for free when we want to add new variants of full op (zeros, ones etc.), and we can ensure consistency across all of them.

Currently we have:

void run(const ::tt::target::ttnn::FullOp *op, ProgramContext &context) {
  ProgramTensorPool &tensorPool = context.getTensorPool();
  FullTensorConfig config(op);
  ::ttnn::Tensor out;
  const ::tt::target::DeviceRef *deviceRef =
      !utils::inSystemMemory(op->out()) ? op->device() : nullptr;

  if (config.numShards == 1) {
    out = createFullOnSingleDevice(context, config, deviceRef);
  } else if (config.numShards > 1) {
    out = createFullOnMultiDevice(context, config, deviceRef);
  } else {
    LOG_FATAL("Unsupported num shards");
  }
  tensorPool.insert_or_assign(op->out()->global_id(), out);
}

We can update FullTensorConfig to explicitly take in all the parameters instead of taking in the op and deriving the parameters in its constructor, and it'd be up to the run function of each op to derive these parameters and pass them into the FullTensorConfig constructor.

Since FullOp currently doesn't have layout, dtype etc. it would need to derive from the output tensor descriptor, but Ones can just get it from the op, and explicitly use 1 as the fill Value.

Then every op would just have the following code in common once the FullTensorConfig is created (we can later wrap the following in a function like executeFull or something):

  if (config.numShards == 1) {
    out = createFullOnSingleDevice(context, config, deviceRef);
  } else if (config.numShards > 1) {
    out = createFullOnMultiDevice(context, config, deviceRef);
  } else {
    LOG_FATAL("Unsupported num shards");
  }
  tensorPool.insert_or_assign(op->out()->global_id(), out);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey Jackson, appreciate the write-up!

I'm prepping a doc on a topic that will, among other things, cover how we run ops, and there I'll go over how I think we should approach these situations - I hope that happens in the next 2 weeks, targetting one of our Thursday tech syncs.

Until then, I'd prefer we treat ttnn::ones as an op of its own. Would that be okay with you?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Jackson OOO right now and I don't have strong opinion, so will approve this for now to unblock, let's circle back and discuss this situation in the future like you said.

@svuckovicTT svuckovicTT force-pushed the svuckovic/ones-op-2 branch 2 times, most recently from 8899587 to 41c0afb Compare December 6, 2024 13:43
: ttnn::Layout::RowMajor;
ttnn::LayoutAttr tensorLayoutAttr =
ttnn::LayoutAttr::get(op.getContext(), ttnnLayoutEnum);
ttnn::TensorMemoryLayoutAttr memLayout = layoutAttr.getMemLayout();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't getMemLayout() return enum ttnn::TensorMemoryLayout?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check 3578538

Comment on lines 132 to 135
// Device only exists if ttnn::TensorMemoryLayout is None
//
auto device =
memLayout ? nullptr : ::ttnn::utils::getOrInsertDevice(rewriter, op);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if memLayout is enum, then check memLayout != None.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check 3578538

Copy link
Contributor

@kmabeeTT kmabeeTT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Sasha - Unresolved discussion about using full op to handle this, and this type of situation in general. Let's loop back to that sometime..approving for now to unblock.

@svuckovicTT
Copy link
Contributor Author

Thanks @kmabeeTT!

@nsmithtt @tapspatel can I get eyes on this please? :)

@svuckovicTT svuckovicTT linked an issue Dec 9, 2024 that may be closed by this pull request
return ::tt::target::TensorLayout::RowMajor;
case ttnn::Layout::Tile:
return ::tt::target::TensorLayout::Tile;
case ttnn::Layout::Invalid:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know if this enum value is used ever?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't really know, but for the sake of completeness and future-proofing, decided to include it.


inline ::flatbuffers::Optional<::tt::target::TensorLayout>
toFlatbufferOptional(FlatbufferObjectCache &cache,
::std::optional<mlir::tt::ttnn::Layout> layout) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: You can omit fully typing namespace in this case I think since it's enclosed in mlir::tt, so ttnn::Layout should be enought

::flatbuffers::Optional<::tt::target::TensorLayout> layout =
toFlatbufferOptional(cache, op.getLayout());

flatbuffers::Offset<::tt::target::DeviceRef> device =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does serialization work in case there is no device? I see you added 0 as a fallback here. Can we also use flatbuffer::nullopt like you did for layout?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type here is flatbuffers::Offset which means that the vtable entry of device would be 0. In Flatbuffer speak, that means that the value isn't set. Enums were trickier, I can't remember the details, but I think that when I tried without setting them to null in program.fbs, they would be treated such that 0 would mean the first value of the enum instead of unset (or offset=0, as is for non-enum types). Setting the enums to null is what provides the Optional wrapper.

@svuckovicTT
Copy link
Contributor Author

Hey @nsmithtt @tapspatel, need your review for the program.fbs file, the rest has been reviewed, shouldn't take more than 30 sec, thank you!

Copy link
Collaborator

@tapspatel tapspatel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good!

@svuckovicTT svuckovicTT merged commit 31e5518 into main Dec 11, 2024
21 checks passed
azecevicTT pushed a commit that referenced this pull request Dec 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[TTNN Dialect] Add ttnn.onesOp
6 participants