Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#2133: Modified existing golden functionality code to support multi-device tensors for CCL operations #2163

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion include/ttmlir/Target/Common/debug_info.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,14 @@ table GoldenTensor {
data: [uint8];
}

table GoldenDevice {
device: uint32;
value: GoldenTensor;
}

table GoldenKV {
key: string;
value: GoldenTensor;
value: [GoldenDevice];
}

table GoldenInfo {
Expand Down
11 changes: 10 additions & 1 deletion include/ttmlir/Target/TTMetal/TTMetalToFlatbuffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,16 @@ namespace mlir::tt::ttmetal {
// stream.
LogicalResult translateTTMetalToFlatbuffer(
Operation *op, llvm::raw_ostream &os,
std::unordered_map<std::string, GoldenTensor> goldenMap = {});
/* golden map has following structure
{
loc: {
device_id: GoldenTensor
}
}
*/
std::unordered_map<std::string,
std::unordered_map<std::uint32_t, GoldenTensor>>
goldenMap = {});
} // namespace mlir::tt::ttmetal

#endif
4 changes: 2 additions & 2 deletions include/ttmlir/Target/TTNN/TTNNToFlatbuffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@ namespace mlir::tt::ttnn {
// Convert a TTNNIR operation to a flatbuffer
std::shared_ptr<void> ttnnToFlatbuffer(
Operation *op,
const std::unordered_map<std::string, GoldenTensor> &goldenMap = {},
const std::unordered_map<std::string, std::unordered_map<std::uint32_t, GoldenTensor>> &goldenMap = {},
const std::vector<std::pair<std::string, std::string>> &moduleCache = {});

// Convert a TTNNIR operation to a flatbuffer
// This function signature is required in order to register the conversion in
// mlir translation framework
LogicalResult translateTTNNToFlatbuffer(
Operation *op, llvm::raw_ostream &os,
const std::unordered_map<std::string, GoldenTensor> &goldenMap = {},
const std::unordered_map<std::string, std::unordered_map<std::uint32_t, GoldenTensor>> &goldenMap = {},
const std::vector<std::pair<std::string, std::string>> &moduleCache = {});
} // namespace mlir::tt::ttnn

Expand Down
9 changes: 7 additions & 2 deletions lib/Target/TTMetal/TTMetalToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,10 @@ Value getOperandThroughDPSOps(Value value) {
}

static std::shared_ptr<void> translateModuleToFlatbuffer(
Operation *op, std::unordered_map<std::string, GoldenTensor> goldenMap) {
Operation *op,
std::unordered_map<std::string,
std::unordered_map<std::uint32_t, GoldenTensor>>
goldenMap) {
::flatbuffers::FlatBufferBuilder fbb;
FlatbufferObjectCache cache(&fbb);

Expand Down Expand Up @@ -452,7 +455,9 @@ static std::shared_ptr<void> translateModuleToFlatbuffer(

LogicalResult translateTTMetalToFlatbuffer(
Operation *op, llvm::raw_ostream &os,
std::unordered_map<std::string, GoldenTensor> goldenMap) {
std::unordered_map<std::string,
std::unordered_map<std::uint32_t, GoldenTensor>>
goldenMap) {
std::shared_ptr<void> data = translateModuleToFlatbuffer(op, goldenMap);
std::size_t size = ::flatbuffers::GetSizePrefixedBufferLength(
static_cast<const uint8_t *>(data.get()));
Expand Down
29 changes: 21 additions & 8 deletions lib/Target/TTNN/TTNNToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1444,7 +1444,9 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op,

std::shared_ptr<void> ttnnToFlatbuffer(
Operation *op,
const std::unordered_map<std::string, GoldenTensor> &goldenMap,
const std::unordered_map<std::string,
std::unordered_map<std::uint32_t, GoldenTensor>>
&goldenMap,
const std::vector<std::pair<std::string, std::string>> &moduleCache) {
ModuleOp module = dyn_cast<ModuleOp>(op);
assert(module && "Expected ModuleOp as top level operation");
Expand All @@ -1469,13 +1471,24 @@ std::shared_ptr<void> ttnnToFlatbuffer(
std::vector<::flatbuffers::Offset<::tt::target::GoldenKV>> goldenKVList;
goldenKVList.reserve(goldenMap.size());

for (auto element : goldenMap) {
std::vector<std::uint8_t> dataTensor = element.second.convertDataToVector();
auto goldenTensor = ::tt::target::CreateGoldenTensorDirect(
fbb, element.second.name.c_str(), &element.second.shape,
&element.second.strides, element.second.dtype, &dataTensor);
for (auto locMap : goldenMap) {
std::vector<::flatbuffers::Offset<::tt::target::GoldenDevice>>
goldenDeviceList;
goldenDeviceList.reserve(locMap.second.size());

for (auto tensorMap : locMap.second) {
std::vector<std::uint8_t> dataTensor =
tensorMap.second.convertDataToVector();
auto goldenTensor = ::tt::target::CreateGoldenTensorDirect(
fbb, tensorMap.second.name.c_str(), &tensorMap.second.shape,
&tensorMap.second.strides, tensorMap.second.dtype, &dataTensor);
auto goldenDevice =
::tt::target::CreateGoldenDevice(fbb, tensorMap.first, goldenTensor);
goldenDeviceList.push_back(goldenDevice);
}

auto goldenKV = ::tt::target::CreateGoldenKVDirect(
fbb, element.first.c_str(), goldenTensor);
fbb, locMap.first.c_str(), &goldenDeviceList);
goldenKVList.push_back(goldenKV);
}

Expand Down Expand Up @@ -1522,7 +1535,7 @@ std::shared_ptr<void> ttnnToFlatbuffer(

LogicalResult translateTTNNToFlatbuffer(
Operation *op, llvm::raw_ostream &os,
const std::unordered_map<std::string, GoldenTensor> &goldenMap,
const std::unordered_map<std::string, std::unordered_map<std::uint32_t, GoldenTensor>> &goldenMap,
const std::vector<std::pair<std::string, std::string>> &moduleCache) {
std::shared_ptr<void> data = ttnnToFlatbuffer(op, goldenMap, moduleCache);
std::size_t size = ::flatbuffers::GetSizePrefixedBufferLength(
Expand Down
8 changes: 5 additions & 3 deletions python/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,7 @@ void populatePassesModule(py::module &m) {
m.def(
"ttnn_to_flatbuffer_file",
[](MlirModule module, std::string &filepath,
const std::unordered_map<std::string, mlir::tt::GoldenTensor>
&goldenMap = {},
const std::unordered_map<std::string, std::unordered_map<std::uint32_t, mlir::tt::GoldenTensor>> &goldenMap = {},
const std::vector<std::pair<std::string, std::string>> &moduleCache =
{}) {
mlir::Operation *moduleOp = unwrap(mlirModuleGetOperation(module));
Expand All @@ -201,7 +200,10 @@ void populatePassesModule(py::module &m) {

m.def("ttmetal_to_flatbuffer_file",
[](MlirModule module, std::string &filepath,
std::unordered_map<std::string, mlir::tt::GoldenTensor> goldenMap) {
std::unordered_map<
std::string,
std::unordered_map<std::uint32_t, mlir::tt::GoldenTensor>>
&goldenMap) {
mlir::Operation *moduleOp = unwrap(mlirModuleGetOperation(module));
std::error_code fileError;
llvm::raw_fd_ostream file(filepath, fileError);
Expand Down
8 changes: 7 additions & 1 deletion python/test_infra/ttir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,14 +171,20 @@ def generate_input_golden(
def get_golden_map(self) -> Dict:
golden_info = {}
for name, golden_tensor in self.id_golden_map.items():
golden_device_info = {}
golden_tensor = golden_tensor.contiguous()
golden_info[name] = create_golden_tensor(

# for now, assume all golden tensors live on device 0 (todo: tapspatel - extend to multichip)
golden_device_info[0] = create_golden_tensor(
name,
list(golden_tensor.tensor.shape),
list(golden_tensor.tensor.stride()),
DataType.Float32,
golden_tensor.tensor.data_ptr(),
)

golden_info[name] = golden_device_info

return golden_info

# ----- Private helpers -----
Expand Down
8 changes: 5 additions & 3 deletions runtime/include/tt/runtime/detail/ttmetal.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,12 @@ std::string getOpDebugString(OpContext opContextHandle);

std::string getOpLocInfo(OpContext opContextHandle);

Tensor getOpOutputTensor(OpContext opContextHandle,
CallbackContext programContextHandle);
std::unordered_map<std::uint32_t, Tensor>
getOpOutputTensor(OpContext opContextHandle,
CallbackContext programContextHandle);

std::vector<float> getTensorData(Tensor tensor);
std::unordered_map<std::uint32_t, std::vector<float>>
getTensorData(std::unordered_map<std::uint32_t, Tensor> tensor_map);

using InputBuffer =
std::tuple<std::uint32_t, std::shared_ptr<::tt::tt_metal::Buffer>,
Expand Down
8 changes: 5 additions & 3 deletions runtime/include/tt/runtime/detail/ttnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,12 @@ std::string getOpDebugString(OpContext opContextHandle);

std::string getOpLocInfo(OpContext opContextHandle);

Tensor getOpOutputTensor(OpContext opContextHandle,
CallbackContext programContextHandle);
std::unordered_map<std::uint32_t, Tensor>
getOpOutputTensor(OpContext opContextHandle,
CallbackContext programContextHandle);

std::vector<float> getTensorData(Tensor tensor);
std::unordered_map<std::uint32_t, std::vector<float>>
getTensorData(std::unordered_map<std::uint32_t, Tensor> tensor_map);

std::vector<Tensor> submit(Device deviceHandle, Binary executableHandle,
std::uint32_t programIndex,
Expand Down
8 changes: 5 additions & 3 deletions runtime/include/tt/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,12 @@ std::string getOpDebugString(OpContext opContextHandle);

std::string getOpLocInfo(OpContext opContextHandle);

Tensor getOpOutputTensor(OpContext opContextHandle,
CallbackContext programContextHandle);
std::unordered_map<std::uint32_t, Tensor>
getOpOutputTensor(OpContext opContextHandle,
CallbackContext programContextHandle);

std::vector<float> getTensorData(Tensor tensor);
std::unordered_map<std::uint32_t, std::vector<float>>
getTensorData(std::unordered_map<std::uint32_t, Tensor> tensor_map);

std::vector<Tensor> submit(Device deviceHandle, Binary executableHandle,
std::uint32_t programIndex,
Expand Down
5 changes: 4 additions & 1 deletion runtime/include/tt/runtime/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ struct RuntimeCheckedObjectImpl {
std::shared_ptr<void> handle;
::tt::runtime::DeviceRuntime associatedRuntime;

RuntimeCheckedObjectImpl() = default;
RuntimeCheckedObjectImpl(std::shared_ptr<void> handle,
::tt::runtime::DeviceRuntime runtime)
: handle(handle), associatedRuntime(runtime) {}
Expand Down Expand Up @@ -128,7 +129,8 @@ struct Binary : public Flatbuffer {

std::vector<TensorDesc> getProgramInputs(std::uint32_t programIndex) const;
std::vector<TensorDesc> getProgramOutputs(std::uint32_t programIndex) const;
const ::tt::target::GoldenTensor *getDebugInfoGolden(std::string &loc) const;
std::unordered_map<std::uint32_t, const ::tt::target::GoldenTensor *>
getDebugInfoGolden(std::string &loc) const;
};

struct Device : public detail::RuntimeCheckedObjectImpl {
Expand All @@ -142,6 +144,7 @@ struct Event : public detail::RuntimeCheckedObjectImpl {
struct Tensor : public detail::RuntimeCheckedObjectImpl {
std::shared_ptr<void> data;
Event event;
Tensor() = default;
Tensor(std::shared_ptr<void> handle, std::shared_ptr<void> data,
DeviceRuntime runtime)
: detail::RuntimeCheckedObjectImpl(handle, runtime), data(data),
Expand Down
24 changes: 14 additions & 10 deletions runtime/lib/binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,21 +111,25 @@ std::vector<TensorDesc> getProgramOutputs(Flatbuffer binary,
return outputs;
}

const ::tt::target::GoldenTensor *getDebugInfoGolden(Flatbuffer binary,
std::string &loc) {
std::unordered_map<std::uint32_t, const ::tt::target::GoldenTensor *>
getDebugInfoGolden(Flatbuffer binary, std::string &loc) {
std::unordered_map<std::uint32_t, const ::tt::target::GoldenTensor *>
goldenTensorDeviceMap;

auto const *programs = getBinary(binary)->programs();
for (auto const *program : *programs) {
for (const ::tt::target::GoldenKV *goldenKV :
*program->debug_info()->golden_info()->golden_map()) {
if (std::string(goldenKV->key()->c_str()) == loc) {
return goldenKV->value();
;
for (const ::tt::target::GoldenDevice *goldenDevice :
*goldenKV->value()) {
goldenTensorDeviceMap[goldenDevice->device()] = goldenDevice->value();
}
}
}
}

LOG_WARNING("Golden information not found");
return nullptr;
return goldenTensorDeviceMap;
}

} // namespace ttnn
Expand Down Expand Up @@ -198,10 +202,10 @@ std::vector<TensorDesc> getProgramOutputs(Flatbuffer binary,
return outputs;
}

const ::tt::target::GoldenTensor *getDebugInfoGolden(Flatbuffer binary,
std::string &loc) {
std::unordered_map<std::uint32_t, const ::tt::target::GoldenTensor *>
getDebugInfoGolden(Flatbuffer binary, std::string &loc) {
LOG_WARNING("Debug golden information not enabled for metal yet!");
return nullptr;
return {};
}

} // namespace metal
Expand Down Expand Up @@ -368,7 +372,7 @@ Binary::getProgramOutputs(std::uint32_t programIndex) const {
LOG_FATAL("Unsupported binary format");
}

const ::tt::target::GoldenTensor *
std::unordered_map<std::uint32_t, const ::tt::target::GoldenTensor *>
Binary::getDebugInfoGolden(std::string &loc) const {
if (::tt::target::ttnn::SizePrefixedTTNNBinaryBufferHasIdentifier(
handle.get())) {
Expand Down
12 changes: 7 additions & 5 deletions runtime/lib/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -437,8 +437,9 @@ std::string getOpLocInfo(OpContext opContextHandle) {
throw std::runtime_error("runtime is not enabled");
}

Tensor getOpOutputTensor(OpContext opContextHandle,
CallbackContext programContextHandle) {
std::unordered_map<std::uint32_t, Tensor>
getOpOutputTensor(OpContext opContextHandle,
CallbackContext programContextHandle) {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::getOpOutputTensor(opContextHandle,
Expand All @@ -455,16 +456,17 @@ Tensor getOpOutputTensor(OpContext opContextHandle,
LOG_FATAL("runtime is not enabled");
}

std::vector<float> getTensorData(Tensor tensor) {
std::unordered_map<std::uint32_t, std::vector<float>>
getTensorData(std::unordered_map<std::uint32_t, Tensor> tensor_map) {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::getTensorData(tensor);
return ::tt::runtime::ttnn::getTensorData(tensor_map);
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (getCurrentRuntime() == DeviceRuntime::TTMetal) {
return ::tt::runtime::ttmetal::getTensorData(tensor);
return ::tt::runtime::ttmetal::getTensorData(tensor_map);
}
#endif

Expand Down
14 changes: 6 additions & 8 deletions runtime/lib/ttmetal/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,6 @@ static ::tt::target::metal::TTMetalBinary const *getBinary(Flatbuffer binary) {
return ::tt::target::metal::GetSizePrefixedTTMetalBinary(binary.handle.get());
}

static Tensor createNullTensor() {
return Tensor(nullptr, nullptr, DeviceRuntime::TTMetal);
}

static tt::runtime::MemoryView
createMemoryView(tt::tt_metal::detail::MemoryView const &memoryView) {
return tt::runtime::MemoryView{
Expand Down Expand Up @@ -348,14 +344,16 @@ std::string getOpLocInfo(OpContext opContextHandle) {
return "";
}

Tensor getOpOutputTensor(OpContext opContextHandle,
CallbackContext programContextHandle) {
std::unordered_map<std::uint32_t, Tensor>
getOpOutputTensor(OpContext opContextHandle,
CallbackContext programContextHandle) {
// Not implemented
LOG_WARNING("obtaining op output tensor for metal runtime not implemented");
return createNullTensor();
return {};
}

std::vector<float> getTensorData(Tensor tensor) {
std::unordered_map<std::uint32_t, std::vector<float>>
getTensorData(std::unordered_map<std::uint32_t, Tensor> tensor_map) {
// Not implemented
LOG_WARNING("obtaining tensor data for metal runtime not implemented");
return {};
Expand Down
Loading
Loading