@@ -1355,15 +1355,15 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
13551355    std::vector<int32_t > ids;
13561356    std::vector<ggml_bitset_t > used_ids;
13571357
1358-     for  (int  i  = 0 ; i  < sched->n_splits ; i ++) {
1359-         struct  ggml_backend_sched_split  * split = &splits[i ];
1358+     for  (int  split_id  = 0 ; split_id  < sched->n_splits ; split_id ++) {
1359+         struct  ggml_backend_sched_split  * split = &splits[split_id ];
13601360        int  split_backend_id = split->backend_id ;
13611361        ggml_backend_t  split_backend = sched->backends [split_backend_id];
13621362
13631363        //  copy the input tensors to the split backend
1364-         for  (int  j  = 0 ; j  < split->n_inputs ; j ++) {
1365-             ggml_backend_t  input_backend = ggml_backend_sched_get_tensor_backend (sched, split->inputs [j ]);
1366-             struct  ggml_tensor  * input = split->inputs [j ];
1364+         for  (int  input_id  = 0 ; input_id  < split->n_inputs ; input_id ++) {
1365+             ggml_backend_t  input_backend = ggml_backend_sched_get_tensor_backend (sched, split->inputs [input_id ]);
1366+             struct  ggml_tensor  * input = split->inputs [input_id ];
13671367            struct  ggml_tensor  * input_cpy = tensor_copy (input, split_backend_id, sched->cur_copy );
13681368
13691369            if  (input->flags  & GGML_TENSOR_FLAG_INPUT) {
@@ -1398,17 +1398,30 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
13981398
13991399                    //  get the ids
14001400                    ggml_tensor * ids_tensor = node->src [2 ];
1401+                     ggml_backend_t  ids_backend = split_backend;
1402+ 
1403+                     //  if the ids tensor is also an input of the split, it may not have been copied yet to the split backend
1404+                     //  in that case, we use the original ids tensor
1405+                     for  (int  i = input_id + 1 ; i < split->n_inputs ; i++) {
1406+                         if  (ids_tensor == tensor_copy (split->inputs [i], split_backend_id, sched->cur_copy )) {
1407+                             ids_tensor = split->inputs [i];
1408+                             ids_backend = ggml_backend_sched_get_tensor_backend (sched, split->inputs [i]);
1409+                             break ;
1410+                         }
1411+                     }
1412+ 
14011413                    if  (ids_tensor != prev_ids_tensor) {
14021414                        ids.resize (ggml_nbytes (ids_tensor) / sizeof (int32_t ));
1403-                         ggml_backend_tensor_get_async (split_backend , ids_tensor, ids.data (), 0 , ggml_nbytes (ids_tensor));
1404-                         ggml_backend_synchronize (split_backend );
1415+                         ggml_backend_tensor_get_async (ids_backend , ids_tensor, ids.data (), 0 , ggml_nbytes (ids_tensor));
1416+                         ggml_backend_synchronize (ids_backend );
14051417
14061418                        //  find the used experts
14071419                        used_ids.clear ();
14081420                        used_ids.resize (ggml_bitset_size (n_expert));
14091421                        for  (int64_t  i1 = 0 ; i1 < ids_tensor->ne [1 ]; i1++) {
14101422                            for  (int64_t  i0 = 0 ; i0 < ids_tensor->ne [0 ]; i0++) {
14111423                                int32_t  id = ids[i1 * ids_tensor->nb [1 ]/sizeof (int32_t ) + i0 * ids_tensor->nb [0 ]/sizeof (int32_t )];
1424+                                 GGML_ASSERT (id >= 0  && id < n_expert);
14121425                                ggml_bitset_set (used_ids.data (), id);
14131426                            }
14141427                        }
0 commit comments