Skip to content

Commit

Permalink
Add Python bindings for override mechanism. (#1526)
Browse files Browse the repository at this point in the history
* Add Python bindings for override mechanism.

* Add Python bindings for override mechanism.

* Add Python bindings for override mechanism.
  • Loading branch information
vcanicTT authored Dec 9, 2024
1 parent c1d9d27 commit 223d244
Show file tree
Hide file tree
Showing 10 changed files with 389 additions and 39 deletions.
1 change: 1 addition & 0 deletions include/ttmlir/Bindings/Python/TTMLIRModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ void populateTTIRModule(py::module &m);
void populateTTKernelModule(py::module &m);
void populateTTNNModule(py::module &m);
void populateOverridesModule(py::module &m);
void populateOptimizerOverridesModule(py::module &m);
void populatePassesModule(py::module &m);
} // namespace mlir::ttmlir::python

Expand Down
19 changes: 10 additions & 9 deletions include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ struct TTIRToTTNNBackendPipelineOptions
// configuration for max performance. If this option is false, skip running
// Optimizer pass, thus leaving all ops on default configuration.
Option<bool> optimizerPassEnabled{
*this, "enable-optimizer",
*this, OptionNames::optimizerPassEnabled,
llvm::cl::desc("Determine and set max valid grid for Op execution."),
llvm::cl::init(false)};

Expand All @@ -38,7 +38,7 @@ struct TTIRToTTNNBackendPipelineOptions
//
Option<llvm::StringMap<InputLayoutOverrideParams>, InputLayoutOverrideParser>
overrideInputLayout{
*this, "insert-memreconfig",
*this, OptionNames::overrideInputLayout,
llvm::cl::desc(
"Manually insert memory reconfig op for specific op's operand."),
llvm::cl::init(llvm::StringMap<InputLayoutOverrideParams>())};
Expand Down Expand Up @@ -66,51 +66,52 @@ struct TTIRToTTNNBackendPipelineOptions
Option<llvm::StringMap<OutputLayoutOverrideParams>,
OutputLayoutOverrideParser>
overrideOutputLayout{
*this, "override-output-layout",
*this, OptionNames::overrideOutputLayout,
llvm::cl::desc("Override output tensor layout for specific ops."),
llvm::cl::init(llvm::StringMap<OutputLayoutOverrideParams>())};

// If this option is true, run memory layout analysis.
//
Option<bool> memoryLayoutAnalysisEnabled{
*this, "memory-layout-analysis-enabled",
*this, OptionNames::memoryLayoutAnalysisEnabled,
llvm::cl::desc("Enable memory layout optimization."),
llvm::cl::init(false)};

// If this option is true, insert memory reconfiguration ops.
//
Option<bool> memReconfigEnabled{
*this, "memreconfig-enabled",
*this, OptionNames::memReconfigEnabled,
llvm::cl::desc("Memory layout reconfiguration pass."),
llvm::cl::init(true)};

// Specify policy for memory layout analysis.
//
Option<MemoryLayoutAnalysisPolicyType, MemoryLayoutAnalysisPolicyTypeParser>
memoryLayoutAnalysisPolicy{
*this, "memory-layout-analysis-policy",
*this, OptionNames::memoryLayoutAnalysisPolicy,
llvm::cl::desc("Specify policy for memory layout analysis."),
llvm::cl::init(MemoryLayoutAnalysisPolicyType::DFSharding)};

// Option to provide a system descriptor flatbuffer file to compile
// against.
//
Option<std::string> systemDescPath{
*this, "system-desc-path",
*this, OptionNames::systemDescPath,
llvm::cl::desc(
"Pass in a system descriptor flatbuffer to compile against."),
llvm::cl::init("")};

// Option to override maximum number of legal layouts for grid analysis
//
Option<int64_t> maxLegalLayouts{
*this, "max-legal-layouts",
*this, OptionNames::maxLegalLayouts,
llvm::cl::desc(
"Override maximum number of legal layouts for grid analysis."),
llvm::cl::init(64)};

ListOption<int64_t> meshShape{
*this, "mesh-shape", llvm::cl::desc("Set the multi-device mesh shape.")};
*this, OptionNames::meshShape,
llvm::cl::desc("Set the multi-device mesh shape.")};

// Option to enable/disable the workaround pass.
//
Expand Down
21 changes: 16 additions & 5 deletions include/ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
#ifndef TTMLIR_DIALECT_TTNN_UTILS_OPTIMIZEROVERRIDES_H
#define TTMLIR_DIALECT_TTNN_UTILS_OPTIMIZEROVERRIDES_H

#include "ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h"
#include <iostream>
#include <string>
#include <unordered_map>

#include "ttmlir/Dialect/TTNN/Utils/MemoryLayoutAnalysisParams.h"
#include "ttmlir/Dialect/TTNN/Utils/PassOverrides.h"

Expand Down Expand Up @@ -65,11 +68,19 @@ class OptimizerOverridesHandler {
TensorMemoryLayout, tt::ttnn::Layout,
tt::DataType);

private:
// Options for the TTIR to TTNN backend pipeline,
// we use them to extract the names and the deafulat values.
TTIRToTTNNBackendPipelineOptions pipelineOptions;
// Wrapper methods we use to expose the adders to the python bindings
std::unordered_map<std::string, InputLayoutOverrideParams>
getInputLayoutOverridesPybindWrapper() const;
std::unordered_map<std::string, OutputLayoutOverrideParams>
getOutputLayoutOverridesPybindWrapper() const;

// Wrapper methods we use to expose the adders to the python bindings
void addInputLayoutOverridePybindWrapper(std::string, std::vector<int64_t> &);
void addOutputLayoutOverridePybindWrapper(std::string, std::vector<int64_t> &,
BufferType, TensorMemoryLayout,
tt::ttnn::Layout, tt::DataType);

private:
// Flags for enabling/disabling the optimizer passes
bool enableOptimizer = false;

Expand Down
17 changes: 17 additions & 0 deletions include/ttmlir/Dialect/TTNN/Utils/PassOverrides.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#ifndef TTMLIR_DIALECT_TTNN_UTILS_PASSOVERRIDES_H
#define TTMLIR_DIALECT_TTNN_UTILS_PASSOVERRIDES_H

#include <string_view>

#include <llvm/Support/CommandLine.h>

#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
Expand All @@ -13,6 +15,21 @@

namespace mlir::tt::ttnn {

struct OptionNames {

static constexpr StringRef optimizerPassEnabled = "enable-optimizer";
static constexpr StringRef overrideInputLayout = "insert-memreconfig";
static constexpr StringRef overrideOutputLayout = "override-output-layout";
static constexpr StringRef memoryLayoutAnalysisEnabled =
"memory-layout-analysis-enabled";
static constexpr StringRef memReconfigEnabled = "memreconfig-enabled";
static constexpr StringRef memoryLayoutAnalysisPolicy =
"memory-layout-analysis-policy";
static constexpr StringRef systemDescPath = "system-desc-path";
static constexpr StringRef maxLegalLayouts = "max-legal-layouts";
static constexpr StringRef meshShape = "mesh-shape";
};

struct OutputLayoutOverrideParams {

SmallVector<int64_t, 2> grid;
Expand Down
82 changes: 57 additions & 25 deletions lib/Dialect/TTNN/Utils/OptimizerOverrides.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,65 +81,79 @@ OptimizerOverridesHandler::getOutputLayoutOverrides() const {
return outputLayoutOverrides;
}

std::unordered_map<std::string, InputLayoutOverrideParams>
OptimizerOverridesHandler::getInputLayoutOverridesPybindWrapper() const {
std::unordered_map<std::string, InputLayoutOverrideParams>
inputLayoutOverridesWrapper;
for (auto &entry : inputLayoutOverrides) {
inputLayoutOverridesWrapper[entry.getKey().str()] = entry.getValue();
}
return inputLayoutOverridesWrapper;
}

std::unordered_map<std::string, OutputLayoutOverrideParams>
OptimizerOverridesHandler::getOutputLayoutOverridesPybindWrapper() const {
std::unordered_map<std::string, OutputLayoutOverrideParams>
outputLayoutOverridesWrapper;
for (auto &entry : outputLayoutOverrides) {
outputLayoutOverridesWrapper[entry.getKey().str()] = entry.getValue();
}
return outputLayoutOverridesWrapper;
}

std::string OptimizerOverridesHandler::toString() const {

std::string options = "";

if (enableOptimizer) {
options += std::string(pipelineOptions.optimizerPassEnabled.getArgStr()) +
"=true ";
options += OptionNames::optimizerPassEnabled.str() + "=true ";
}

if (enableMemoryReconfig) {
options +=
std::string(pipelineOptions.memReconfigEnabled.getArgStr()) + "=true ";
options += OptionNames::memReconfigEnabled.str() + "=true ";
}

if (enableMemoryLayoutAnalysis) {
options +=
std::string(pipelineOptions.memoryLayoutAnalysisEnabled.getArgStr()) +
"=true ";
options += OptionNames::memoryLayoutAnalysisEnabled.str() + "=true ";
}

if (enableMemoryLayoutAnalysisPolicy) {
options +=
std::string(pipelineOptions.memoryLayoutAnalysisPolicy.getArgStr()) +
MemoryLayoutAnalysisPolicyTypeParser::toString(
memoryLayoutAnalysisPolicy) +
" ";
options += OptionNames::memoryLayoutAnalysisPolicy.str() + "=" +
MemoryLayoutAnalysisPolicyTypeParser::toString(
memoryLayoutAnalysisPolicy) +
" ";
}

// Create input layout overrides.
// Example: insert-memreconfig=input0=0:1,input1=0,input2=0:1:2
// Example:
// insert-memreconfig=input0=0:1,input1=0,input2=0:1:2
if (inputLayoutOverrides.size() > 0) {
options += std::string(pipelineOptions.overrideInputLayout.getArgStr()) +
"=" + InputLayoutOverrideParser::toString(inputLayoutOverrides) +
" ";
options += OptionNames::overrideInputLayout.str() + "=" +
InputLayoutOverrideParser::toString(inputLayoutOverrides) + " ";
}

// Create output layout overrides.
// Example:
// override-output-layout=op1=2x2:dram:interleaved:tile:fp32,op2=4x4:l1:block_sharded:row_major:fp16
// override-output-layout=op1=2x2:dram:interleaved:tile:fp32,op2=4x4:l1:block_sharded:row_major:fp16
// Example:
// override-output-layout=add_1_2=1x1:dram:interleaved:row_major:f32"
// override-output-layout=add_1_2=1x1:dram:interleaved:row_major:f32"
if (outputLayoutOverrides.size() > 0) {
options +=
std::string(pipelineOptions.overrideOutputLayout.getArgStr()) + "=" +
OutputLayoutOverrideParser::toString(outputLayoutOverrides) + " ";
options += OptionNames::overrideOutputLayout.str() + "=" +
OutputLayoutOverrideParser::toString(outputLayoutOverrides) +
" ";
}

if (systemDescPath.size() > 0) {
options += std::string(pipelineOptions.systemDescPath.getArgStr()) +
systemDescPath + " ";
options += OptionNames::systemDescPath.str() + "=" + systemDescPath + " ";
}

if (maxLegalLayouts > 0) {
options += std::string(pipelineOptions.maxLegalLayouts.getArgStr()) +
options += OptionNames::maxLegalLayouts.str() + "=" +
std::to_string(maxLegalLayouts) + " ";
}

if (meshShape.size() > 0) {
options += std::string(pipelineOptions.meshShape.getArgStr()) + "=";
options += OptionNames::meshShape.str() + "=";
for (int64_t meshShapeValue : meshShape) {
options += std::to_string(meshShapeValue) + ",";
}
Expand Down Expand Up @@ -175,4 +189,22 @@ void OptimizerOverridesHandler::addOutputLayoutOverride(
std::move(grid), bufferType, tensorMemoryLayout, memoryLayout, dataType};
}

void OptimizerOverridesHandler::addInputLayoutOverridePybindWrapper(
std::string opName, std::vector<int64_t> &operandIdxes) {
StringRef opNameStringRef(opName);
SmallVector<int64_t> operandIdxesSmallVector(operandIdxes.begin(),
operandIdxes.end());
addInputLayoutOverride(opNameStringRef, operandIdxesSmallVector);
}

void OptimizerOverridesHandler::addOutputLayoutOverridePybindWrapper(
std::string opName, std::vector<int64_t> &grid, BufferType bufferType,
TensorMemoryLayout tensorMemoryLayout, tt::ttnn::Layout memoryLayout,
tt::DataType dataType) {
StringRef opNameStringRef(opName);
SmallVector<int64_t> gridSmallVector(grid.begin(), grid.end());
addOutputLayoutOverride(opNameStringRef, gridSmallVector, bufferType,
tensorMemoryLayout, memoryLayout, dataType);
}

} // namespace mlir::tt::ttnn
7 changes: 7 additions & 0 deletions python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ declare_mlir_python_sources(TTMLIRPythonSources.Overrides
SOURCES overrides.py
)

declare_mlir_python_sources(TTMLIRPythonSources.OptimizerOverrides
ROOT_DIR "${TTMLIR_PYTHON_ROOT_DIR}"
ADD_TO_PARENT TTMLIRPythonSources
SOURCES optimizer_overrides.py
)

declare_mlir_python_sources(TTMLIRPythonSources.Passes
ROOT_DIR "${TTMLIR_PYTHON_ROOT_DIR}"
ADD_TO_PARENT TTMLIRPythonSources
Expand All @@ -87,6 +93,7 @@ declare_mlir_python_extension(TTMLIRPythonExtensions.Main
TTKernelModule.cpp
TTNNModule.cpp
Overrides.cpp
OptimizerOverrides.cpp
Passes.cpp
EMBED_CAPI_LINK_LIBS
MLIRCAPITransforms
Expand Down
Loading

0 comments on commit 223d244

Please sign in to comment.