Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,7 @@
"namespace": {
"type": "string",
"nullable": true,
"default": "lightspeed-stack",
"default": "public",
"description": "Database namespace",
"title": "Name space"
},
Expand Down
2 changes: 1 addition & 1 deletion src/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ class PostgreSQLDatabaseConfiguration(ConfigurationBase):
)

namespace: Optional[str] = Field(
"lightspeed-stack",
"public",
title="Name space",
description="Database namespace",
)
Expand Down
5 changes: 5 additions & 0 deletions src/quota/connect_pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ def connect_pg(config: PostgreSQLDatabaseConfiguration) -> Any:
psycopg2.Error: If establishing the database connection fails.
"""
logger.info("Connecting to PostgreSQL storage")
namespace = "public"
if config.namespace is not None:
namespace = config.namespace

try:
connection = psycopg2.connect(
host=config.host,
Expand All @@ -35,6 +39,7 @@ def connect_pg(config: PostgreSQLDatabaseConfiguration) -> Any:
sslmode=config.ssl_mode,
# sslrootcert=config.ca_cert_path,
gssencmode=config.gss_encmode,
options=f"-c search_path={namespace}",
)
if connection is not None:
connection.autocommit = True
Expand Down
8 changes: 6 additions & 2 deletions src/quota/revokable_quota_limiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from quota.quota_exceed_error import QuotaExceedError
from quota.quota_limiter import QuotaLimiter
from quota.sql import (
CREATE_QUOTA_TABLE,
CREATE_QUOTA_TABLE_PG,
CREATE_QUOTA_TABLE_SQLITE,
UPDATE_AVAILABLE_QUOTA_PG,
UPDATE_AVAILABLE_QUOTA_SQLITE,
SELECT_QUOTA_PG,
Expand Down Expand Up @@ -185,7 +186,10 @@ def _initialize_tables(self) -> None:
"""Initialize tables used by quota limiter."""
logger.info("Initializing tables for quota limiter")
cursor = self.connection.cursor()
cursor.execute(CREATE_QUOTA_TABLE)
if self.sqlite_connection_config is not None:
cursor.execute(CREATE_QUOTA_TABLE_SQLITE)
elif self.postgres_connection_config is not None:
cursor.execute(CREATE_QUOTA_TABLE_PG)
cursor.close()
self.connection.commit()

Expand Down
15 changes: 14 additions & 1 deletion src/quota/sql.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,19 @@
"""SQL commands used by quota management package."""

CREATE_QUOTA_TABLE = """
CREATE_QUOTA_TABLE_PG = """
CREATE TABLE IF NOT EXISTS quota_limits (
id text NOT NULL,
subject char(1) NOT NULL,
quota_limit int NOT NULL,
available int,
updated_at timestamp with time zone,
revoked_at timestamp with time zone,
PRIMARY KEY(id, subject)
);
"""


CREATE_QUOTA_TABLE_SQLITE = """
CREATE TABLE IF NOT EXISTS quota_limits (
id text NOT NULL,
subject char(1) NOT NULL,
Expand Down
20 changes: 15 additions & 5 deletions src/runners/quota_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""User and cluster quota scheduler runner."""

from typing import Any
from typing import Any, Optional
from threading import Thread
from time import sleep

Expand All @@ -17,7 +17,8 @@
from quota.connect_sqlite import connect_sqlite

from quota.sql import (
CREATE_QUOTA_TABLE,
CREATE_QUOTA_TABLE_PG,
CREATE_QUOTA_TABLE_SQLITE,
INCREASE_QUOTA_STATEMENT_PG,
INCREASE_QUOTA_STATEMENT_SQLITE,
RESET_QUOTA_STATEMENT_PG,
Expand Down Expand Up @@ -59,7 +60,15 @@ def quota_scheduler(config: QuotaHandlersConfiguration) -> bool:
logger.warning("Can not connect to database, skipping")
return False

init_tables(connection)
create_quota_table: Optional[str] = None
if config.postgres is not None:
create_quota_table = CREATE_QUOTA_TABLE_PG
elif config.sqlite is not None:
create_quota_table = CREATE_QUOTA_TABLE_SQLITE

if create_quota_table is not None:
init_tables(connection, create_quota_table)

period = config.scheduler.period

increase_quota_statement = get_increase_quota_statement(config)
Expand Down Expand Up @@ -296,17 +305,18 @@ def connect(config: QuotaHandlersConfiguration) -> Any:
return None


def init_tables(connection: Any) -> None:
def init_tables(connection: Any, create_quota_table: str) -> None:
"""
Create the quota table required by the quota limiter on the provided database connection.

Parameters:
connection (Any): A DB-API compatible connection on which the quota
table(s) will be created; changes are committed before returning.
create_quota_table (str): Command used to create table with quota.
"""
logger.info("Initializing tables for quota limiter")
cursor = connection.cursor()
cursor.execute(CREATE_QUOTA_TABLE)
cursor.execute(create_quota_table)
cursor.close()
connection.commit()

Expand Down
173 changes: 170 additions & 3 deletions tests/unit/models/config/test_dump_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def test_dump_configuration(tmp_path: Path) -> None:
"password": "**********",
"ssl_mode": "require",
"gss_encmode": "disable",
"namespace": "lightspeed-stack",
"namespace": "public",
"ca_cert_path": None,
},
},
Expand Down Expand Up @@ -467,7 +467,7 @@ def test_dump_configuration_with_quota_limiters(tmp_path: Path) -> None:
"password": "**********",
"ssl_mode": "require",
"gss_encmode": "disable",
"namespace": "lightspeed-stack",
"namespace": "public",
"ca_cert_path": None,
},
},
Expand Down Expand Up @@ -542,6 +542,7 @@ def test_dump_configuration_byok(tmp_path: Path) -> None:
ca_cert_path=None,
ssl_mode="require",
gss_encmode="disable",
namespace="foo",
),
),
mcp_servers=[],
Expand Down Expand Up @@ -652,7 +653,7 @@ def test_dump_configuration_byok(tmp_path: Path) -> None:
"password": "**********",
"ssl_mode": "require",
"gss_encmode": "disable",
"namespace": "lightspeed-stack",
"namespace": "foo",
"ca_cert_path": None,
},
},
Expand Down Expand Up @@ -681,3 +682,169 @@ def test_dump_configuration_byok(tmp_path: Path) -> None:
"enable_token_history": False,
},
}


def test_dump_configuration_pg_namespace(tmp_path: Path) -> None:
"""
Test that the Configuration object can be serialized to a JSON file and
that the resulting file contains all expected sections and values.

Please note that redaction process is not in place.
"""
cfg = Configuration(
name="test_name",
service=ServiceConfiguration(
tls_config=TLSConfiguration(
tls_certificate_path=Path("tests/configuration/server.crt"),
tls_key_path=Path("tests/configuration/server.key"),
tls_key_password=Path("tests/configuration/password"),
),
cors=CORSConfiguration(
allow_origins=["foo_origin", "bar_origin", "baz_origin"],
allow_credentials=False,
allow_methods=["foo_method", "bar_method", "baz_method"],
allow_headers=["foo_header", "bar_header", "baz_header"],
),
),
llama_stack=LlamaStackConfiguration(
use_as_library_client=True,
library_client_config_path="tests/configuration/run.yaml",
api_key=SecretStr("whatever"),
),
user_data_collection=UserDataCollection(
feedback_enabled=False, feedback_storage=None
),
database=DatabaseConfiguration(
sqlite=None,
postgres=PostgreSQLDatabaseConfiguration(
db="lightspeed_stack",
user="ls_user",
password=SecretStr("ls_password"),
port=5432,
ca_cert_path=None,
ssl_mode="require",
gss_encmode="disable",
namespace="foo",
),
),
mcp_servers=[],
customization=None,
inference=InferenceConfiguration(
default_provider="default_provider",
default_model="default_model",
),
)
assert cfg is not None
dump_file = tmp_path / "test.json"
cfg.dump(dump_file)

with open(dump_file, "r", encoding="utf-8") as fin:
content = json.load(fin)
# content should be loaded
assert content is not None

# all sections must exists
assert "name" in content
assert "service" in content
assert "llama_stack" in content
assert "user_data_collection" in content
assert "mcp_servers" in content
assert "authentication" in content
assert "authorization" in content
assert "customization" in content
assert "inference" in content
assert "database" in content
assert "byok_rag" in content
assert "quota_handlers" in content

# check the whole deserialized JSON file content
assert content == {
"name": "test_name",
"service": {
"host": "localhost",
"port": 8080,
"auth_enabled": False,
"workers": 1,
"color_log": True,
"access_log": True,
"tls_config": {
"tls_certificate_path": "tests/configuration/server.crt",
"tls_key_password": "tests/configuration/password",
"tls_key_path": "tests/configuration/server.key",
},
"cors": {
"allow_credentials": False,
"allow_headers": [
"foo_header",
"bar_header",
"baz_header",
],
"allow_methods": [
"foo_method",
"bar_method",
"baz_method",
],
"allow_origins": [
"foo_origin",
"bar_origin",
"baz_origin",
],
},
},
"llama_stack": {
"url": None,
"use_as_library_client": True,
"api_key": "**********",
"library_client_config_path": "tests/configuration/run.yaml",
},
"user_data_collection": {
"feedback_enabled": False,
"feedback_storage": None,
"transcripts_enabled": False,
"transcripts_storage": None,
},
"mcp_servers": [],
"authentication": {
"module": "noop",
"skip_tls_verification": False,
"k8s_ca_cert_path": None,
"k8s_cluster_api": None,
"jwk_config": None,
"api_key_config": None,
"rh_identity_config": None,
},
"customization": None,
"inference": {
"default_provider": "default_provider",
"default_model": "default_model",
},
"database": {
"sqlite": None,
"postgres": {
"host": "localhost",
"port": 5432,
"db": "lightspeed_stack",
"user": "ls_user",
"password": "**********",
"ssl_mode": "require",
"gss_encmode": "disable",
"ca_cert_path": None,
"namespace": "foo",
},
},
"authorization": None,
"conversation_cache": {
"memory": None,
"postgres": None,
"sqlite": None,
"type": None,
},
"byok_rag": [],
"quota_handlers": {
"sqlite": None,
"postgres": None,
"limiters": [],
"scheduler": {"period": 1},
"enable_token_history": False,
},
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,25 @@ def test_postgresql_database_configuration() -> None:
assert c.password.get_secret_value() == "password"
assert c.ssl_mode == POSTGRES_DEFAULT_SSL_MODE
assert c.gss_encmode == POSTGRES_DEFAULT_GSS_ENCMODE
assert c.namespace == "lightspeed-stack"
assert c.namespace == "public"
assert c.ca_cert_path is None


def test_postgresql_database_configuration_namespace_specification() -> None:
"""Test the PostgreSQLDatabaseConfiguration model."""
# pylint: disable=no-member
c = PostgreSQLDatabaseConfiguration(
db="db", user="user", password="password", namespace="foo"
)
assert c is not None
assert c.host == "localhost"
assert c.port == 5432
assert c.db == "db"
assert c.user == "user"
assert c.password.get_secret_value() == "password"
assert c.ssl_mode == POSTGRES_DEFAULT_SSL_MODE
assert c.gss_encmode == POSTGRES_DEFAULT_GSS_ENCMODE
assert c.namespace == "foo"
assert c.ca_cert_path is None


Expand Down
11 changes: 9 additions & 2 deletions tests/unit/quota/test_connect_pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@ def test_connect_pg_when_connection_established(mocker: MockerFixture) -> None:
"""Test the connection to PostgreSQL database."""
# any correct PostgreSQL configuration can be used
configuration = PostgreSQLDatabaseConfiguration(
db="db", user="user", password="password"
db="db",
user="user",
password="password",
namespace="foo",
)

# do not use connection to real PostgreSQL instance
Expand All @@ -28,7 +31,11 @@ def test_connect_pg_when_connection_error(mocker: MockerFixture) -> None:
"""Test the connection to PostgreSQL database."""
# any correct PostgreSQL configuration can be used
configuration = PostgreSQLDatabaseConfiguration(
host="foo", db="db", user="user", password="password"
host="foo",
db="db",
user="user",
password="password",
namespace="foo",
)

# do not use connection to real PostgreSQL instance
Expand Down
Loading
Loading