Skip to content

Commit ad4cdd6

Browse files
authored
Support OpenAI reasoning models (#1841)
* Update tiktoken * Add max_completion_tokens to model config * Update/remove outdated comments * Remove max_tokens from report generation * Remove max_tokens from entity summarization * Remove logit_bias from graph extraction * Remove logit_bias from claim extraction * Swap params if reasoning model * Add reasoning model support to basic search * Add reasoning model support for local and global search * Support reasoning models with dynamic community selection * Support reasoning models in DRIFT search * Remove unused num_threads entry * Semver * Update openai * Add reasoning_effort param
1 parent 74ad1d4 commit ad4cdd6

File tree

60 files changed

+424
-616
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

60 files changed

+424
-616
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "minor",
3+
"description": "Support OpenAI reasoning models."
4+
}

graphrag/config/defaults.py

Lines changed: 13 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,7 @@ class BasicSearchDefaults:
4141
"""Default values for basic search."""
4242

4343
prompt: None = None
44-
text_unit_prop: float = 0.5
45-
conversation_history_max_turns: int = 5
46-
temperature: float = 0
47-
top_p: float = 1
48-
n: int = 1
49-
max_tokens: int = 12_000
50-
llm_max_tokens: int = 2000
44+
k: int = 10
5145
chat_model_id: str = DEFAULT_CHAT_MODEL_ID
5246
embedding_model_id: str = DEFAULT_EMBEDDING_MODEL_ID
5347

@@ -104,13 +98,10 @@ class DriftSearchDefaults:
10498

10599
prompt: None = None
106100
reduce_prompt: None = None
107-
temperature: float = 0
108-
top_p: float = 1
109-
n: int = 1
110-
max_tokens: int = 12_000
111101
data_max_tokens: int = 12_000
112-
reduce_max_tokens: int = 2_000
102+
reduce_max_tokens: None = None
113103
reduce_temperature: float = 0
104+
reduce_max_completion_tokens: None = None
114105
concurrency: int = 32
115106
drift_k_followups: int = 20
116107
primer_folds: int = 5
@@ -124,7 +115,8 @@ class DriftSearchDefaults:
124115
local_search_temperature: float = 0
125116
local_search_top_p: float = 1
126117
local_search_n: int = 1
127-
local_search_llm_max_gen_tokens: int = 4_096
118+
local_search_llm_max_gen_tokens = None
119+
local_search_llm_max_gen_completion_tokens = None
128120
chat_model_id: str = DEFAULT_CHAT_MODEL_ID
129121
embedding_model_id: str = DEFAULT_EMBEDDING_MODEL_ID
130122

@@ -168,7 +160,6 @@ class ExtractClaimsDefaults:
168160
)
169161
max_gleanings: int = 1
170162
strategy: None = None
171-
encoding_model: None = None
172163
model_id: str = DEFAULT_CHAT_MODEL_ID
173164

174165

@@ -182,7 +173,6 @@ class ExtractGraphDefaults:
182173
)
183174
max_gleanings: int = 1
184175
strategy: None = None
185-
encoding_model: None = None
186176
model_id: str = DEFAULT_CHAT_MODEL_ID
187177

188178

@@ -228,20 +218,14 @@ class GlobalSearchDefaults:
228218
map_prompt: None = None
229219
reduce_prompt: None = None
230220
knowledge_prompt: None = None
231-
temperature: float = 0
232-
top_p: float = 1
233-
n: int = 1
234-
max_tokens: int = 12_000
221+
max_context_tokens: int = 12_000
235222
data_max_tokens: int = 12_000
236-
map_max_tokens: int = 1000
237-
reduce_max_tokens: int = 2000
238-
concurrency: int = 32
239-
dynamic_search_llm: str = "gpt-4o-mini"
223+
map_max_length: int = 1000
224+
reduce_max_length: int = 2000
240225
dynamic_search_threshold: int = 1
241226
dynamic_search_keep_parent: bool = False
242227
dynamic_search_num_repeats: int = 1
243228
dynamic_search_use_summary: bool = False
244-
dynamic_search_concurrent_coroutines: int = 16
245229
dynamic_search_max_level: int = 2
246230
chat_model_id: str = DEFAULT_CHAT_MODEL_ID
247231

@@ -271,8 +255,10 @@ class LanguageModelDefaults:
271255
api_key: None = None
272256
auth_type = AuthType.APIKey
273257
encoding_model: str = ""
274-
max_tokens: int = 4000
258+
max_tokens: int | None = None
275259
temperature: float = 0
260+
max_completion_tokens: int | None = None
261+
reasoning_effort: str | None = None
276262
top_p: float = 1
277263
n: int = 1
278264
frequency_penalty: float = 0.0
@@ -305,11 +291,7 @@ class LocalSearchDefaults:
305291
conversation_history_max_turns: int = 5
306292
top_k_entities: int = 10
307293
top_k_relationships: int = 10
308-
temperature: float = 0
309-
top_p: float = 1
310-
n: int = 1
311-
max_tokens: int = 12_000
312-
llm_max_tokens: int = 2000
294+
max_context_tokens: int = 12_000
313295
chat_model_id: str = DEFAULT_CHAT_MODEL_ID
314296
embedding_model_id: str = DEFAULT_EMBEDDING_MODEL_ID
315297

@@ -364,6 +346,7 @@ class SummarizeDescriptionsDefaults:
364346

365347
prompt: None = None
366348
max_length: int = 500
349+
max_input_tokens: int = 4_000
367350
strategy: None = None
368351
model_id: str = DEFAULT_CHAT_MODEL_ID
369352

graphrag/config/models/basic_search_config.py

Lines changed: 3 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -23,31 +23,7 @@ class BasicSearchConfig(BaseModel):
2323
description="The model ID to use for text embeddings.",
2424
default=graphrag_config_defaults.basic_search.embedding_model_id,
2525
)
26-
text_unit_prop: float = Field(
27-
description="The text unit proportion.",
28-
default=graphrag_config_defaults.basic_search.text_unit_prop,
29-
)
30-
conversation_history_max_turns: int = Field(
31-
description="The conversation history maximum turns.",
32-
default=graphrag_config_defaults.basic_search.conversation_history_max_turns,
33-
)
34-
temperature: float = Field(
35-
description="The temperature to use for token generation.",
36-
default=graphrag_config_defaults.basic_search.temperature,
37-
)
38-
top_p: float = Field(
39-
description="The top-p value to use for token generation.",
40-
default=graphrag_config_defaults.basic_search.top_p,
41-
)
42-
n: int = Field(
43-
description="The number of completions to generate.",
44-
default=graphrag_config_defaults.basic_search.n,
45-
)
46-
max_tokens: int = Field(
47-
description="The maximum tokens.",
48-
default=graphrag_config_defaults.basic_search.max_tokens,
49-
)
50-
llm_max_tokens: int = Field(
51-
description="The LLM maximum tokens.",
52-
default=graphrag_config_defaults.basic_search.llm_max_tokens,
26+
k: int = Field(
27+
description="The number of text units to include in search context.",
28+
default=graphrag_config_defaults.basic_search.k,
5329
)

graphrag/config/models/community_reports_config.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ def resolved_strategy(
5050
return self.strategy or {
5151
"type": CreateCommunityReportsStrategyType.graph_intelligence,
5252
"llm": model_config.model_dump(),
53-
"num_threads": model_config.concurrent_requests,
5453
"graph_prompt": (Path(root_dir) / self.graph_prompt).read_text(
5554
encoding="utf-8"
5655
)

graphrag/config/models/drift_search_config.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -27,28 +27,12 @@ class DRIFTSearchConfig(BaseModel):
2727
description="The model ID to use for drift search.",
2828
default=graphrag_config_defaults.drift_search.embedding_model_id,
2929
)
30-
temperature: float = Field(
31-
description="The temperature to use for token generation.",
32-
default=graphrag_config_defaults.drift_search.temperature,
33-
)
34-
top_p: float = Field(
35-
description="The top-p value to use for token generation.",
36-
default=graphrag_config_defaults.drift_search.top_p,
37-
)
38-
n: int = Field(
39-
description="The number of completions to generate.",
40-
default=graphrag_config_defaults.drift_search.n,
41-
)
42-
max_tokens: int = Field(
43-
description="The maximum context size in tokens.",
44-
default=graphrag_config_defaults.drift_search.max_tokens,
45-
)
4630
data_max_tokens: int = Field(
4731
description="The data llm maximum tokens.",
4832
default=graphrag_config_defaults.drift_search.data_max_tokens,
4933
)
5034

51-
reduce_max_tokens: int = Field(
35+
reduce_max_tokens: int | None = Field(
5236
description="The reduce llm maximum tokens response to produce.",
5337
default=graphrag_config_defaults.drift_search.reduce_max_tokens,
5438
)
@@ -58,6 +42,11 @@ class DRIFTSearchConfig(BaseModel):
5842
default=graphrag_config_defaults.drift_search.reduce_temperature,
5943
)
6044

45+
reduce_max_completion_tokens: int | None = Field(
46+
description="The reduce llm maximum tokens response to produce.",
47+
default=graphrag_config_defaults.drift_search.reduce_max_completion_tokens,
48+
)
49+
6150
concurrency: int = Field(
6251
description="The number of concurrent requests.",
6352
default=graphrag_config_defaults.drift_search.concurrency,
@@ -123,7 +112,12 @@ class DRIFTSearchConfig(BaseModel):
123112
default=graphrag_config_defaults.drift_search.local_search_n,
124113
)
125114

126-
local_search_llm_max_gen_tokens: int = Field(
115+
local_search_llm_max_gen_tokens: int | None = Field(
127116
description="The maximum number of generated tokens for the LLM in local search.",
128117
default=graphrag_config_defaults.drift_search.local_search_llm_max_gen_tokens,
129118
)
119+
120+
local_search_llm_max_gen_completion_tokens: int | None = Field(
121+
description="The maximum number of generated tokens for the LLM in local search.",
122+
default=graphrag_config_defaults.drift_search.local_search_llm_max_gen_completion_tokens,
123+
)

graphrag/config/models/extract_claims_config.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,24 +38,18 @@ class ClaimExtractionConfig(BaseModel):
3838
description="The override strategy to use.",
3939
default=graphrag_config_defaults.extract_claims.strategy,
4040
)
41-
encoding_model: str | None = Field(
42-
default=graphrag_config_defaults.extract_claims.encoding_model,
43-
description="The encoding model to use.",
44-
)
4541

4642
def resolved_strategy(
4743
self, root_dir: str, model_config: LanguageModelConfig
4844
) -> dict:
4945
"""Get the resolved claim extraction strategy."""
5046
return self.strategy or {
5147
"llm": model_config.model_dump(),
52-
"num_threads": model_config.concurrent_requests,
5348
"extraction_prompt": (Path(root_dir) / self.prompt).read_text(
5449
encoding="utf-8"
5550
)
5651
if self.prompt
5752
else None,
5853
"claim_description": self.description,
5954
"max_gleanings": self.max_gleanings,
60-
"encoding_name": model_config.encoding_model,
6155
}

graphrag/config/models/extract_graph_config.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,6 @@ class ExtractGraphConfig(BaseModel):
3434
description="Override the default entity extraction strategy",
3535
default=graphrag_config_defaults.extract_graph.strategy,
3636
)
37-
encoding_model: str | None = Field(
38-
default=graphrag_config_defaults.extract_graph.encoding_model,
39-
description="The encoding model to use.",
40-
)
4137

4238
def resolved_strategy(
4339
self, root_dir: str, model_config: LanguageModelConfig
@@ -50,12 +46,10 @@ def resolved_strategy(
5046
return self.strategy or {
5147
"type": ExtractEntityStrategyType.graph_intelligence,
5248
"llm": model_config.model_dump(),
53-
"num_threads": model_config.concurrent_requests,
5449
"extraction_prompt": (Path(root_dir) / self.prompt).read_text(
5550
encoding="utf-8"
5651
)
5752
if self.prompt
5853
else None,
5954
"max_gleanings": self.max_gleanings,
60-
"encoding_name": model_config.encoding_model,
6155
}

graphrag/config/models/global_search_config.py

Lines changed: 8 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -27,44 +27,24 @@ class GlobalSearchConfig(BaseModel):
2727
description="The global search general prompt to use.",
2828
default=graphrag_config_defaults.global_search.knowledge_prompt,
2929
)
30-
temperature: float = Field(
31-
description="The temperature to use for token generation.",
32-
default=graphrag_config_defaults.global_search.temperature,
33-
)
34-
top_p: float = Field(
35-
description="The top-p value to use for token generation.",
36-
default=graphrag_config_defaults.global_search.top_p,
37-
)
38-
n: int = Field(
39-
description="The number of completions to generate.",
40-
default=graphrag_config_defaults.global_search.n,
41-
)
42-
max_tokens: int = Field(
30+
max_context_tokens: int = Field(
4331
description="The maximum context size in tokens.",
44-
default=graphrag_config_defaults.global_search.max_tokens,
32+
default=graphrag_config_defaults.global_search.max_context_tokens,
4533
)
4634
data_max_tokens: int = Field(
4735
description="The data llm maximum tokens.",
4836
default=graphrag_config_defaults.global_search.data_max_tokens,
4937
)
50-
map_max_tokens: int = Field(
51-
description="The map llm maximum tokens.",
52-
default=graphrag_config_defaults.global_search.map_max_tokens,
38+
map_max_length: int = Field(
39+
description="The map llm maximum response length in words.",
40+
default=graphrag_config_defaults.global_search.map_max_length,
5341
)
54-
reduce_max_tokens: int = Field(
55-
description="The reduce llm maximum tokens.",
56-
default=graphrag_config_defaults.global_search.reduce_max_tokens,
57-
)
58-
concurrency: int = Field(
59-
description="The number of concurrent requests.",
60-
default=graphrag_config_defaults.global_search.concurrency,
42+
reduce_max_length: int = Field(
43+
description="The reduce llm maximum response length in words.",
44+
default=graphrag_config_defaults.global_search.reduce_max_length,
6145
)
6246

6347
# configurations for dynamic community selection
64-
dynamic_search_llm: str = Field(
65-
description="LLM model to use for dynamic community selection",
66-
default=graphrag_config_defaults.global_search.dynamic_search_llm,
67-
)
6848
dynamic_search_threshold: int = Field(
6949
description="Rating threshold in include a community report",
7050
default=graphrag_config_defaults.global_search.dynamic_search_threshold,
@@ -81,10 +61,6 @@ class GlobalSearchConfig(BaseModel):
8161
description="Use community summary instead of full_context",
8262
default=graphrag_config_defaults.global_search.dynamic_search_use_summary,
8363
)
84-
dynamic_search_concurrent_coroutines: int = Field(
85-
description="Number of concurrent coroutines to rate community reports",
86-
default=graphrag_config_defaults.global_search.dynamic_search_concurrent_coroutines,
87-
)
8864
dynamic_search_max_level: int = Field(
8965
description="The maximum level of community hierarchy to consider if none of the processed communities are relevant",
9066
default=graphrag_config_defaults.global_search.dynamic_search_max_level,

graphrag/config/models/language_model_config.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,14 +223,22 @@ def _validate_deployment_name(self) -> None:
223223
default=language_model_defaults.responses,
224224
description="Static responses to use in mock mode.",
225225
)
226-
max_tokens: int = Field(
226+
max_tokens: int | None = Field(
227227
description="The maximum number of tokens to generate.",
228228
default=language_model_defaults.max_tokens,
229229
)
230230
temperature: float = Field(
231231
description="The temperature to use for token generation.",
232232
default=language_model_defaults.temperature,
233233
)
234+
max_completion_tokens: int | None = Field(
235+
description="The maximum number of tokens to consume. This includes reasoning tokens for the o* reasoning models.",
236+
default=language_model_defaults.max_completion_tokens,
237+
)
238+
reasoning_effort: str | None = Field(
239+
description="Level of effort OpenAI reasoning models should expend. Supported options are 'low', 'medium', 'high'; and OAI defaults to 'medium'.",
240+
default=language_model_defaults.reasoning_effort,
241+
)
234242
top_p: float = Field(
235243
description="The top-p value to use for token generation.",
236244
default=language_model_defaults.top_p,

graphrag/config/models/local_search_config.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -43,23 +43,7 @@ class LocalSearchConfig(BaseModel):
4343
description="The top k mapped relations.",
4444
default=graphrag_config_defaults.local_search.top_k_relationships,
4545
)
46-
temperature: float = Field(
47-
description="The temperature to use for token generation.",
48-
default=graphrag_config_defaults.local_search.temperature,
49-
)
50-
top_p: float = Field(
51-
description="The top-p value to use for token generation.",
52-
default=graphrag_config_defaults.local_search.top_p,
53-
)
54-
n: int = Field(
55-
description="The number of completions to generate.",
56-
default=graphrag_config_defaults.local_search.n,
57-
)
58-
max_tokens: int = Field(
46+
max_context_tokens: int = Field(
5947
description="The maximum tokens.",
60-
default=graphrag_config_defaults.local_search.max_tokens,
61-
)
62-
llm_max_tokens: int = Field(
63-
description="The LLM maximum tokens.",
64-
default=graphrag_config_defaults.local_search.llm_max_tokens,
48+
default=graphrag_config_defaults.local_search.max_context_tokens,
6549
)

0 commit comments

Comments
 (0)