@@ -18316,6 +18316,20 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1831618316                            std::max((uint32_t) 1, cparams.n_seq_max),
1831718317                            cparams.n_seq_max);
1831818318                } else if (llm_arch_is_hybrid(arch)) {
18319+ 
18320+                     // The main difference between hybrid architectures is the
18321+                     // layer filters, so pick the right one here
18322+                     llama_memory_hybrid::layer_filter_cb filter_attn = nullptr;
18323+                     llama_memory_hybrid::layer_filter_cb filter_recr = nullptr;
18324+                     if (arch == LLM_ARCH_FALCON_H1) {
18325+                         filter_attn = [&](int32_t) { return true; };
18326+                         filter_recr = [&](int32_t) { return true; };
18327+                     } else if (arch == LLM_ARCH_NEMOTRONH) {
18328+                         filter_attn = [&](int32_t il) {
18329+                             return hparams.is_recurrent(il) && hparams.n_ff(il) == 0;
18330+                         };
18331+                     }
18332+ 
1831918333                    const auto padding = llama_kv_cache::get_padding(cparams);
1832018334
1832118335                    cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
@@ -18335,8 +18349,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1833518349                        /* n_seq_max         */ cparams.n_seq_max,
1833618350                        /* offload           */ cparams.offload_kqv,
1833718351                        /* unified           */ cparams.kv_unified,
18338-                         /* filter_attn       */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr ,
18339-                         /* filter_recr       */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr );
18352+                         /* filter_attn       */ std::move(filter_attn) ,
18353+                         /* filter_recr       */ std::move(filter_recr) );
1834018354                } else {
1834118355                    const auto padding = llama_kv_cache::get_padding(cparams);
1834218356
0 commit comments