@@ -768,6 +768,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::
768768 // Iterate and write all the keys first, each row is a cell
769769 // Get whole range at a time
770770 for (uint32_t il = 0 ; il < n_layer; ++il) {
771+ // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null)
772+ if (r_l[il] == nullptr ) continue ;
771773
772774 // Write key type
773775 const int32_t r_type_i = (int32_t )r_l[il]->type ;
@@ -787,6 +789,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::
787789
788790 if (!s_trans) {
789791 for (uint32_t il = 0 ; il < n_layer; ++il) {
792+ // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null)
793+ if (s_l[il] == nullptr ) continue ;
790794
791795 // Write value type
792796 const int32_t s_type_i = (int32_t )s_l[il]->type ;
@@ -807,6 +811,9 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::
807811 // When v is transposed, we also need the element size and get the element ranges from each row
808812 const uint32_t mem_size = size;
809813 for (uint32_t il = 0 ; il < n_layer; ++il) {
814+ // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null)
815+ if (s_l[il] == nullptr ) continue ;
816+
810817 const uint32_t n_embd_s = hparams.n_embd_s ();
811818
812819 // Write value type
@@ -951,6 +958,8 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
951958
952959 // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
953960 for (uint32_t il = 0 ; il < n_layer; ++il) {
961+ // skip null layers
962+ if (r_l[il] == nullptr ) continue ;
954963
955964 // Read type of key
956965 int32_t r_type_i_ref;
@@ -978,11 +987,14 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
978987
979988 if (!s_trans) {
980989 for (uint32_t il = 0 ; il < n_layer; ++il) {
990+ // skip null layers
991+ if (s_l[il] == nullptr ) continue ;
981992
982993 // Read type of value
983994 int32_t s_type_i_ref;
984995 io.read_to (&s_type_i_ref, sizeof (s_type_i_ref));
985996 const int32_t s_type_i = (int32_t )s_l[il]->type ;
997+
986998 if (s_type_i != s_type_i_ref) {
987999 LLAMA_LOG_ERROR (" %s: mismatched s type (%d != %d, layer %d)\n " , __func__, s_type_i, s_type_i_ref, il);
9881000 return false ;
@@ -1005,6 +1017,9 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
10051017 } else {
10061018 // For each layer, read the values for each cell (transposed)
10071019 for (uint32_t il = 0 ; il < n_layer; ++il) {
1020+ // skip null layers
1021+ if (s_l[il] == nullptr ) continue ;
1022+
10081023 const uint32_t n_embd_s = hparams.n_embd_s ();
10091024
10101025 // Read type of value
0 commit comments