diff --git a/runtime/lib/common/system_desc.cpp b/runtime/lib/common/system_desc.cpp index ad86edf47..f11dcc09c 100644 --- a/runtime/lib/common/system_desc.cpp +++ b/runtime/lib/common/system_desc.cpp @@ -262,7 +262,9 @@ std::pair<::tt::runtime::SystemDesc, DeviceIds> getCurrentSystemDesc() { ::tt::tt_metal::distributed::MeshShape meshShape = {1, numDevices}; std::shared_ptr<::tt::tt_metal::distributed::MeshDevice> meshDevice = ::tt::tt_metal::distributed::MeshDevice::create( - meshShape, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, 1, + ::tt::tt_metal::distributed::MeshDeviceConfig{.mesh_shape = + meshShape}, + DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, 1, ::tt::tt_metal::DispatchCoreType::WORKER); std::exception_ptr eptr = nullptr; std::unique_ptr<::tt::runtime::SystemDesc> desc; @@ -271,7 +273,7 @@ std::pair<::tt::runtime::SystemDesc, DeviceIds> getCurrentSystemDesc() { } catch (...) { eptr = std::current_exception(); } - meshDevice->close_devices(); + meshDevice->close(); if (eptr) { std::rethrow_exception(eptr); } diff --git a/runtime/lib/ttmetal/runtime.cpp b/runtime/lib/ttmetal/runtime.cpp index 0e84c17ce..f2702b52a 100644 --- a/runtime/lib/ttmetal/runtime.cpp +++ b/runtime/lib/ttmetal/runtime.cpp @@ -74,7 +74,8 @@ Device openDevice(DeviceIds const &deviceIds, size_t numHWCQs, size_t l1SmallSizeValue = l1SmallSize.value_or(DEFAULT_L1_SMALL_SIZE); std::shared_ptr<::tt::tt_metal::distributed::MeshDevice> meshDevice = ::tt::tt_metal::distributed::MeshDevice::create( - grid, l1SmallSizeValue, DEFAULT_TRACE_REGION_SIZE, numHWCQs, + ::tt::tt_metal::distributed::MeshDeviceConfig{.mesh_shape = grid}, + l1SmallSizeValue, DEFAULT_TRACE_REGION_SIZE, numHWCQs, ::tt::tt_metal::DispatchCoreType::WORKER); return Device(std::static_pointer_cast(meshDevice), @@ -92,7 +93,7 @@ void closeDevice(Device device) { ::tt::tt_metal::detail::DumpDeviceProfileResults(ttmetalDevice); } #endif - ttmetalMeshDevice.close_devices(); + ttmetalMeshDevice.close(); } void deallocateBuffers(Device deviceHandle) { diff --git a/runtime/lib/ttnn/runtime.cpp b/runtime/lib/ttnn/runtime.cpp index 58b694482..72378f881 100644 --- a/runtime/lib/ttnn/runtime.cpp +++ b/runtime/lib/ttnn/runtime.cpp @@ -184,7 +184,8 @@ Device openDevice(DeviceIds const &deviceIds, size_t numHWCQs, ::tt::tt_metal::distributed::MeshShape grid = {1, deviceIds.size()}; size_t l1SmallSizeValue = l1SmallSize.value_or(kL1SmallSize); std::shared_ptr<::ttnn::MeshDevice> meshDevice = ::ttnn::MeshDevice::create( - grid, l1SmallSizeValue, DEFAULT_TRACE_REGION_SIZE, numHWCQs, + ::tt::tt_metal::distributed::MeshDeviceConfig{.mesh_shape = grid}, + l1SmallSizeValue, DEFAULT_TRACE_REGION_SIZE, numHWCQs, ::tt::tt_metal::DispatchCoreType::WORKER); bool enableAsync = debug::Env::get().enableAsyncTTNN; @@ -205,7 +206,7 @@ void closeDevice(Device device) { } #endif - ttnnMeshDevice.close_devices(); + ttnnMeshDevice.close(); } void deallocateBuffers(Device deviceHandle) { diff --git a/third_party/CMakeLists.txt b/third_party/CMakeLists.txt index b1bf234e6..42ebb830b 100644 --- a/third_party/CMakeLists.txt +++ b/third_party/CMakeLists.txt @@ -1,6 +1,6 @@ include(ExternalProject) -set(TT_METAL_VERSION "5d4c047dacf2606dd56c7b4d51d5049bf2c6846a") +set(TT_METAL_VERSION "3a9637d52b003d5d5eda455cc72fbeb75a689f32") if ("$ENV{ARCH_NAME}" STREQUAL "grayskull") set(ARCH_NAME "grayskull")