-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ttnn::ones() op #1476
Add ttnn::ones() op #1476
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clang-Tidy
found issue(s) with the introduced code (1/1)
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include "ones.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like we should just use full op for this. Under the hood ttnn implements it as a full op with fill value 1. This way we can reuse the code of full op for different ops that target specific fill values (zeros, ones, twos etc.). It feels cumbersome to copy and paste the code for full op just to target a specific fill value.
Not sure what the best way to do this on the compiler/flatbuffer side is though. We could just always use full with specific values, or we could add a ones op in the ttnn dialect that then lower it to full with specific values when translating to flatbuffer. So the flatbuffer will only have a schema for full.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I somewhat agree, but...
Full op is not implemented fully today (pun intended) - the only arguments that are supported are
let arguments = (ins TT_Device:$device, F32Attr:$fillValue);
while the full set of operands for ttnn::full
is:
static ttnn::Tensor invoke(
const ttnn::Shape& shape,
const float fill_value,
const std::optional<DataType>& dtype = std::nullopt,
const std::optional<Layout>& layout = std::nullopt,
detail::OptionalAnyDevice device = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<ttnn::Tensor> optional_output_tensor = std::nullopt)
So I wouldn't block this PR on that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually I think a better way to do this is to use the current full op implementation as an API. We can add zeros, ones etc. to full.h/full.cpp, and just execute it with specific fill values. This way we get all the implementation for free when we want to add new variants of full op (zeros, ones etc.), and we can ensure consistency across all of them.
Currently we have:
void run(const ::tt::target::ttnn::FullOp *op, ProgramContext &context) {
ProgramTensorPool &tensorPool = context.getTensorPool();
FullTensorConfig config(op);
::ttnn::Tensor out;
const ::tt::target::DeviceRef *deviceRef =
!utils::inSystemMemory(op->out()) ? op->device() : nullptr;
if (config.numShards == 1) {
out = createFullOnSingleDevice(context, config, deviceRef);
} else if (config.numShards > 1) {
out = createFullOnMultiDevice(context, config, deviceRef);
} else {
LOG_FATAL("Unsupported num shards");
}
tensorPool.insert_or_assign(op->out()->global_id(), out);
}
We can update FullTensorConfig
to explicitly take in all the parameters instead of taking in the op and deriving the parameters in its constructor, and it'd be up to the run
function of each op to derive these parameters and pass them into the FullTensorConfig
constructor.
Since FullOp
currently doesn't have layout, dtype etc. it would need to derive from the output tensor descriptor, but Ones
can just get it from the op, and explicitly use 1
as the fill Value.
Then every op would just have the following code in common once the FullTensorConfig
is created (we can later wrap the following in a function like executeFull
or something):
if (config.numShards == 1) {
out = createFullOnSingleDevice(context, config, deviceRef);
} else if (config.numShards > 1) {
out = createFullOnMultiDevice(context, config, deviceRef);
} else {
LOG_FATAL("Unsupported num shards");
}
tensorPool.insert_or_assign(op->out()->global_id(), out);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey Jackson, appreciate the write-up!
I'm prepping a doc on a topic that will, among other things, cover how we run ops, and there I'll go over how I think we should approach these situations - I hope that happens in the next 2 weeks, targetting one of our Thursday tech syncs.
Until then, I'd prefer we treat ttnn::ones
as an op of its own. Would that be okay with you?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Jackson OOO right now and I don't have strong opinion, so will approve this for now to unblock, let's circle back and discuss this situation in the future like you said.
8899587
to
41c0afb
Compare
: ttnn::Layout::RowMajor; | ||
ttnn::LayoutAttr tensorLayoutAttr = | ||
ttnn::LayoutAttr::get(op.getContext(), ttnnLayoutEnum); | ||
ttnn::TensorMemoryLayoutAttr memLayout = layoutAttr.getMemLayout(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't getMemLayout()
return enum ttnn::TensorMemoryLayout
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check 3578538
// Device only exists if ttnn::TensorMemoryLayout is None | ||
// | ||
auto device = | ||
memLayout ? nullptr : ::ttnn::utils::getOrInsertDevice(rewriter, op); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if memLayout
is enum, then check memLayout != None
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check 3578538
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Sasha - Unresolved discussion about using full op to handle this, and this type of situation in general. Let's loop back to that sometime..approving for now to unblock.
Thanks @kmabeeTT! @nsmithtt @tapspatel can I get eyes on this please? :) |
return ::tt::target::TensorLayout::RowMajor; | ||
case ttnn::Layout::Tile: | ||
return ::tt::target::TensorLayout::Tile; | ||
case ttnn::Layout::Invalid: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you know if this enum value is used ever?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't really know, but for the sake of completeness and future-proofing, decided to include it.
|
||
inline ::flatbuffers::Optional<::tt::target::TensorLayout> | ||
toFlatbufferOptional(FlatbufferObjectCache &cache, | ||
::std::optional<mlir::tt::ttnn::Layout> layout) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: You can omit fully typing namespace in this case I think since it's enclosed in mlir::tt, so ttnn::Layout should be enought
::flatbuffers::Optional<::tt::target::TensorLayout> layout = | ||
toFlatbufferOptional(cache, op.getLayout()); | ||
|
||
flatbuffers::Offset<::tt::target::DeviceRef> device = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does serialization work in case there is no device? I see you added 0 as a fallback here. Can we also use flatbuffer::nullopt like you did for layout?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type here is flatbuffers::Offset
which means that the vtable entry of device would be 0. In Flatbuffer speak, that means that the value isn't set. Enums were trickier, I can't remember the details, but I think that when I tried without setting them to null
in program.fbs
, they would be treated such that 0
would mean the first value of the enum instead of unset (or offset=0, as is for non-enum types). Setting the enums to null
is what provides the Optional wrapper.
Hey @nsmithtt @tapspatel, need your review for the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good!
acc73d0
to
788289b
Compare
This PR adds support for
ttnn::ones()
op.MLIRToFlatbuffer.h
, twotoFlatbufferOptional
method have been added to make theTTNNToFlatbuffer.cpp
code cleaner, as all the optional stuff is abstracted awayTTNNToEmitC.cpp
, would like to try to apply this to other ops, and hopefully create a generic solution that would work for most ops without any special case handlingrun()
method inruntime/lib/ttnn/operations/creation/ones.cpp
has been written so that it "mechanically" unpacks a flatbuffer and calls the appropriate ttnn method - it would be great if we could make this generic so that we don't need to manually handle each opN̶o̶t̶e̶:̶ ̶t̶h̶e̶r̶e̶'̶s̶ ̶a̶ ̶p̶i̶e̶c̶e̶ ̶o̶f̶ ̶c̶o̶d̶e̶ ̶
̶l̶i̶b̶/̶C̶o̶n̶v̶e̶r̶s̶i̶o̶n̶/̶T̶T̶N̶N̶T̶o̶E̶m̶i̶t̶C̶/̶U̶t̶i̶l̶s̶.̶c̶p̶p̶
̶ ̶t̶h̶a̶t̶ ̶I̶ ̶c̶h̶a̶n̶g̶e̶d̶ ̶j̶u̶s̶t̶ ̶t̶o̶ ̶m̶a̶k̶e̶ ̶t̶h̶e̶ ̶t̶e̶s̶t̶s̶ ̶r̶u̶n̶,̶ ̶I̶'̶l̶l̶ ̶m̶a̶r̶k̶ ̶i̶t̶ ̶w̶i̶t̶h̶ ̶a̶ ̶c̶o̶m̶m̶e̶n̶t̶,̶ ̶t̶h̶a̶t̶ ̶w̶i̶l̶l̶ ̶b̶e̶ ̶r̶e̶m̶o̶v̶e̶d̶ ̶b̶e̶f̶o̶r̶e̶ ̶t̶h̶i̶s̶ ̶P̶R̶ ̶i̶s̶ ̶m̶e̶r̶g̶e̶d̶ ̶w̶i̶t̶h̶ ̶m̶a̶i̶n̶ ̶-̶ ̶@̶m̶t̶o̶p̶a̶l̶o̶v̶i̶c̶T̶T̶ ̶h̶a̶s̶ ̶a̶ ̶f̶i̶x̶ ̶i̶n̶ ̶t̶h̶e̶ ̶w̶o̶r̶k̶s̶.̶