Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
odjuricicTT committed Feb 21, 2025
1 parent 83c8e8e commit 95e1140
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 30 deletions.
51 changes: 22 additions & 29 deletions lib/Dialect/TTNN/Analysis/ShardSolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,24 +130,39 @@ bool ShardSolver::resolveStep() {
// << "\n Consumer layout " <<
// consumerLayouts[consumerId]
// << "\n\n";
if (reshardOnEdge) {
// TODO(odjuricic): This should read from results of previous
// resolve instead of accepting all.
//
assert(producerId <=
std::numeric_limits<decltype(Path::producerId)>::max());
assert(consumerId <=
std::numeric_limits<decltype(Path::consumerId)>::max());
paths.push_back(Path(producerId, consumerId));
edgeProducerBitset.set(producerId);
edgeConsumerBitset.set(consumerId);
continue;
}

llvm::Expected<bool> shardCompatible = checkShardCompatible(
producerOp->getResult(0), producerLayouts[producerId], consumerOp,
consumerLayouts[consumerId]);

if (!shardCompatible) {
std::string error = llvm::toString(shardCompatible.takeError());
if (!errorCount.count(error)) {
errorCount.insert({error, 0});
}
errorCount[error]++;
} else if (reshardOnEdge) {
if (shardCompatible && shardCompatible.get()) {
assert(producerId <=
std::numeric_limits<decltype(Path::producerId)>::max());
assert(consumerId <=
std::numeric_limits<decltype(Path::consumerId)>::max());
paths.push_back(Path(producerId, consumerId));
edgeProducerBitset.set(producerId);
edgeConsumerBitset.set(consumerId);

} else {
std::string error = llvm::toString(shardCompatible.takeError());
if (!errorCount.count(error)) {
errorCount.insert({error, 0});
}
errorCount[error]++;
}
}
}
Expand All @@ -157,10 +172,6 @@ bool ShardSolver::resolveStep() {

// No valid paths found for this edge, mark it for resharding.
//
if (reshardOnEdge) {
return false;
}

if (!insertReshard(edge)) {
return false;
}
Expand Down Expand Up @@ -236,14 +247,6 @@ bool ShardSolver::preprocessFirstOp() {
// Add constraint check
Operation *firstOp = shardSpecs->front().op;

// if (llvm::isa<tt::ttnn::MatmulOp>(firstOp)) {
// auto loc = llvm::dyn_cast<NameLoc>(firstOp->getLoc());
// if (loc.getName() == "index_262.dc.matmul.3") {
// llvm::errs() << loc;
// }
// llvm::errs() << "MatmulOp\n";
// }

if (memReconfigEdges.count(
Edge(firstOp->getOperand(0).getDefiningOp(), firstOp, 0)) > 0) {
return true;
Expand Down Expand Up @@ -296,17 +299,7 @@ bool ShardSolver::preprocessFirstOp() {
}

if (!hasValidLayout) {
// Print all consumer and producer layouts:
//
// firstOp->emitError() << "No valid output layout found for DRAM input!";

// llvm::errs() << "First op layouts: " << firstOpLayouts.size() << "\n";
// for (auto layout : firstOpLayouts) {
// llvm::errs() << "\t" << layout << "\n";
// }

// Insert reshard edge for the first op to start the chain.
//
Edge shardChainInputEdge = Edge(firstOp->getOperand(0).getDefiningOp(),
firstOp, 0 /*operandIndex*/);

Expand Down
4 changes: 4 additions & 0 deletions test/lit.cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@
# system_desc_path: The system desc that is to be used to generate the binary files.
config.system_desc_path = os.getenv("SYSTEM_DESC_PATH", "")

# This needs to be done optimizer subdirectories only.
lit_config.parallelism_groups["optimizer"] = 1
config.parallelism_group = "optimizer"

# set features based on system
system_desc = None
if config.system_desc_path:
Expand Down
19 changes: 18 additions & 1 deletion test/unittests/Optimizer/TestShardSolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h"

#include "ttmlir/Dialect/TTNN/Analysis/L1ChainConfig.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h"

using namespace mlir::tt::ttnn;

Expand All @@ -42,7 +43,13 @@ class ShardSolverBase : public ::testing::Test {
}

mlir::RankedTensorType getTensorRankedType() {
return mlir::RankedTensorType::get(getTensorShape(), builder.getF32Type());
return mlir::RankedTensorType::get(
getTensorShape(), builder.getF32Type(),
TTNNLayoutAttr::get(&context, getTensorShape(), builder.getF32Type(),
BufferType::DRAM,
mlir::tt::GridAttr::get(&context, {1, 1}),
mlir::tt::ttnn::TensorMemoryLayoutAttr::get(
&context, TensorMemoryLayout::Interleaved)));
}

mlir::Value createEmptyTensor() {
Expand All @@ -63,6 +70,10 @@ class ShardSolverBase : public ::testing::Test {
mlir::TypeRange(input), mlir::TypeRange(output));
func = builder.create<mlir::func::FuncOp>(builder.getUnknownLoc(), "test",
funcType);
func->setAttr(
mlir::tt::DeviceAttr::name,
mlir::tt::DeviceAttr::get(
&context, mlir::tt::SystemDescAttr::getDefault(&context)));

mlir::Block *block = func.addEntryBlock();
block->addArgument(getTensorRankedType(), builder.getUnknownLoc());
Expand Down Expand Up @@ -225,6 +236,12 @@ TEST_F(ShardSolverBase, VerifyProduceMaxCoreUsage) {
TTNNLayoutAttr const &producerLayout,
mlir::Operation *consumerOp,
TTNNLayoutAttr const &consumerLayout) {
// Interleaved to sharded is always supported.
//
if (producerLayout.hasInterleavedDRAMTensorMemoryLayout()) {
return true;
}

// Simple shard compat assumption. Try to keep same shard layout.
//
if (producerLayout.getMemLayout() != consumerLayout.getMemLayout()) {
Expand Down

0 comments on commit 95e1140

Please sign in to comment.