Skip to content

Commit

Permalink
Add l1 small size as an optional parameter to open device (#1781)
Browse files Browse the repository at this point in the history
L1 small size is a device property and needs to be set when opening
device, adding a parameter so that the user can toggle this as needed.

A future improvement would be to expose an API to the user to calculate
the l1 small size directly, or maybe have a field in the device attr in
the mlir that the compiler calculates and populates.
  • Loading branch information
jnie-TT authored Jan 15, 2025
1 parent d129a9e commit f902beb
Show file tree
Hide file tree
Showing 7 changed files with 19 additions and 10 deletions.
3 changes: 2 additions & 1 deletion runtime/include/tt/runtime/detail/ttmetal.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ tt::target::DataType getTensorDataType(Tensor tensor);

size_t getNumAvailableDevices();

Device openDevice(DeviceIds const &deviceIds, size_t numHWCQs = 1);
Device openDevice(DeviceIds const &deviceIds, size_t numHWCQs = 1,
std::optional<size_t> l1SmallSize = std::nullopt);

void closeDevice(Device device);

Expand Down
3 changes: 2 additions & 1 deletion runtime/include/tt/runtime/detail/ttnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ tt::target::DataType getTensorDataType(Tensor tensor);

size_t getNumAvailableDevices();

Device openDevice(DeviceIds const &deviceIds, size_t numHWCQs = 1);
Device openDevice(DeviceIds const &deviceIds, size_t numHWCQs = 1,
std::optional<size_t> l1SmallSize = std::nullopt);

void closeDevice(Device device);

Expand Down
3 changes: 2 additions & 1 deletion runtime/include/tt/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ tt::target::DataType getTensorDataType(Tensor tensor);

size_t getNumAvailableDevices();

Device openDevice(DeviceIds const &deviceIds, size_t numHWCQs = 1);
Device openDevice(DeviceIds const &deviceIds, size_t numHWCQs = 1,
std::optional<size_t> l1SmallSize = std::nullopt);

void closeDevice(Device device);

Expand Down
7 changes: 4 additions & 3 deletions runtime/lib/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,16 +216,17 @@ size_t getNumAvailableDevices() {
LOG_FATAL("runtime is not enabled");
}

Device openDevice(DeviceIds const &deviceIds, size_t numHWCQs) {
Device openDevice(DeviceIds const &deviceIds, size_t numHWCQs,
std::optional<size_t> l1SmallSize) {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::openDevice(deviceIds, numHWCQs);
return ::tt::runtime::ttnn::openDevice(deviceIds, numHWCQs, l1SmallSize);
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (getCurrentRuntime() == DeviceRuntime::TTMetal) {
return ::tt::runtime::ttmetal::openDevice(deviceIds, numHWCQs);
return ::tt::runtime::ttmetal::openDevice(deviceIds, numHWCQs, l1SmallSize);
}
#endif
LOG_FATAL("runtime is not enabled");
Expand Down
6 changes: 4 additions & 2 deletions runtime/lib/ttmetal/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,15 @@ size_t getNumAvailableDevices() {
return ::tt::tt_metal::GetNumAvailableDevices();
}

Device openDevice(DeviceIds const &deviceIds, size_t numHWCQs) {
Device openDevice(DeviceIds const &deviceIds, size_t numHWCQs,
std::optional<size_t> l1SmallSize) {
LOG_ASSERT(deviceIds.size(), "No devices specified");

::tt::tt_metal::distributed::MeshShape grid = {1, deviceIds.size()};
size_t l1SmallSizeValue = l1SmallSize.value_or(DEFAULT_L1_SMALL_SIZE);
std::shared_ptr<::tt::tt_metal::distributed::MeshDevice> meshDevice =
::tt::tt_metal::distributed::MeshDevice::create(
grid, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, numHWCQs,
grid, l1SmallSizeValue, DEFAULT_TRACE_REGION_SIZE, numHWCQs,
::tt::tt_metal::DispatchCoreType::WORKER);

return Device(std::static_pointer_cast<void>(meshDevice),
Expand Down
6 changes: 4 additions & 2 deletions runtime/lib/ttnn/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,13 @@ size_t getNumAvailableDevices() {
return ::tt::tt_metal::GetNumAvailableDevices();
}

Device openDevice(DeviceIds const &deviceIds, size_t numHWCQs) {
Device openDevice(DeviceIds const &deviceIds, size_t numHWCQs,
std::optional<size_t> l1SmallSize) {
LOG_ASSERT(deviceIds.size(), "No devices specified");
::tt::tt_metal::distributed::MeshShape grid = {1, deviceIds.size()};
size_t l1SmallSizeValue = l1SmallSize.value_or(kL1SmallSize);
std::shared_ptr<::ttnn::MeshDevice> meshDevice = ::ttnn::MeshDevice::create(
grid, kL1SmallSize, DEFAULT_TRACE_REGION_SIZE, numHWCQs,
grid, l1SmallSizeValue, DEFAULT_TRACE_REGION_SIZE, numHWCQs,
::tt::tt_metal::DispatchCoreType::WORKER);

bool enableAsync = debug::Env::get().enableAsyncTTNN;
Expand Down
1 change: 1 addition & 0 deletions runtime/tools/python/ttrt/runtime/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ PYBIND11_MODULE(_C, m) {
"Get the number of available devices");
m.def("open_device", &tt::runtime::openDevice, py::arg("device_ids"),
py::arg("num_hw_cqs") = size_t{1},
py::arg("l1_small_size") = py::none(),
"Open a mesh of devices for execution");
m.def("close_device", &tt::runtime::closeDevice, "Close a mesh device");
m.def("to_host", &tt::runtime::toHost, py::arg("tensor"),
Expand Down

0 comments on commit f902beb

Please sign in to comment.