Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Optimizer/TTNN] Migrating Optimizer to TTNN with minimal possible fu… #944

Merged
merged 1 commit into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion include/ttmlir/Dialect/TTIR/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
#ifndef TTMLIR_DIALECT_TTIR_TRANSFORMS_PASSES_H
#define TTMLIR_DIALECT_TTIR_TRANSFORMS_PASSES_H

#include "ttmlir/Dialect/TT/Utils/OverrideParams.h"
#include "ttmlir/Dialect/TTIR/IR/TTIR.h"
#include "ttmlir/Dialect/TTIR/IR/TTIROps.h"

Expand Down
26 changes: 0 additions & 26 deletions include/ttmlir/Dialect/TTIR/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -108,32 +108,6 @@ def TTIRAllocate: Pass<"ttir-allocate", "::mlir::ModuleOp"> {
}];
}

def TTIROptimizer: Pass<"ttir-optimizer", "::mlir::ModuleOp"> {
let summary = "Determine op configurations for maximum performance.";
let description = [{
Go through the ops, set sharding specs for each op based on sharding analysis,
by updating layout attribute of each op.
}];
let options = [
Option<"overrideOutputLayout", "override-output-layout",
"llvm::StringMap<LayoutOverrideParams>",
/*default=*/"llvm::StringMap<LayoutOverrideParams>()",
"Override output tensor layout for specific ops.">,
Option<"shardingPassEnabled", "sharding-pass-enabled",
"bool",
/*default=*/"false",
"Enable sharding pass.">,
Option<"reshardingEnabled", "resharding-enabled",
"bool",
/*default=*/"false",
"Resharding pass. Temp disabled till we support all types of shard specs.">,
Option<"maxLegalLayouts", "max-legal-layouts",
"int64_t",
/*default=*/"64",
"Override maximum number of legal layouts for grid analysis.">
];
}

def TTIRLoadSystemDesc: Pass<"ttir-load-system-desc", "::mlir::ModuleOp"> {
let summary = "Load system desc.";
let description = [{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_DIALECT_TTIR_ANALYSIS_DFSHARDINGPOLICY_H
#define TTMLIR_DIALECT_TTIR_ANALYSIS_DFSHARDINGPOLICY_H
#ifndef TTMLIR_DIALECT_TTNN_ANALYSIS_DFSHARDINGPOLICY_H
#define TTMLIR_DIALECT_TTNN_ANALYSIS_DFSHARDINGPOLICY_H

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "ttmlir/Dialect/TTIR/Analysis/ShardChainConfig.h"
#include "ttmlir/Dialect/TTNN/Analysis/ShardChainConfig.h"

namespace mlir::tt::ttir {
namespace mlir::tt::ttnn {

// Process ops in DFS schedulable order and build shard chain configs.
// Schedule is also produced as a side effect of sharding.
Expand All @@ -17,14 +17,15 @@ class DFShardingPolicy {
private:
Operation *rootOp;
std::vector<ShardChainConfig> *shardChainConfigs;
llvm::DenseMap<Operation *, std::vector<LayoutAttr>> legalLayouts;
llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>> legalLayouts;
llvm::DenseMap<func::FuncOp, llvm::SmallVector<Operation *>> *schedule;
unsigned usableL1CacheSize = 0;

public:
DFShardingPolicy(
Operation *rootOp, std::vector<ShardChainConfig> &shardChainConfigs,
const llvm::DenseMap<Operation *, std::vector<LayoutAttr>> &legalLayouts,
const llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>>
&legalLayouts,
llvm::DenseMap<func::FuncOp, llvm::SmallVector<Operation *>> &schedule,
unsigned usableL1CacheSize)
: rootOp(rootOp), shardChainConfigs(&shardChainConfigs),
Expand All @@ -34,6 +35,6 @@ class DFShardingPolicy {
void run();
};

} // namespace mlir::tt::ttir
} // namespace mlir::tt::ttnn

#endif // TTMLIR_DIALECT_TTIR_ANALYSIS_DFSHARDINGPOLICY_H
#endif // TTMLIR_DIALECT_TTNN_ANALYSIS_DFSHARDINGPOLICY_H
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_DIALECT_TTIR_ANALYSIS_EDGE_H
#define TTMLIR_DIALECT_TTIR_ANALYSIS_EDGE_H
#ifndef TTMLIR_DIALECT_TTNN_ANALYSIS_EDGE_H
#define TTMLIR_DIALECT_TTNN_ANALYSIS_EDGE_H

#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"

namespace mlir::tt::ttir {
namespace mlir::tt::ttnn {
struct Edge {
Operation *producerOp = nullptr;
Operation *consumerOp = nullptr;
Expand All @@ -23,11 +23,11 @@ struct Edge {
}
};

} // namespace mlir::tt::ttir
} // namespace mlir::tt::ttnn

namespace std {
template <> struct hash<mlir::tt::ttir::Edge> {
size_t operator()(const mlir::tt::ttir::Edge &edge) const noexcept {
template <> struct hash<mlir::tt::ttnn::Edge> {
size_t operator()(const mlir::tt::ttnn::Edge &edge) const noexcept {
llvm::hash_code code = llvm::hash_value(edge.operandIndex);
code = llvm::hash_combine(code, edge.producerOp);
code = llvm::hash_combine(code, edge.consumerOp);
Expand All @@ -36,4 +36,4 @@ template <> struct hash<mlir::tt::ttir::Edge> {
};
} // namespace std

#endif // TTMLIR_DIALECT_TTIR_ANALYSIS_EDGE_H
#endif // TTMLIR_DIALECT_TTNN_ANALYSIS_EDGE_H
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_DIALECT_TTIR_ANALYSIS_LEGALGRIDANALYSIS_H
#define TTMLIR_DIALECT_TTIR_ANALYSIS_LEGALGRIDANALYSIS_H
#ifndef TTMLIR_DIALECT_TTNN_ANALYSIS_LEGALGRIDANALYSIS_H
#define TTMLIR_DIALECT_TTNN_ANALYSIS_LEGALGRIDANALYSIS_H

#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TT/Utils/OverrideParams.h"
#include "ttmlir/Dialect/TTIR/Analysis/TTIRAnalysis.h"
#include "ttmlir/Dialect/TTNN/Analysis/TTNNAnalysis.h"
#include "llvm/ADT/StringMap.h"

namespace mlir::tt::ttir {
namespace mlir::tt::ttnn {

struct LegalGridAnalysisInput {
ChipDescAttr chipDesc;
Expand Down Expand Up @@ -43,15 +43,15 @@ struct LegalGridAnalysisInput {
};

class LegalGridAnalysis
: public TTIRAnalysis<LegalGridAnalysisInput, std::vector<LayoutAttr>> {
: public TTNNAnalysis<LegalGridAnalysisInput, std::vector<tt::LayoutAttr>> {
private:
void analysisImplementation() override;
bool applyOverrides() override;

public:
LegalGridAnalysis(Operation *op) : TTIRAnalysis(op) {}
LegalGridAnalysis(Operation *op) : TTNNAnalysis(op) {}
};

} // namespace mlir::tt::ttir
} // namespace mlir::tt::ttnn

#endif // TTMLIR_DIALECT_TTIR_ANALYSIS_LEGALGRIDANALYSIS_H
#endif // TTMLIR_DIALECT_TTNN_ANALYSIS_LEGALGRIDANALYSIS_H
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,27 @@
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_DIALECT_TTIR_ANALYSIS_OPCONFIGANALYSIS_H
#define TTMLIR_DIALECT_TTIR_ANALYSIS_OPCONFIGANALYSIS_H
#ifndef TTMLIR_DIALECT_TTNN_ANALYSIS_OPCONFIGANALYSIS_H
#define TTMLIR_DIALECT_TTNN_ANALYSIS_OPCONFIGANALYSIS_H

#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTIR/Analysis/TTIRAnalysis.h"
#include "ttmlir/Dialect/TTNN/Analysis/TTNNAnalysis.h"

namespace mlir::tt::ttir {
namespace mlir::tt::ttnn {

struct OpConfigAnalysisInput {
llvm::DenseMap<Operation *, std::vector<LayoutAttr>> legalGrids;
llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>> legalGrids;

OpConfigAnalysisInput() : legalGrids() {}

OpConfigAnalysisInput(
const llvm::DenseMap<Operation *, std::vector<LayoutAttr>> &&legalGrids)
const llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>>
&&legalGrids)
: legalGrids(std::move(legalGrids)) {}

OpConfigAnalysisInput(
const llvm::DenseMap<Operation *, std::vector<LayoutAttr>> &legalGrids)
const llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>>
&legalGrids)
: legalGrids(legalGrids) {}

bool operator==(const OpConfigAnalysisInput &rhs) const {
Expand All @@ -35,16 +37,16 @@ struct OpConfigAnalysisInput {
// Determine optimal configuration for each op.
//
class OpConfigAnalysis
: public TTIRAnalysis<OpConfigAnalysisInput,
llvm::DenseMap<Operation *, LayoutAttr>> {
: public TTNNAnalysis<OpConfigAnalysisInput,
llvm::DenseMap<Operation *, tt::LayoutAttr>> {

private:
void analysisImplementation() override;
bool applyOverrides() override;

public:
OpConfigAnalysis(Operation *op) : TTIRAnalysis(op) {}
OpConfigAnalysis(Operation *op) : TTNNAnalysis(op) {}
};
} // namespace mlir::tt::ttir
} // namespace mlir::tt::ttnn

#endif // TTMLIR_DIALECT_TTIR_ANALYSIS_OPCONFIGANALYSIS_H
#endif // TTMLIR_DIALECT_TTNN_ANALYSIS_OPCONFIGANALYSIS_H
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,19 @@
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_DIALECT_TTIR_ANALYSIS_SHARDCHAINCONFIG_H
#define TTMLIR_DIALECT_TTIR_ANALYSIS_SHARDCHAINCONFIG_H
#ifndef TTMLIR_DIALECT_TTNN_ANALYSIS_SHARDCHAINCONFIG_H
#define TTMLIR_DIALECT_TTNN_ANALYSIS_SHARDCHAINCONFIG_H

#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTIR/Analysis/ShardSolver.h"
#include "ttmlir/Dialect/TTNN/Analysis/ShardSolver.h"
#include <unordered_set>

namespace mlir::tt::ttir {
namespace mlir::tt::ttnn {

struct ShardSpec {
Operation *op;
uint tensorSplitFactor;
LayoutAttr layout;
tt::LayoutAttr layout;
};

// Enum to track the state of the shard chain.
Expand All @@ -36,12 +36,14 @@ class ShardChainConfig {
public:
ShardChainConfig() : shardSpecs(), state() {}

ShardSolver resolve(
const llvm::DenseMap<Operation *, std::vector<LayoutAttr>> &legalLayouts,
unsigned usableL1CacheSize);
ShardSolver
resolve(const llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>>
&legalLayouts,
unsigned usableL1CacheSize);
void build();
void complete(const llvm::DenseMap<Operation *, LayoutAttr> &selectedOpLayout,
std::unordered_set<Edge> &reshardedEdges);
void
complete(const llvm::DenseMap<Operation *, tt::LayoutAttr> &selectedOpLayout,
std::unordered_set<Edge> &reshardedEdges);

bool isEmpty() { return shardSpecs.empty(); }
void addShardSpec(ShardSpec &&spec) {
Expand All @@ -56,6 +58,6 @@ class ShardChainConfig {
}
};

} // namespace mlir::tt::ttir
} // namespace mlir::tt::ttnn

#endif // TTMLIR_DIALECT_TTIR_ANALYSIS_SHARDCHAINCONFIG_H
#endif // TTMLIR_DIALECT_TTNN_ANALYSIS_SHARDCHAINCONFIG_H
Loading
Loading