File tree Expand file tree Collapse file tree 3 files changed +17
-3
lines changed Expand file tree Collapse file tree 3 files changed +17
-3
lines changed Original file line number Diff line number Diff line change @@ -28,9 +28,13 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_
2828
2929 std::vector<const llama_grammar_element *> grammar_rules (result->parsed_grammar .c_rules ());
3030
31- result-> grammar = llama_grammar_init (
31+ struct llama_grammar * grammar = llama_grammar_init (
3232 grammar_rules.data (),
3333 grammar_rules.size (), result->parsed_grammar .symbol_ids .at (" root" ));
34+ if (grammar == nullptr ) {
35+ throw std::runtime_error (" Failed to initialize llama_grammar" );
36+ }
37+ result->grammar = grammar;
3438 }
3539
3640 result->prev .resize (params.n_prev );
@@ -59,9 +63,13 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
5963 if (!ctx->parsed_grammar .rules .empty ()) {
6064 std::vector<const llama_grammar_element *> grammar_rules (ctx->parsed_grammar .c_rules ());
6165
62- ctx-> grammar = llama_grammar_init (
66+ struct llama_grammar * grammar = llama_grammar_init (
6367 grammar_rules.data (),
6468 grammar_rules.size (), ctx->parsed_grammar .symbol_ids .at (" root" ));
69+ if (grammar == nullptr ) {
70+ throw std::runtime_error (" Failed to initialize llama_grammar" );
71+ }
72+ ctx->grammar = grammar;
6573 }
6674
6775 std::fill (ctx->prev .begin (), ctx->prev .end (), 0 );
Original file line number Diff line number Diff line change @@ -101,7 +101,9 @@ int main(int argc, char** argv) {
101101 auto grammar = llama_grammar_init (
102102 grammar_rules.data (),
103103 grammar_rules.size (), parsed_grammar.symbol_ids .at (" root" ));
104-
104+ if (grammar == nullptr ) {
105+ throw std::runtime_error (" Failed to initialize llama_grammar" );
106+ }
105107 // Read the input file
106108 std::string input_str;
107109 {
Original file line number Diff line number Diff line change @@ -116,6 +116,10 @@ int main()
116116 std::vector<const llama_grammar_element *> grammar_rules (parsed_grammar.c_rules ());
117117 grammar = llama_grammar_init (
118118 grammar_rules.data (), grammar_rules.size (), parsed_grammar.symbol_ids .at (" root" ));
119+ if (grammar == nullptr )
120+ {
121+ throw std::runtime_error (" Failed to initialize llama_grammar" );
122+ }
119123
120124 std::vector<std::vector<llama_grammar_element>> expected_stacks = {
121125 {
You can’t perform that action at this time.
0 commit comments