@@ -1117,40 +1117,43 @@ static enum ggml_status ggml_backend_cann_buffer_init_tensor(
11171117}
11181118
11191119// ND to NZ Workspace Cache Management. Thread-safety: Not guaranteed
1120- namespace {
1121-
1122- static std::unordered_map<int , void *> g_nz_workspace_map;
1123- static std::unordered_map<int , size_t > g_nz_workspace_allocated_map;
1124-
1125- void release_nz_workspace (int device) {
1126- auto it = g_nz_workspace_map.find (device);
1127- if (it != g_nz_workspace_map.end () && it->second ) {
1128- aclrtFree (it->second );
1129- g_nz_workspace_map.erase (it);
1130- g_nz_workspace_allocated_map.erase (device);
1120+ class NzWorkspace {
1121+ public:
1122+ NzWorkspace () : ptr_(nullptr ), allocated_(0 ) {}
1123+
1124+ // 初始化 / 重置为无效
1125+ void init () {
1126+ if (ptr_) {
1127+ aclrtFree (ptr_);
1128+ ptr_ = nullptr ;
1129+ allocated_ = 0 ;
11311130 }
11321131 }
11331132
1134- void relloc_nz_workspace (int device, size_t new_size) {
1135- void * &workspace = g_nz_workspace_map[device];
1136- size_t &allocated = g_nz_workspace_allocated_map[device];
1137-
1138- if (new_size > allocated) {
1139- if (workspace) {
1140- aclrtFree (workspace);
1141- workspace = nullptr ;
1142- }
1143- ACL_CHECK (aclrtMalloc (&workspace, new_size, ACL_MEM_MALLOC_HUGE_FIRST));
1144- allocated = new_size;
1133+ void realloc (size_t new_size) {
1134+ if (new_size > allocated_) {
1135+ init ();
1136+ ACL_CHECK (aclrtMalloc (&ptr_, new_size, ACL_MEM_MALLOC_HUGE_FIRST));
1137+ allocated_ = new_size;
11451138 }
11461139 }
11471140
1148- void * get_nz_workspace (int device) {
1149- auto it = g_nz_workspace_map.find (device);
1150- return (it != g_nz_workspace_map.end ()) ? it->second : nullptr ;
1141+ void * get () const { return ptr_; }
1142+
1143+ private:
1144+ void * ptr_;
1145+ size_t allocated_;
1146+ };
1147+
1148+ static std::array<NzWorkspace, GGML_CANN_MAX_DEVICES> g_nz_workspaces;
1149+
1150+ inline NzWorkspace& get_workspace (int device) {
1151+ if (device < 0 || device >= static_cast <int >(g_nz_workspaces.size ())) {
1152+ throw std::out_of_range (" device id out of range" );
11511153 }
1154+ return g_nz_workspaces[device];
1155+ }
11521156
1153- } // namespace
11541157
11551158/* *
11561159 * @brief Convert tensor weights to NZ format using Ascend CANN API.
@@ -1176,9 +1179,9 @@ static void weight_format_to_nz(ggml_tensor *tensor, size_t offset, int device)
11761179 ACL_CHECK (aclnnTransMatmulWeightGetWorkspaceSize (weightTransposed,
11771180 &workspaceSize, &executor));
11781181 // Avoid frequent malloc/free of the workspace.
1179- relloc_nz_workspace (device, workspaceSize);
1182+ get_workspace (device). realloc ( workspaceSize);
11801183
1181- void * g_nz_workspace = get_nz_workspace (device);
1184+ void * g_nz_workspace = get_workspace (device). get ( );
11821185
11831186 ACL_CHECK (aclnnTransMatmulWeight (g_nz_workspace, workspaceSize, executor, nullptr ));
11841187 ACL_CHECK (aclDestroyTensor (weightTransposed));
@@ -2259,7 +2262,7 @@ static enum ggml_status ggml_backend_cann_graph_compute(
22592262 ggml_backend_cann_context* cann_ctx =
22602263 (ggml_backend_cann_context*)backend->context ;
22612264 ggml_cann_set_device (cann_ctx->device );
2262- release_nz_workspace (cann_ctx->device );
2265+ get_workspace (cann_ctx->device ). init ( );
22632266
22642267#ifdef USE_ACL_GRAPH
22652268 bool use_cann_graph = true ;
0 commit comments