Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ttnn runtime layout comparisons in debug mode, minor runtime build cleanup #2338

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 16 additions & 18 deletions lib/Dialect/TTNN/Transforms/TTNNDecomposeLayouts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -652,13 +652,13 @@ class TTNNDecomposeLayouts
if (info.shouldUntilize() && !opsToCreate.createFromDeviceOp) {
// Force-create a FromDeviceOp
currentInput = this->createFromDeviceOpIfNeeded(
op, rewriter, currentInput, info, true /* forceCreate */);
op, rewriter, currentInput, info, /*forceCreate=*/true);
// untilize on host
currentInput =
this->createToLayoutOpIfNeeded(op, rewriter, currentInput, info);
// move back to device and convert memory config if needed
currentInput = createToDeviceOpIfNeeded(op, rewriter, currentInput, info,
true /* forceCreate */);
/*forceCreate=*/true);
op.getResult().replaceAllUsesWith(currentInput);
return;
}
Expand Down Expand Up @@ -694,13 +694,13 @@ class TTNNDecomposeLayouts
!opsToCreate.createFromDeviceOp) {
// Force-create a FromDeviceOp
currentInput = this->createFromDeviceOpIfNeeded(
op, rewriter, currentInput, info, true /* forceCreate */);
op, rewriter, currentInput, info, /*forceCreate=*/true);
// tilize on host
currentInput =
this->createToLayoutOpIfNeeded(op, rewriter, currentInput, info);
// move back to device and convert memory config if needed
currentInput = this->createToDeviceOpIfNeeded(
op, rewriter, currentInput, info, true /* forceCreate */);
currentInput = this->createToDeviceOpIfNeeded(op, rewriter, currentInput,
info, /*forceCreate=*/true);
currentInput = this->createToMemoryConfigOpIfNeeded(op, rewriter,
currentInput, info);
op.getResult().replaceAllUsesWith(currentInput);
Expand Down Expand Up @@ -745,15 +745,14 @@ class TTNNDecomposeLayouts
/* Device to device untilized typecast, need to move to host first */
if (!output.isTilized() && !opsToCreate.createFromDeviceOp) {
// Force-create a FromDeviceOp
currentInput =
this->createOp<ttnn::FromDeviceOp>(op, rewriter, currentInput);
currentInput = this->createFromDeviceOpIfNeeded(
op, rewriter, currentInput, info, /*forceCreate=*/true);
// typecast on host
currentInput = this->createDataTypeCastingOpIfNeeded(op, rewriter,
currentInput, info);
// move back to device and convert memory config if needed
currentInput = this->createOp<ttnn::ToDeviceOp>(
op, rewriter, currentInput, info.device,
/* optional MemConfigAttr */ nullptr);
currentInput = this->createToDeviceOpIfNeeded(op, rewriter, currentInput,
info, /*forceCreate=*/true);
currentInput = this->createToMemoryConfigOpIfNeeded(op, rewriter,
currentInput, info);
op.getResult().replaceAllUsesWith(currentInput);
Expand Down Expand Up @@ -790,13 +789,13 @@ class TTNNDecomposeLayouts
currentInput, info);
// Force-create a FromDeviceOp
currentInput = this->createFromDeviceOpIfNeeded(
op, rewriter, currentInput, info, true /* forceCreate */);
op, rewriter, currentInput, info, /*forceCreate=*/true);
// untilize on host
currentInput =
this->createToLayoutOpIfNeeded(op, rewriter, currentInput, info);
// move back to device and convert memory config if needed
currentInput = this->createToDeviceOpIfNeeded(
op, rewriter, currentInput, info, true /* forceCreate */);
currentInput = this->createToDeviceOpIfNeeded(op, rewriter, currentInput,
info, /*forceCreate=*/true);
op.getResult().replaceAllUsesWith(currentInput);
return;
}
Expand Down Expand Up @@ -836,15 +835,14 @@ class TTNNDecomposeLayouts
if (info.shouldTilize() && input.dataType != DataType::BFloat16 &&
!opsToCreate.createFromDeviceOp) {
// Force-create a FromDeviceOp
currentInput =
this->createOp<ttnn::FromDeviceOp>(op, rewriter, currentInput);
currentInput = this->createFromDeviceOpIfNeeded(
op, rewriter, currentInput, info, /*forceCreate=*/true);
// tilize on host
currentInput =
this->createToLayoutOpIfNeeded(op, rewriter, currentInput, info);
// move back to device and convert data type/memory config if needed
currentInput = this->createOp<ttnn::ToDeviceOp>(
op, rewriter, currentInput, info.device,
/* optional MemConfigAttr */ nullptr);
currentInput = this->createToDeviceOpIfNeeded(op, rewriter, currentInput,
info, /*forceCreate=*/true);
currentInput = this->createDataTypeCastingOpIfNeeded(op, rewriter,
currentInput, info);
currentInput = this->createToMemoryConfigOpIfNeeded(op, rewriter,
Expand Down
13 changes: 9 additions & 4 deletions lib/Dialect/TTNN/Transforms/TTNNLayout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ class TTNNLayoutDPSOperandsRewriter
Location newLoc =
appendInputSuffix(op.getLoc(), operand.getOperandNumber());

bool isTiled = shouldTilize(op, operand.getOperandNumber());
bool isTiled = shouldTilize(op, operand.getOperandNumber(), isDPSResult);

// Given the operand constraint, create the desired layout for the operand
std::optional<Value> desiredLayout = createToLayoutOp(
Expand All @@ -418,8 +418,8 @@ class TTNNLayoutDPSOperandsRewriter
}

private:
bool shouldTilize(DestinationStyleOpInterface dpsOp,
int64_t operandNumber) const {
bool shouldTilize(DestinationStyleOpInterface dpsOp, int64_t operandNumber,
bool isDPSResult) const {

Operation *operation = dpsOp.getOperation();

Expand All @@ -435,7 +435,12 @@ class TTNNLayoutDPSOperandsRewriter
}

// These ops constrain to ROW_MAJOR on their operands
if (mlir::isa<ttir::Conv2dOp>(operation)) {
//
// For conv2d, the result tensor is tilized by default in runtime
// unless we specify an override in the Conv2dConfig (which we don't
// currently). Therefore we don't force row major if the operand is a DPS
// result
if (mlir::isa<ttir::Conv2dOp>(operation) && !isDPSResult) {
return false;
}
return true;
Expand Down
32 changes: 5 additions & 27 deletions runtime/lib/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,23 +1,9 @@
if (TTMLIR_ENABLE_RUNTIME)
if (TT_RUNTIME_ENABLE_TTNN)
add_subdirectory(ttnn)
else()
add_library(TTRuntimeTTNN INTERFACE)
endif()
if (TT_RUNTIME_ENABLE_TTMETAL)
add_subdirectory(ttmetal)
else()
add_library(TTRuntimeTTMetal INTERFACE)
endif()
else()
add_library(TTRuntimeTTNN INTERFACE)
add_library(TTRuntimeTTNNTypes INTERFACE)
add_library(TTRuntimeTTNNUtils INTERFACE)
add_library(TTRuntimeTTMetal INTERFACE)
endif()

message(STATUS "Runtimes Enabled: TTNN[${TT_RUNTIME_ENABLE_TTNN}] TTMETAL[${TT_RUNTIME_ENABLE_TTMETAL}]")

add_subdirectory(common)
add_subdirectory(ttnn)
add_subdirectory(ttmetal)

add_library(TTBinary STATIC binary.cpp)
set_property(TARGET TTBinary PROPERTY CXX_STANDARD 20)
target_include_directories(TTBinary
Expand All @@ -27,14 +13,6 @@ target_include_directories(TTBinary
)
add_dependencies(TTBinary FBS_GENERATION)

if (TTMLIR_ENABLE_RUNTIME AND (TT_RUNTIME_ENABLE_TTNN OR TT_RUNTIME_ENABLE_TTMETAL))
add_subdirectory(common)
else()
add_library(TTRuntimeSysDesc INTERFACE)
add_library(TTRuntimeDebug INTERFACE)
add_library(TTRuntimeWorkarounds INTERFACE)
endif()

add_library(TTMLIRRuntime SHARED runtime.cpp)
set_property(TARGET TTMLIRRuntime PROPERTY CXX_STANDARD 20)
if (TTMLIR_ENABLE_RUNTIME AND TT_RUNTIME_ENABLE_TTNN)
Expand Down Expand Up @@ -72,7 +50,7 @@ target_link_libraries(TTMLIRRuntime
set_target_properties(TTMLIRRuntime PROPERTIES INSTALL_RPATH "$ORIGIN")
set_target_properties(TTMLIRRuntime PROPERTIES BUILD_WITH_INSTALL_RPATH TRUE)

add_dependencies(TTMLIRRuntime TTBinary TTRuntimeSysDesc TTRuntimeDebug TTRuntimeWorkarounds FBS_GENERATION)
add_dependencies(TTMLIRRuntime TTBinary TTRuntimeSysDesc TTRuntimeDebug TTRuntimeWorkarounds TTRuntimeTTNN TTRuntimeTTMetal FBS_GENERATION)

if (TTMLIR_ENABLE_RUNTIME)
set_target_properties(TTMLIRRuntime PROPERTIES PUBLIC_HEADER "../include/tt/runtime/runtime.h;../include/tt/runtime/types.h;../include/tt/runtime/utils.h")
Expand Down
13 changes: 13 additions & 0 deletions runtime/lib/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,19 @@ message(STATUS "Target triple: ${triple}")

add_definitions(-DTARGET_TRIPLE="${triple}")

if (TTMLIR_ENABLE_RUNTIME AND (TT_RUNTIME_ENABLE_TTNN OR TT_RUNTIME_ENABLE_TTMETAL))
set(ANY_RUNTIME_ENABLED ON)
else()
set(ANY_RUNTIME_ENABLED OFF)
endif()

if (NOT ANY_RUNTIME_ENABLED)
add_library(TTRuntimeSysDesc INTERFACE)
add_library(TTRuntimeDebug INTERFACE)
add_library(TTRuntimeWorkarounds INTERFACE)
return()
endif()

add_library(TTRuntimeSysDesc STATIC system_desc.cpp)
set_property(TARGET TTRuntimeSysDesc PROPERTY CXX_STANDARD 20)
target_include_directories(TTRuntimeSysDesc
Expand Down
1 change: 0 additions & 1 deletion runtime/lib/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

#if defined(TT_RUNTIME_ENABLE_TTNN)
#include "tt/runtime/detail/ttnn.h"
#include "tt/runtime/ttnn/types.h"
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
Expand Down
11 changes: 11 additions & 0 deletions runtime/lib/ttmetal/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
if (TTMLIR_ENABLE_RUNTIME AND TT_RUNTIME_ENABLE_TTMETAL)
set(TTMETAL_RUNTIME_ENABLED ON)
else()
set(TTMETAL_RUNTIME_ENABLED OFF)
endif()

if (NOT TTMETAL_RUNTIME_ENABLED)
add_library(TTRuntimeTTMetal INTERFACE)
return()
endif()

add_library(TTRuntimeTTMetal
STATIC
runtime.cpp
Expand Down
14 changes: 13 additions & 1 deletion runtime/lib/ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,19 @@
add_subdirectory(types)
if (TTMLIR_ENABLE_RUNTIME AND TT_RUNTIME_ENABLE_TTNN)
set(TTNN_RUNTIME_ENABLED ON)
else()
set(TTNN_RUNTIME_ENABLED OFF)
endif()

add_subdirectory(utils)
add_subdirectory(debug)
add_subdirectory(types)
add_subdirectory(operations)

if (NOT TTNN_RUNTIME_ENABLED)
add_library(TTRuntimeTTNN INTERFACE)
return()
endif()

add_library(TTRuntimeTTNN
STATIC
runtime.cpp
Expand Down
19 changes: 19 additions & 0 deletions runtime/lib/ttnn/debug/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
if (NOT TTNN_RUNTIME_ENABLED OR NOT TT_RUNTIME_DEBUG)
add_library(TTRuntimeTTNNDebug INTERFACE)
return()
endif()

add_library(TTRuntimeTTNNDebug
STATIC
debug_apis.cpp
)
set_property(TARGET TTRuntimeTTNNDebug PROPERTY CXX_STANDARD 20)
target_compile_options(TTRuntimeTTNNDebug PUBLIC -mavx -mavx2)
target_include_directories(TTRuntimeTTNNDebug PUBLIC
${PROJECT_SOURCE_DIR}/runtime/include
${PROJECT_SOURCE_DIR}/runtime/lib/ttnn/include
${PROJECT_BINARY_DIR}/include/ttmlir/Target/Common
)
target_include_directories(TTRuntimeTTNNDebug SYSTEM PUBLIC "$<BUILD_INTERFACE:${TTMETAL_INCLUDE_DIRS}>")
add_dependencies(TTRuntimeTTNNDebug TTRuntimeTTNNUtils)
target_link_libraries(TTRuntimeTTNNDebug PUBLIC TTRuntimeTTNNUtils)
90 changes: 90 additions & 0 deletions runtime/lib/ttnn/debug/debug_apis.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#if !defined(TT_RUNTIME_DEBUG) || !TT_RUNTIME_DEBUG
#error "TT_RUNTIME_DEBUG must be defined and set"
#endif

#include "tt/runtime/ttnn/debug_apis.h"
#include "tt/runtime/detail/logger.h"
#include "tt/runtime/ttnn/utils.h"

namespace tt::runtime::ttnn::debug {

static std::string toString(::ttnn::Layout layout) {
switch (layout) {
case ::ttnn::Layout::ROW_MAJOR:
return "ROW_MAJOR";
case ::ttnn::Layout::TILE:
return "TILE";
case ::ttnn::Layout::INVALID:
return "INVALID";
}
}

static std::string toString(::ttnn::DataType dtype) {
switch (dtype) {
case ::ttnn::DataType::FLOAT32:
return "FLOAT32";
case ::ttnn::DataType::BFLOAT16:
return "BFLOAT16";
case ::ttnn::DataType::BFLOAT8_B:
return "BFLOAT8_B";
case ::ttnn::DataType::BFLOAT4_B:
return "BFLOAT4_B";
case ::ttnn::DataType::UINT32:
return "UINT32";
case ::ttnn::DataType::UINT16:
return "UINT16";
case ::ttnn::DataType::UINT8:
return "UINT8";
case ::ttnn::DataType::INT32:
return "INT32";
case ::ttnn::DataType::INVALID:
return "INVALID";
}
}

void checkTensorRefMatchesTTNNTensor(
const ::tt::target::ttnn::TensorRef *tensorRef,
const ::ttnn::Tensor &ttnnTensor) {
::ttnn::Layout expectedLayout =
::tt::runtime::ttnn::utils::inferLayoutFromTileShape(tensorRef);
::ttnn::Layout actualLayout = ttnnTensor.get_layout();
DEBUG_ASSERT(expectedLayout == actualLayout, "Layout mismatch, expected ",
toString(expectedLayout), ", got ", toString(actualLayout));

::ttnn::DataType expectedDataType =
::tt::runtime::ttnn::utils::toTTNNDataType(
tensorRef->desc()->layout()->memory_desc()->data_type());
::ttnn::DataType actualDataType = ttnnTensor.get_dtype();
DEBUG_ASSERT(expectedDataType == actualDataType,
"DataType mismatch, expected ", toString(expectedDataType),
", got ", toString(actualDataType));

// TODO (jnie): Compare storage once we correctly determine it in the
// flatbuffer. This requires compiler support which is missing.
//
// ::ttnn::StorageType expectedStorageType =
// ::tt::runtime::ttnn::utils::toTTNNStorageType(
// tensorRef->desc()->layout()->memory_desc()->storage_type());
// ::ttnn::StorageType actualStorageType =
// ttnnTensor.storage_type();
// DEBUG_ASSERT(expectedStorageType == actualStorageType, "Storage type
// mismatch, expected ", static_cast<int>(expectedStorageType), ", got ",
// static_cast<int>(actualStorageType));

if (!::tt::runtime::ttnn::utils::inSystemMemory(tensorRef)) {
const ::tt::target::ttnn::MemoryConfig *memcfg =
::tt::runtime::ttnn::utils::getTensorRefMemoryConfig(tensorRef);
DEBUG_ASSERT(memcfg, "Device tensor must have memory config");
::ttnn::MemoryConfig expectedMemoryConfig =
::tt::runtime::ttnn::utils::createMemoryConfigIfNeeded(memcfg).value();
::ttnn::MemoryConfig actualMemoryConfig = ttnnTensor.memory_config();
DEBUG_ASSERT(expectedMemoryConfig == actualMemoryConfig,
"Memory config mismatch, expected ", expectedMemoryConfig,
", got ", actualMemoryConfig);
}
}
} // namespace tt::runtime::ttnn::debug
34 changes: 34 additions & 0 deletions runtime/lib/ttnn/include/tt/runtime/ttnn/debug_apis.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TT_RUNTIME_TTNN_DEBUG_APIS_H
#define TT_RUNTIME_TTNN_DEBUG_APIS_H

#include "tt/runtime/detail/ttnn.h"
#include "ttmlir/Target/TTNN/Target.h"

#if defined(TT_RUNTIME_DEBUG) && TT_RUNTIME_DEBUG == 1
#define RUNTIME_DEBUG_MAYBE_CONST_INLINE
#else
#define RUNTIME_DEBUG_MAYBE_CONST_INLINE \
inline __attribute__((always_inline, const))
#endif

namespace tt::runtime::ttnn::debug {

RUNTIME_DEBUG_MAYBE_CONST_INLINE void
checkTensorRefMatchesTTNNTensor(const ::tt::target::ttnn::TensorRef *tensorRef,
const ::ttnn::Tensor &ttnnTensor)
#if defined(TT_RUNTIME_DEBUG) && TT_RUNTIME_DEBUG == 1
;
#else
{
}
#endif

#undef RUNTIME_DEBUG_MAYBE_CONST_INLINE

} // namespace tt::runtime::ttnn::debug

#endif // TT_RUNTIME_TTNN_DEBUG_APIS_H
Loading
Loading