Skip to content

Commit

Permalink
update-config-usage
Browse files Browse the repository at this point in the history
  • Loading branch information
galshubeli committed Mar 5, 2025
1 parent 5189050 commit 5be9461
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 24 deletions.
49 changes: 27 additions & 22 deletions graphrag_sdk/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,28 @@ def __init__(
system_instruction: Optional[str] = None,
):
"""
Initialize the GoogleGenerativeModel with required parameters.
Initializes the GoogleGenerativeModel with the specified parameters.
Args:
model_name (str): Name of the GoogleAI model.
model_name (str): The name of the GoogleAI model to use.
generation_config (Optional[GoogleGenerationConfig]): Configuration settings for generation.
system_instruction (Optional[str]): System-level instruction for the model.
If not provided, a default instance of `GoogleGenerationConfig` is used.
system_instruction (Optional[str]): An optional system-level instruction to guide the model’s behavior.
Raises:
TypeError: If `generation_config` is provided but is not an instance of `GoogleGenerationConfig`.
"""
if generation_config is not None and not isinstance(generation_config, GoogleGenerationConfig):
raise TypeError(
"generation_config must be an instance of GoogleGenerationConfig "
"(from google.generativeai import GenerationConfig as GoogleGenerationConfig)."
)

self._model_name = model_name
self._generation_config = generation_config or GenerativeModelConfig()
self._generation_config = generation_config or GoogleGenerationConfig()
self._system_instruction = system_instruction

# Configure the API key for Google Generative AI
configure(api_key=os.environ["GOOGLE_API_KEY"])

def start_chat(self, system_instruction: Optional[str] = None) -> GenerativeModelChatSession:
Expand All @@ -54,13 +67,7 @@ def start_chat(self, system_instruction: Optional[str] = None) -> GenerativeMode
self._model = GoogleGenerativeModel(
self._model_name,
generation_config=(
GoogleGenerationConfig(
temperature=self._generation_config.temperature,
top_p=self._generation_config.top_p,
top_k=self._generation_config.top_k,
max_output_tokens=self._generation_config.max_tokens,
stop_sequences=self._generation_config.stop,
)
self._generation_config
if self._generation_config is not None
else None
),
Expand Down Expand Up @@ -102,7 +109,13 @@ def to_json(self) -> dict:
"""
return {
"model_name": self._model_name,
"generation_config": self._generation_config.to_json(),
"generation_config": {
"temperature": self._generation_config.temperature,
"top_p": self._generation_config.top_p,
"max_output_tokens": self._generation_config.max_output_tokens,
"stop_sequences": self._generation_config.stop_sequences,
"response_mime_type": self._generation_config.response_mime_type,
},
"system_instruction": self._system_instruction,
}

Expand All @@ -117,9 +130,7 @@ def from_json(json: dict) -> "GenerativeModel":
"""
return GeminiGenerativeModel(
model_name=json["model_name"],
generation_config=GenerativeModelConfig.from_json(
json["generation_config"]
),
generation_config=GoogleGenerationConfig(**json["generation_config"]),
system_instruction=json["system_instruction"],
)

Expand Down Expand Up @@ -171,13 +182,7 @@ def _adjust_generation_config(self, output_method: OutputMethod) -> dict:
"temperature": 0
}

config = self._model._generation_config.to_json()

# Convert OpenAI-style config keys to GoogleAI-style keys
config["max_output_tokens"] = config.pop("max_tokens")
config["response_mime_type"] = config.pop("response_format")
config["stop_sequences"] = config.pop("stop")
return config
return self._model._generation_config

def delete_last_message(self):
"""
Expand Down
5 changes: 3 additions & 2 deletions tests/test_kg_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from graphrag_sdk.relation import Relation
from graphrag_sdk.attribute import Attribute, AttributeType
from graphrag_sdk.models.gemini import GeminiGenerativeModel
from graphrag_sdk import KnowledgeGraph, KnowledgeGraphModelConfig, GenerativeModelConfig
from graphrag_sdk import KnowledgeGraph, KnowledgeGraphModelConfig
from google.generativeai import GenerationConfig as GoogleGenerationConfig


load_dotenv()
Expand Down Expand Up @@ -71,7 +72,7 @@ def setUpClass(cls):

cls.graph_name = "IMDB_gemini"

model = GeminiGenerativeModel(model_name="gemini-1.5-flash-001", generation_config=GenerativeModelConfig(temperature=0))
model = GeminiGenerativeModel(model_name="gemini-1.5-flash-001", generation_config=GoogleGenerationConfig(temperature=0))
cls.kg = KnowledgeGraph(
name=cls.graph_name,
ontology=cls.ontology,
Expand Down

0 comments on commit 5be9461

Please sign in to comment.