Skip to content

Commit

Permalink
Merge pull request #100 from FalkorDB/fix-gemini-config
Browse files Browse the repository at this point in the history
Fixes for Gemini API Configuration Usage
  • Loading branch information
galshubeli authored Mar 5, 2025
2 parents d109884 + 5be9461 commit 1421dc4
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 41 deletions.
4 changes: 2 additions & 2 deletions graphrag_sdk/models/azure_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,11 @@ def ask(self, message: str) -> GenerationResponse:
{"role": "system", "content": self.system_instruction},
{"role": "user", "content": message[:14385]},
],
max_tokens=self.generation_config.max_output_tokens,
max_tokens=self.generation_config.max_tokens,
temperature=self.generation_config.temperature,
top_p=self.generation_config.top_p,
top_k=self.generation_config.top_k,
stop=self.generation_config.stop_sequences,
stop=self.generation_config.stop,
)
return self._parse_generate_content_response(response)

Expand Down
42 changes: 27 additions & 15 deletions graphrag_sdk/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,28 @@ def __init__(
system_instruction: Optional[str] = None,
):
"""
Initialize the GoogleGenerativeModel with required parameters.
Initializes the GoogleGenerativeModel with the specified parameters.
Args:
model_name (str): Name of the GoogleAI model.
model_name (str): The name of the GoogleAI model to use.
generation_config (Optional[GoogleGenerationConfig]): Configuration settings for generation.
system_instruction (Optional[str]): System-level instruction for the model.
If not provided, a default instance of `GoogleGenerationConfig` is used.
system_instruction (Optional[str]): An optional system-level instruction to guide the model’s behavior.
Raises:
TypeError: If `generation_config` is provided but is not an instance of `GoogleGenerationConfig`.
"""
if generation_config is not None and not isinstance(generation_config, GoogleGenerationConfig):
raise TypeError(
"generation_config must be an instance of GoogleGenerationConfig "
"(from google.generativeai import GenerationConfig as GoogleGenerationConfig)."
)

self._model_name = model_name
self._generation_config = generation_config
self._generation_config = generation_config or GoogleGenerationConfig()
self._system_instruction = system_instruction

# Configure the API key for Google Generative AI
configure(api_key=os.environ["GOOGLE_API_KEY"])

def start_chat(self, system_instruction: Optional[str] = None) -> GenerativeModelChatSession:
Expand All @@ -54,13 +67,7 @@ def start_chat(self, system_instruction: Optional[str] = None) -> GenerativeMode
self._model = GoogleGenerativeModel(
self._model_name,
generation_config=(
GoogleGenerationConfig(
temperature=self._generation_config.temperature,
top_p=self._generation_config.top_p,
top_k=self._generation_config.top_k,
max_output_tokens=self._generation_config.max_output_tokens,
stop_sequences=self._generation_config.stop_sequences,
)
self._generation_config
if self._generation_config is not None
else None
),
Expand Down Expand Up @@ -102,7 +109,13 @@ def to_json(self) -> dict:
"""
return {
"model_name": self._model_name,
"generation_config": self._generation_config.to_json(),
"generation_config": {
"temperature": self._generation_config.temperature,
"top_p": self._generation_config.top_p,
"max_output_tokens": self._generation_config.max_output_tokens,
"stop_sequences": self._generation_config.stop_sequences,
"response_mime_type": self._generation_config.response_mime_type,
},
"system_instruction": self._system_instruction,
}

Expand All @@ -117,9 +130,7 @@ def from_json(json: dict) -> "GenerativeModel":
"""
return GeminiGenerativeModel(
model_name=json["model_name"],
generation_config=GenerativeModelConfig.from_json(
json["generation_config"]
),
generation_config=GoogleGenerationConfig(**json["generation_config"]),
system_instruction=json["system_instruction"],
)

Expand Down Expand Up @@ -170,6 +181,7 @@ def _adjust_generation_config(self, output_method: OutputMethod) -> dict:
"response_mime_type": "application/json",
"temperature": 0
}

return self._model._generation_config

def delete_last_message(self):
Expand Down
47 changes: 28 additions & 19 deletions graphrag_sdk/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,44 +14,53 @@ class OutputMethod(Enum):

class GenerativeModelConfig:
"""
Configuration for a generative model
Configuration for a generative model.
This configuration follows OpenAI-style parameter naming but is designed to be compatible with other generative models.
Args:
temperature (Optional[float]): The temperature to use for sampling.
top_p (Optional[float]): The top-p value to use for sampling.
top_k (Optional[int]): The top-k value to use for sampling.
max_output_tokens (Optional[int]): The maximum number of tokens to generate.
stop_sequences (Optional[list[str]]): The stop sequences to use for sampling.
response_format (Optional[dict]): The format of the response.
Examples:
>>> config = GenerativeModelConfig(temperature=0.5, top_p=0.9, top_k=50, max_output_tokens=100, stop_sequences=[".", "?", "!"])
temperature (Optional[float]): Controls the randomness of the output. Higher values (e.g., 1.0) make responses more random,
while lower values (e.g., 0.1) make them more deterministic.
top_p (Optional[float]): Nucleus sampling parameter. A value of 0.9 considers only the top 90% of probability mass.
top_k (Optional[int]): Limits sampling to the top-k most probable tokens.
max_tokens (Optional[int]): The maximum number of tokens the model is allowed to generate in a response.
stop (Optional[list[str]]): A list of stop sequences that signal the model to stop generating further tokens.
response_format (Optional[dict]): Specifies the desired format of the response, if supported by the model.
Example:
>>> config = GenerativeModelConfig(
... temperature=0.5,
... top_p=0.9,
... top_k=50,
... max_tokens=100,
... stop=[".", "?", "!"]
... )
"""

def __init__(
self,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
max_output_tokens: Optional[int] = None,
stop_sequences: Optional[list[str]] = None,
max_tokens: Optional[int] = None,
stop: Optional[list[str]] = None,
response_format: Optional[dict] = None,
):
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
self.max_output_tokens = max_output_tokens
self.stop_sequences = stop_sequences
self.max_tokens = max_tokens
self.stop = stop
self.response_format = response_format

def __str__(self) -> str:
return f"GenerativeModelConfig(temperature={self.temperature}, top_p={self.top_p}, top_k={self.top_k}, max_output_tokens={self.max_output_tokens}, stop_sequences={self.stop_sequences})"
return f"GenerativeModelConfig(temperature={self.temperature}, top_p={self.top_p}, top_k={self.top_k}, max_tokens={self.max_tokens}, stop={self.stop})"

def to_json(self) -> dict:
return {
"temperature": self.temperature,
"top_p": self.top_p,
"max_tokens": self.max_output_tokens,
"stop": self.stop_sequences,
"max_tokens": self.max_tokens,
"stop": self.stop,
"response_format": self.response_format,
}

Expand All @@ -61,8 +70,8 @@ def from_json(json: dict) -> "GenerativeModelConfig":
temperature=json.get("temperature"),
top_p=json.get("top_p"),
top_k=json.get("top_k"),
max_output_tokens=json.get("max_tokens"),
stop_sequences=json.get("stop"),
max_tokens=json.get("max_tokens"),
stop=json.get("stop"),
)


Expand Down
3 changes: 2 additions & 1 deletion tests/test_kg_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from graphrag_sdk.attribute import Attribute, AttributeType
from graphrag_sdk.models.gemini import GeminiGenerativeModel
from graphrag_sdk import KnowledgeGraph, KnowledgeGraphModelConfig
from google.generativeai import GenerationConfig as GoogleGenerationConfig


load_dotenv()
Expand Down Expand Up @@ -71,7 +72,7 @@ def setUpClass(cls):

cls.graph_name = "IMDB_gemini"

model = GeminiGenerativeModel(model_name="gemini-1.5-flash-001")
model = GeminiGenerativeModel(model_name="gemini-1.5-flash-001", generation_config=GoogleGenerationConfig(temperature=0))
cls.kg = KnowledgeGraph(
name=cls.graph_name,
ontology=cls.ontology,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_kg_litellm_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from graphrag_sdk.ontology import Ontology
from graphrag_sdk.attribute import Attribute, AttributeType
from graphrag_sdk.models.litellm import LiteModel
from graphrag_sdk import KnowledgeGraph, KnowledgeGraphModelConfig
from graphrag_sdk import KnowledgeGraph, KnowledgeGraphModelConfig, GenerativeModelConfig

load_dotenv()

Expand Down Expand Up @@ -68,7 +68,7 @@ def setUpClass(cls):
)
)
cls.graph_name = "IMDB_openai"
model = LiteModel(model_name="gpt-4o")
model = LiteModel(model_name="gpt-4o", generation_config=GenerativeModelConfig(temperature=0))
cls.kg = KnowledgeGraph(
name=cls.graph_name,
ontology=cls.ontology,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_kg_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from graphrag_sdk.ontology import Ontology
from graphrag_sdk.attribute import Attribute, AttributeType
from graphrag_sdk.models.openai import OpenAiGenerativeModel
from graphrag_sdk import KnowledgeGraph, KnowledgeGraphModelConfig
from graphrag_sdk import KnowledgeGraph, KnowledgeGraphModelConfig, GenerativeModelConfig

load_dotenv()

Expand Down Expand Up @@ -68,7 +68,7 @@ def setUpClass(cls):
)
)
cls.graph_name = "IMDB_openai"
model = OpenAiGenerativeModel(model_name="gpt-3.5-turbo-0125")
model = OpenAiGenerativeModel(model_name="gpt-3.5-turbo-0125", generation_config=GenerativeModelConfig(temperature=0))
cls.kg = KnowledgeGraph(
name=cls.graph_name,
ontology=cls.ontology,
Expand Down

0 comments on commit 1421dc4

Please sign in to comment.