diff --git a/runtime/include/tt/runtime/detail/ttmetal.h b/runtime/include/tt/runtime/detail/ttmetal.h index 187fa5542f..aa7808bac0 100644 --- a/runtime/include/tt/runtime/detail/ttmetal.h +++ b/runtime/include/tt/runtime/detail/ttmetal.h @@ -35,7 +35,8 @@ tt::target::DataType getTensorDataType(Tensor tensor); size_t getNumAvailableDevices(); -Device openDevice(DeviceIds const &deviceIds, size_t numHWCQs = 1); +Device openDevice(DeviceIds const &deviceIds, size_t numHWCQs = 1, + std::optional l1SmallSize = std::nullopt); void closeDevice(Device device); diff --git a/runtime/include/tt/runtime/detail/ttnn.h b/runtime/include/tt/runtime/detail/ttnn.h index b1007d4057..081ef02fe6 100644 --- a/runtime/include/tt/runtime/detail/ttnn.h +++ b/runtime/include/tt/runtime/detail/ttnn.h @@ -86,7 +86,8 @@ tt::target::DataType getTensorDataType(Tensor tensor); size_t getNumAvailableDevices(); -Device openDevice(DeviceIds const &deviceIds, size_t numHWCQs = 1); +Device openDevice(DeviceIds const &deviceIds, size_t numHWCQs = 1, + std::optional l1SmallSize = std::nullopt); void closeDevice(Device device); diff --git a/runtime/include/tt/runtime/runtime.h b/runtime/include/tt/runtime/runtime.h index 2f278ffc1c..7725e4b565 100644 --- a/runtime/include/tt/runtime/runtime.h +++ b/runtime/include/tt/runtime/runtime.h @@ -71,7 +71,8 @@ tt::target::DataType getTensorDataType(Tensor tensor); size_t getNumAvailableDevices(); -Device openDevice(DeviceIds const &deviceIds, size_t numHWCQs = 1); +Device openDevice(DeviceIds const &deviceIds, size_t numHWCQs = 1, + std::optional l1SmallSize = std::nullopt); void closeDevice(Device device); diff --git a/runtime/lib/runtime.cpp b/runtime/lib/runtime.cpp index c25cfed51b..b0ac1ee43e 100644 --- a/runtime/lib/runtime.cpp +++ b/runtime/lib/runtime.cpp @@ -216,16 +216,17 @@ size_t getNumAvailableDevices() { LOG_FATAL("runtime is not enabled"); } -Device openDevice(DeviceIds const &deviceIds, size_t numHWCQs) { +Device openDevice(DeviceIds const &deviceIds, size_t numHWCQs, + std::optional l1SmallSize) { #if defined(TT_RUNTIME_ENABLE_TTNN) if (getCurrentRuntime() == DeviceRuntime::TTNN) { - return ::tt::runtime::ttnn::openDevice(deviceIds, numHWCQs); + return ::tt::runtime::ttnn::openDevice(deviceIds, numHWCQs, l1SmallSize); } #endif #if defined(TT_RUNTIME_ENABLE_TTMETAL) if (getCurrentRuntime() == DeviceRuntime::TTMetal) { - return ::tt::runtime::ttmetal::openDevice(deviceIds, numHWCQs); + return ::tt::runtime::ttmetal::openDevice(deviceIds, numHWCQs, l1SmallSize); } #endif LOG_FATAL("runtime is not enabled"); diff --git a/runtime/lib/ttmetal/runtime.cpp b/runtime/lib/ttmetal/runtime.cpp index 9cca242a58..0e84c17cef 100644 --- a/runtime/lib/ttmetal/runtime.cpp +++ b/runtime/lib/ttmetal/runtime.cpp @@ -66,13 +66,15 @@ size_t getNumAvailableDevices() { return ::tt::tt_metal::GetNumAvailableDevices(); } -Device openDevice(DeviceIds const &deviceIds, size_t numHWCQs) { +Device openDevice(DeviceIds const &deviceIds, size_t numHWCQs, + std::optional l1SmallSize) { LOG_ASSERT(deviceIds.size(), "No devices specified"); ::tt::tt_metal::distributed::MeshShape grid = {1, deviceIds.size()}; + size_t l1SmallSizeValue = l1SmallSize.value_or(DEFAULT_L1_SMALL_SIZE); std::shared_ptr<::tt::tt_metal::distributed::MeshDevice> meshDevice = ::tt::tt_metal::distributed::MeshDevice::create( - grid, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, numHWCQs, + grid, l1SmallSizeValue, DEFAULT_TRACE_REGION_SIZE, numHWCQs, ::tt::tt_metal::DispatchCoreType::WORKER); return Device(std::static_pointer_cast(meshDevice), diff --git a/runtime/lib/ttnn/runtime.cpp b/runtime/lib/ttnn/runtime.cpp index c527a94d7e..58b694482e 100644 --- a/runtime/lib/ttnn/runtime.cpp +++ b/runtime/lib/ttnn/runtime.cpp @@ -178,11 +178,13 @@ size_t getNumAvailableDevices() { return ::tt::tt_metal::GetNumAvailableDevices(); } -Device openDevice(DeviceIds const &deviceIds, size_t numHWCQs) { +Device openDevice(DeviceIds const &deviceIds, size_t numHWCQs, + std::optional l1SmallSize) { LOG_ASSERT(deviceIds.size(), "No devices specified"); ::tt::tt_metal::distributed::MeshShape grid = {1, deviceIds.size()}; + size_t l1SmallSizeValue = l1SmallSize.value_or(kL1SmallSize); std::shared_ptr<::ttnn::MeshDevice> meshDevice = ::ttnn::MeshDevice::create( - grid, kL1SmallSize, DEFAULT_TRACE_REGION_SIZE, numHWCQs, + grid, l1SmallSizeValue, DEFAULT_TRACE_REGION_SIZE, numHWCQs, ::tt::tt_metal::DispatchCoreType::WORKER); bool enableAsync = debug::Env::get().enableAsyncTTNN; diff --git a/runtime/tools/python/ttrt/runtime/module.cpp b/runtime/tools/python/ttrt/runtime/module.cpp index 4c3eb8c690..f55b6b81e1 100644 --- a/runtime/tools/python/ttrt/runtime/module.cpp +++ b/runtime/tools/python/ttrt/runtime/module.cpp @@ -99,6 +99,7 @@ PYBIND11_MODULE(_C, m) { "Get the number of available devices"); m.def("open_device", &tt::runtime::openDevice, py::arg("device_ids"), py::arg("num_hw_cqs") = size_t{1}, + py::arg("l1_small_size") = py::none(), "Open a mesh of devices for execution"); m.def("close_device", &tt::runtime::closeDevice, "Close a mesh device"); m.def("to_host", &tt::runtime::toHost, py::arg("tensor"),