Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion contributing/samples/gepa/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
from tau_bench.types import EnvRunResult
from tau_bench.types import RunConfig
import tau_bench_agent as tau_bench_agent_lib

import utils


Expand Down
1 change: 0 additions & 1 deletion contributing/samples/gepa/run_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from absl import flags
import experiment
from google.genai import types

import utils

_OUTPUT_DIR = flags.DEFINE_string(
Expand Down
93 changes: 72 additions & 21 deletions src/google/adk/sessions/database_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import pickle
from typing import Any
from typing import Optional
from typing import overload
import uuid

from google.genai import types
Expand Down Expand Up @@ -413,36 +414,86 @@ def set_sqlite_pragma(dbapi_connection, connection_record):
class DatabaseSessionService(BaseSessionService):
"""A session service that uses a database for storage."""

def __init__(self, db_url: str, **kwargs: Any):
"""Initializes the database session service with a database URL."""
# 1. Create DB engine for db connection
# 2. Create all tables based on schema
# 3. Initialize all properties
try:
db_engine = create_async_engine(db_url, **kwargs)
@overload
def __init__(
self,
db_url: str,
**kwargs: Any,
) -> None:
"""Initializes the database session service with a database URL.

if db_engine.dialect.name == "sqlite":
# Set sqlite pragma to enable foreign keys constraints
event.listen(db_engine.sync_engine, "connect", set_sqlite_pragma)
Args:
db_url: Database URL string for creating a new engine.
**kwargs: Additional keyword arguments passed to create_async_engine.
"""

except Exception as e:
if isinstance(e, ArgumentError):
raise ValueError(
f"Invalid database URL format or argument '{db_url}'."
) from e
if isinstance(e, ImportError):
@overload
def __init__(
self,
*,
db_engine: AsyncEngine,
) -> None:
"""Initializes the database session service with an existing SQLAlchemy AsyncEngine.

Args:
db_engine: Existing SQLAlchemy AsyncEngine instance to use.
"""

def __init__(
self,
db_url: Optional[str] = None,
db_engine: Optional[AsyncEngine] = None,
**kwargs: Any,
):
"""Initializes the database session service.

Args:
db_url: Database URL string for creating a new engine. Mutually exclusive
with db_engine.
db_engine: Existing AsyncEngine instance. Mutually exclusive with db_url.
**kwargs: Additional keyword arguments passed to create_async_engine when
db_url is provided. Ignored when db_engine is provided.

Raises:
ValueError: If neither or both db_url and db_engine are provided, or if
engine creation fails.
"""
if (db_url is None) == (db_engine is None):
raise ValueError(
"Exactly one of 'db_url' or 'db_engine' must be provided."
)

# 1. Create or use provided DB engine for db connection
# 2. Create all tables based on schema
# 3. Initialize all properties
if db_engine is not None:
engine = db_engine
else:
try:
engine = create_async_engine(db_url, **kwargs)

if engine.dialect.name == "sqlite":
# Set sqlite pragma to enable foreign keys constraints
event.listen(engine.sync_engine, "connect", set_sqlite_pragma)

except Exception as e:
if isinstance(e, ArgumentError):
raise ValueError(
f"Invalid database URL format or argument '{db_url}'."
) from e
if isinstance(e, ImportError):
raise ValueError(
f"Database related module not found for URL '{db_url}'."
) from e
raise ValueError(
f"Database related module not found for URL '{db_url}'."
f"Failed to create database engine for URL '{db_url}'"
) from e
raise ValueError(
f"Failed to create database engine for URL '{db_url}'"
) from e

# Get the local timezone
local_timezone = get_localzone()
logger.info("Local timezone: %s", local_timezone)

self.db_engine: AsyncEngine = db_engine
self.db_engine: AsyncEngine = engine
self.metadata: MetaData = MetaData()

# DB session factory method
Expand Down
Loading