Skip to content

Commit

Permalink
#1683: Updated memory collection and report of dram and l1 using new …
Browse files Browse the repository at this point in the history
…metal APIs
  • Loading branch information
tapspatel committed Jan 16, 2025
1 parent 6a76e4a commit d8bbeb0
Show file tree
Hide file tree
Showing 12 changed files with 262 additions and 219 deletions.
7 changes: 6 additions & 1 deletion docs/src/ttrt.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,11 @@ ttrt run out.ttnn --save-artifacts --artifact-dir /path/to/some/dir
ttrt run out.ttnn --load-kernels-from-disk
ttrt run out.ttnn --enable-async-ttnn
ttrt run out.ttnn --result-file result.json
ttrt run out.ttnn --golden
ttrt run out.ttnn --disable-golden
ttrt run out.ttnn --save-golden-tensors
ttrt run out.ttnn --debugger
ttrt run out.ttnn --memory --save-artifacts
ttrt run out.ttnn --memory --check-memory-leak
```

### query
Expand Down Expand Up @@ -219,6 +223,7 @@ ttrt perf /dir/of/flatbuffers --loops 10 --host-only
ttrt perf /dir/of/flatbuffers --log-file ttrt.log --host-only
ttrt perf --save-artifacts --artifact-dir /path/to/some/dir
ttrt perf out.ttnn --result-file result.json
ttrt run out.ttnn --memory
```

To use the Tracy GUI, run the following instructions on your macbook. You can upload your .tracy file into the GUI to view the profiled dumps.
Expand Down
3 changes: 3 additions & 0 deletions runtime/include/tt/runtime/detail/ttmetal.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ void deallocateBuffers(Device device);

void dumpMemoryReport(Device device);

std::unordered_map<tt::runtime::MemoryBufferType, tt::runtime::MemoryView>
getMemoryView(Device device, int deviceID = 0);

void wait(Event event);

void wait(Tensor tensor);
Expand Down
3 changes: 3 additions & 0 deletions runtime/include/tt/runtime/detail/ttnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ void deallocateBuffers(Device device);

void dumpMemoryReport(Device device);

std::unordered_map<tt::runtime::MemoryBufferType, tt::runtime::MemoryView>
getMemoryView(Device device, int deviceID = 0);

void wait(Event event);

void wait(Tensor tensor);
Expand Down
13 changes: 13 additions & 0 deletions runtime/include/tt/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,19 @@ std::pair<SystemDesc, DeviceIds> getCurrentSystemDesc();
namespace detail {
void deallocateBuffers(Device device);
void dumpMemoryReport(Device device);

/*
This function get the memory view per device
{
"DRAM": MemoryView,
"L1": MemoryView,
"L1Small": MemoryView,
"Trace": MemoryView
}
*/
std::unordered_map<tt::runtime::MemoryBufferType, tt::runtime::MemoryView>
getMemoryView(Device device, int deviceID = 0);

} // namespace detail

DeviceRuntime getCurrentRuntime();
Expand Down
23 changes: 23 additions & 0 deletions runtime/include/tt/runtime/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,20 @@
#pragma clang diagnostic pop

namespace tt::runtime {
/*
MemoryBlockTable is a list of memory blocks in the following format:
[{"blockID": "0", "address": "0", "size": "0", "prevID": "0", "nextID": "0",
"allocated": true}] address: bytes size: bytes
*/
using MemoryBlockTable =
std::vector<std::unordered_map<std::string, std::string>>;

enum class MemoryBufferType {
DRAM,
L1,
L1_SMALL,
TRACE,
};

enum class DeviceRuntime {
Disabled,
Expand Down Expand Up @@ -146,6 +160,15 @@ struct OpContext : public detail::RuntimeCheckedObjectImpl {
using detail::RuntimeCheckedObjectImpl::RuntimeCheckedObjectImpl;
};

struct MemoryView {
std::uint64_t numBanks = 0;
size_t totalBytesPerBank = 0;
size_t totalBytesAllocatedPerBank = 0;
size_t totalBytesFreePerBank = 0;
size_t largestContiguousBytesFreePerBank = 0;
MemoryBlockTable blockTable;
};

} // namespace tt::runtime

#endif
17 changes: 17 additions & 0 deletions runtime/lib/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,23 @@ void dumpMemoryReport(Device device) {

LOG_FATAL("runtime is not enabled");
}

std::unordered_map<tt::runtime::MemoryBufferType, tt::runtime::MemoryView>
getMemoryView(Device device, int deviceID) {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::getMemoryView(device, deviceID);
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (getCurrentRuntime() == DeviceRuntime::TTMetal) {
return ::tt::runtime::ttmetal::getMemoryView(device, deviceID);
}
#endif

LOG_FATAL("runtime is not enabled");
}
} // namespace detail

DeviceRuntime getCurrentRuntime() {
Expand Down
43 changes: 43 additions & 0 deletions runtime/lib/ttmetal/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,19 @@ static Tensor createNullTensor() {
return Tensor(nullptr, nullptr, DeviceRuntime::TTMetal);
}

static tt::runtime::MemoryView
createMemoryView(tt::tt_metal::detail::MemoryView const &memoryView) {
return tt::runtime::MemoryView{
.numBanks = memoryView.num_banks,
.totalBytesPerBank = memoryView.total_bytes_per_bank,
.totalBytesAllocatedPerBank = memoryView.total_bytes_allocated_per_bank,
.totalBytesFreePerBank = memoryView.total_bytes_free_per_bank,
.largestContiguousBytesFreePerBank =
memoryView.largest_contiguous_bytes_free_per_bank,
.blockTable = memoryView.block_table,
};
}

Tensor createTensor(std::shared_ptr<void> data,
std::vector<std::uint32_t> const &shape,
std::vector<std::uint32_t> const &stride,
Expand Down Expand Up @@ -116,6 +129,36 @@ void dumpMemoryReport(Device deviceHandle) {
}
}

std::unordered_map<tt::runtime::MemoryBufferType, tt::runtime::MemoryView>
getMemoryView(Device deviceHandle, int deviceID) {
std::unordered_map<tt::runtime::MemoryBufferType, tt::runtime::MemoryView>
memoryMap;
::tt::tt_metal::distributed::MeshDevice &meshDevice =
deviceHandle.as<::tt::tt_metal::distributed::MeshDevice>(
DeviceRuntime::TTMetal);

auto device = meshDevice.get_device(deviceID);

auto dramMemoryView = ::tt::tt_metal::detail::GetMemoryView(
device, tt::tt_metal::BufferType::DRAM);
auto l1MemoryView = ::tt::tt_metal::detail::GetMemoryView(
device, tt::tt_metal::BufferType::L1);
auto l1SmallMemoryView = ::tt::tt_metal::detail::GetMemoryView(
device, tt::tt_metal::BufferType::L1_SMALL);
auto traceMemoryView = ::tt::tt_metal::detail::GetMemoryView(
device, tt::tt_metal::BufferType::TRACE);

memoryMap[tt::runtime::MemoryBufferType::DRAM] =
createMemoryView(dramMemoryView);
memoryMap[tt::runtime::MemoryBufferType::L1] = createMemoryView(l1MemoryView);
memoryMap[tt::runtime::MemoryBufferType::L1_SMALL] =
createMemoryView(l1SmallMemoryView);
memoryMap[tt::runtime::MemoryBufferType::TRACE] =
createMemoryView(traceMemoryView);

return memoryMap;
}

void wait(Event event) {
Events events = event.as<Events>(DeviceRuntime::TTMetal);
for (auto e : events) {
Expand Down
42 changes: 42 additions & 0 deletions runtime/lib/ttnn/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,19 @@ static DeviceVariant getTargetDevice(::ttnn::MeshDevice &meshDevice) {
return std::ref(meshDevice);
}

static tt::runtime::MemoryView
createMemoryView(tt::tt_metal::detail::MemoryView const &memoryView) {
return tt::runtime::MemoryView{
.numBanks = memoryView.num_banks,
.totalBytesPerBank = memoryView.total_bytes_per_bank,
.totalBytesAllocatedPerBank = memoryView.total_bytes_allocated_per_bank,
.totalBytesFreePerBank = memoryView.total_bytes_free_per_bank,
.largestContiguousBytesFreePerBank =
memoryView.largest_contiguous_bytes_free_per_bank,
.blockTable = memoryView.block_table,
};
}

static ::tt::target::ttnn::TTNNBinary const *getBinary(Flatbuffer binary) {
bool isTTNN = ::tt::target::ttnn::SizePrefixedTTNNBinaryBufferHasIdentifier(
binary.handle.get());
Expand Down Expand Up @@ -225,6 +238,35 @@ void dumpMemoryReport(Device deviceHandle) {
}
}

std::unordered_map<tt::runtime::MemoryBufferType, tt::runtime::MemoryView>
getMemoryView(Device deviceHandle, int deviceID) {
std::unordered_map<tt::runtime::MemoryBufferType, tt::runtime::MemoryView>
memoryMap;
::ttnn::MeshDevice &meshDevice =
deviceHandle.as<::ttnn::MeshDevice>(DeviceRuntime::TTNN);

auto device = meshDevice.get_device(deviceID);

auto dramMemoryView = ::tt::tt_metal::detail::GetMemoryView(
device, tt::tt_metal::BufferType::DRAM);
auto l1MemoryView = ::tt::tt_metal::detail::GetMemoryView(
device, tt::tt_metal::BufferType::L1);
auto l1SmallMemoryView = ::tt::tt_metal::detail::GetMemoryView(
device, tt::tt_metal::BufferType::L1_SMALL);
auto traceMemoryView = ::tt::tt_metal::detail::GetMemoryView(
device, tt::tt_metal::BufferType::TRACE);

memoryMap[tt::runtime::MemoryBufferType::DRAM] =
createMemoryView(dramMemoryView);
memoryMap[tt::runtime::MemoryBufferType::L1] = createMemoryView(l1MemoryView);
memoryMap[tt::runtime::MemoryBufferType::L1_SMALL] =
createMemoryView(l1SmallMemoryView);
memoryMap[tt::runtime::MemoryBufferType::TRACE] =
createMemoryView(traceMemoryView);

return memoryMap;
}

void wait(Event event) {
// Nothing to do for ttnn runtime
LOG_ASSERT(event.matchesRuntime(DeviceRuntime::TTNN));
Expand Down
Loading

0 comments on commit d8bbeb0

Please sign in to comment.