diff --git a/docs/config.json b/docs/config.json index 5ed3d0543..adada0e5f 100644 --- a/docs/config.json +++ b/docs/config.json @@ -726,7 +726,7 @@ "namespace": { "type": "string", "nullable": true, - "default": "lightspeed-stack", + "default": "public", "description": "Database namespace", "title": "Name space" }, diff --git a/src/models/config.py b/src/models/config.py index 23c23a4fe..ace465bde 100644 --- a/src/models/config.py +++ b/src/models/config.py @@ -218,7 +218,7 @@ class PostgreSQLDatabaseConfiguration(ConfigurationBase): ) namespace: Optional[str] = Field( - "lightspeed-stack", + "public", title="Name space", description="Database namespace", ) diff --git a/src/quota/connect_pg.py b/src/quota/connect_pg.py index 7e5eb96e2..fbbf95109 100644 --- a/src/quota/connect_pg.py +++ b/src/quota/connect_pg.py @@ -25,6 +25,10 @@ def connect_pg(config: PostgreSQLDatabaseConfiguration) -> Any: psycopg2.Error: If establishing the database connection fails. """ logger.info("Connecting to PostgreSQL storage") + namespace = "public" + if config.namespace is not None: + namespace = config.namespace + try: connection = psycopg2.connect( host=config.host, @@ -35,6 +39,7 @@ def connect_pg(config: PostgreSQLDatabaseConfiguration) -> Any: sslmode=config.ssl_mode, # sslrootcert=config.ca_cert_path, gssencmode=config.gss_encmode, + options=f"-c search_path={namespace}", ) if connection is not None: connection.autocommit = True diff --git a/src/quota/revokable_quota_limiter.py b/src/quota/revokable_quota_limiter.py index 42065b3ea..49340e2cd 100644 --- a/src/quota/revokable_quota_limiter.py +++ b/src/quota/revokable_quota_limiter.py @@ -8,7 +8,8 @@ from quota.quota_exceed_error import QuotaExceedError from quota.quota_limiter import QuotaLimiter from quota.sql import ( - CREATE_QUOTA_TABLE, + CREATE_QUOTA_TABLE_PG, + CREATE_QUOTA_TABLE_SQLITE, UPDATE_AVAILABLE_QUOTA_PG, UPDATE_AVAILABLE_QUOTA_SQLITE, SELECT_QUOTA_PG, @@ -185,7 +186,10 @@ def _initialize_tables(self) -> None: """Initialize tables used by quota limiter.""" logger.info("Initializing tables for quota limiter") cursor = self.connection.cursor() - cursor.execute(CREATE_QUOTA_TABLE) + if self.sqlite_connection_config is not None: + cursor.execute(CREATE_QUOTA_TABLE_SQLITE) + elif self.postgres_connection_config is not None: + cursor.execute(CREATE_QUOTA_TABLE_PG) cursor.close() self.connection.commit() diff --git a/src/quota/sql.py b/src/quota/sql.py index 93a970999..0f30e98f8 100644 --- a/src/quota/sql.py +++ b/src/quota/sql.py @@ -1,6 +1,19 @@ """SQL commands used by quota management package.""" -CREATE_QUOTA_TABLE = """ +CREATE_QUOTA_TABLE_PG = """ + CREATE TABLE IF NOT EXISTS quota_limits ( + id text NOT NULL, + subject char(1) NOT NULL, + quota_limit int NOT NULL, + available int, + updated_at timestamp with time zone, + revoked_at timestamp with time zone, + PRIMARY KEY(id, subject) + ); + """ + + +CREATE_QUOTA_TABLE_SQLITE = """ CREATE TABLE IF NOT EXISTS quota_limits ( id text NOT NULL, subject char(1) NOT NULL, diff --git a/src/runners/quota_scheduler.py b/src/runners/quota_scheduler.py index dcf2ba039..3d9fe7903 100644 --- a/src/runners/quota_scheduler.py +++ b/src/runners/quota_scheduler.py @@ -1,6 +1,6 @@ """User and cluster quota scheduler runner.""" -from typing import Any +from typing import Any, Optional from threading import Thread from time import sleep @@ -17,7 +17,8 @@ from quota.connect_sqlite import connect_sqlite from quota.sql import ( - CREATE_QUOTA_TABLE, + CREATE_QUOTA_TABLE_PG, + CREATE_QUOTA_TABLE_SQLITE, INCREASE_QUOTA_STATEMENT_PG, INCREASE_QUOTA_STATEMENT_SQLITE, RESET_QUOTA_STATEMENT_PG, @@ -59,7 +60,15 @@ def quota_scheduler(config: QuotaHandlersConfiguration) -> bool: logger.warning("Can not connect to database, skipping") return False - init_tables(connection) + create_quota_table: Optional[str] = None + if config.postgres is not None: + create_quota_table = CREATE_QUOTA_TABLE_PG + elif config.sqlite is not None: + create_quota_table = CREATE_QUOTA_TABLE_SQLITE + + if create_quota_table is not None: + init_tables(connection, create_quota_table) + period = config.scheduler.period increase_quota_statement = get_increase_quota_statement(config) @@ -296,17 +305,18 @@ def connect(config: QuotaHandlersConfiguration) -> Any: return None -def init_tables(connection: Any) -> None: +def init_tables(connection: Any, create_quota_table: str) -> None: """ Create the quota table required by the quota limiter on the provided database connection. Parameters: connection (Any): A DB-API compatible connection on which the quota table(s) will be created; changes are committed before returning. + create_quota_table (str): Command used to create table with quota. """ logger.info("Initializing tables for quota limiter") cursor = connection.cursor() - cursor.execute(CREATE_QUOTA_TABLE) + cursor.execute(create_quota_table) cursor.close() connection.commit() diff --git a/tests/unit/models/config/test_dump_configuration.py b/tests/unit/models/config/test_dump_configuration.py index 38177a8a7..a1b3353a4 100644 --- a/tests/unit/models/config/test_dump_configuration.py +++ b/tests/unit/models/config/test_dump_configuration.py @@ -167,7 +167,7 @@ def test_dump_configuration(tmp_path: Path) -> None: "password": "**********", "ssl_mode": "require", "gss_encmode": "disable", - "namespace": "lightspeed-stack", + "namespace": "public", "ca_cert_path": None, }, }, @@ -467,7 +467,7 @@ def test_dump_configuration_with_quota_limiters(tmp_path: Path) -> None: "password": "**********", "ssl_mode": "require", "gss_encmode": "disable", - "namespace": "lightspeed-stack", + "namespace": "public", "ca_cert_path": None, }, }, @@ -542,6 +542,7 @@ def test_dump_configuration_byok(tmp_path: Path) -> None: ca_cert_path=None, ssl_mode="require", gss_encmode="disable", + namespace="foo", ), ), mcp_servers=[], @@ -652,7 +653,7 @@ def test_dump_configuration_byok(tmp_path: Path) -> None: "password": "**********", "ssl_mode": "require", "gss_encmode": "disable", - "namespace": "lightspeed-stack", + "namespace": "foo", "ca_cert_path": None, }, }, @@ -681,3 +682,169 @@ def test_dump_configuration_byok(tmp_path: Path) -> None: "enable_token_history": False, }, } + + +def test_dump_configuration_pg_namespace(tmp_path: Path) -> None: + """ + Test that the Configuration object can be serialized to a JSON file and + that the resulting file contains all expected sections and values. + + Please note that redaction process is not in place. + """ + cfg = Configuration( + name="test_name", + service=ServiceConfiguration( + tls_config=TLSConfiguration( + tls_certificate_path=Path("tests/configuration/server.crt"), + tls_key_path=Path("tests/configuration/server.key"), + tls_key_password=Path("tests/configuration/password"), + ), + cors=CORSConfiguration( + allow_origins=["foo_origin", "bar_origin", "baz_origin"], + allow_credentials=False, + allow_methods=["foo_method", "bar_method", "baz_method"], + allow_headers=["foo_header", "bar_header", "baz_header"], + ), + ), + llama_stack=LlamaStackConfiguration( + use_as_library_client=True, + library_client_config_path="tests/configuration/run.yaml", + api_key=SecretStr("whatever"), + ), + user_data_collection=UserDataCollection( + feedback_enabled=False, feedback_storage=None + ), + database=DatabaseConfiguration( + sqlite=None, + postgres=PostgreSQLDatabaseConfiguration( + db="lightspeed_stack", + user="ls_user", + password=SecretStr("ls_password"), + port=5432, + ca_cert_path=None, + ssl_mode="require", + gss_encmode="disable", + namespace="foo", + ), + ), + mcp_servers=[], + customization=None, + inference=InferenceConfiguration( + default_provider="default_provider", + default_model="default_model", + ), + ) + assert cfg is not None + dump_file = tmp_path / "test.json" + cfg.dump(dump_file) + + with open(dump_file, "r", encoding="utf-8") as fin: + content = json.load(fin) + # content should be loaded + assert content is not None + + # all sections must exists + assert "name" in content + assert "service" in content + assert "llama_stack" in content + assert "user_data_collection" in content + assert "mcp_servers" in content + assert "authentication" in content + assert "authorization" in content + assert "customization" in content + assert "inference" in content + assert "database" in content + assert "byok_rag" in content + assert "quota_handlers" in content + + # check the whole deserialized JSON file content + assert content == { + "name": "test_name", + "service": { + "host": "localhost", + "port": 8080, + "auth_enabled": False, + "workers": 1, + "color_log": True, + "access_log": True, + "tls_config": { + "tls_certificate_path": "tests/configuration/server.crt", + "tls_key_password": "tests/configuration/password", + "tls_key_path": "tests/configuration/server.key", + }, + "cors": { + "allow_credentials": False, + "allow_headers": [ + "foo_header", + "bar_header", + "baz_header", + ], + "allow_methods": [ + "foo_method", + "bar_method", + "baz_method", + ], + "allow_origins": [ + "foo_origin", + "bar_origin", + "baz_origin", + ], + }, + }, + "llama_stack": { + "url": None, + "use_as_library_client": True, + "api_key": "**********", + "library_client_config_path": "tests/configuration/run.yaml", + }, + "user_data_collection": { + "feedback_enabled": False, + "feedback_storage": None, + "transcripts_enabled": False, + "transcripts_storage": None, + }, + "mcp_servers": [], + "authentication": { + "module": "noop", + "skip_tls_verification": False, + "k8s_ca_cert_path": None, + "k8s_cluster_api": None, + "jwk_config": None, + "api_key_config": None, + "rh_identity_config": None, + }, + "customization": None, + "inference": { + "default_provider": "default_provider", + "default_model": "default_model", + }, + "database": { + "sqlite": None, + "postgres": { + "host": "localhost", + "port": 5432, + "db": "lightspeed_stack", + "user": "ls_user", + "password": "**********", + "ssl_mode": "require", + "gss_encmode": "disable", + "ca_cert_path": None, + "namespace": "foo", + }, + }, + "authorization": None, + "conversation_cache": { + "memory": None, + "postgres": None, + "sqlite": None, + "type": None, + }, + "byok_rag": [], + "quota_handlers": { + "sqlite": None, + "postgres": None, + "limiters": [], + "scheduler": {"period": 1}, + "enable_token_history": False, + }, + } diff --git a/tests/unit/models/config/test_postgresql_database_configuration.py b/tests/unit/models/config/test_postgresql_database_configuration.py index 556790be8..b25a7405e 100644 --- a/tests/unit/models/config/test_postgresql_database_configuration.py +++ b/tests/unit/models/config/test_postgresql_database_configuration.py @@ -27,7 +27,25 @@ def test_postgresql_database_configuration() -> None: assert c.password.get_secret_value() == "password" assert c.ssl_mode == POSTGRES_DEFAULT_SSL_MODE assert c.gss_encmode == POSTGRES_DEFAULT_GSS_ENCMODE - assert c.namespace == "lightspeed-stack" + assert c.namespace == "public" + assert c.ca_cert_path is None + + +def test_postgresql_database_configuration_namespace_specification() -> None: + """Test the PostgreSQLDatabaseConfiguration model.""" + # pylint: disable=no-member + c = PostgreSQLDatabaseConfiguration( + db="db", user="user", password="password", namespace="foo" + ) + assert c is not None + assert c.host == "localhost" + assert c.port == 5432 + assert c.db == "db" + assert c.user == "user" + assert c.password.get_secret_value() == "password" + assert c.ssl_mode == POSTGRES_DEFAULT_SSL_MODE + assert c.gss_encmode == POSTGRES_DEFAULT_GSS_ENCMODE + assert c.namespace == "foo" assert c.ca_cert_path is None diff --git a/tests/unit/quota/test_connect_pg.py b/tests/unit/quota/test_connect_pg.py index 114499f01..26c8c5d32 100644 --- a/tests/unit/quota/test_connect_pg.py +++ b/tests/unit/quota/test_connect_pg.py @@ -13,7 +13,10 @@ def test_connect_pg_when_connection_established(mocker: MockerFixture) -> None: """Test the connection to PostgreSQL database.""" # any correct PostgreSQL configuration can be used configuration = PostgreSQLDatabaseConfiguration( - db="db", user="user", password="password" + db="db", + user="user", + password="password", + namespace="foo", ) # do not use connection to real PostgreSQL instance @@ -28,7 +31,11 @@ def test_connect_pg_when_connection_error(mocker: MockerFixture) -> None: """Test the connection to PostgreSQL database.""" # any correct PostgreSQL configuration can be used configuration = PostgreSQLDatabaseConfiguration( - host="foo", db="db", user="user", password="password" + host="foo", + db="db", + user="user", + password="password", + namespace="foo", ) # do not use connection to real PostgreSQL instance diff --git a/tests/unit/quota/test_quota_limiter_factory.py b/tests/unit/quota/test_quota_limiter_factory.py index bd72a27be..0276b7394 100644 --- a/tests/unit/quota/test_quota_limiter_factory.py +++ b/tests/unit/quota/test_quota_limiter_factory.py @@ -28,7 +28,10 @@ def test_quota_limiters_no_limiters_pg_storage() -> None: """Test the quota limiters creating when no limiters are specified.""" configuration = QuotaHandlersConfiguration() configuration.postgres = PostgreSQLDatabaseConfiguration( - db="test", user="user", password="password" + db="test", + user="user", + password="password", + namespace="foo", ) configuration.limiters = None limiters = QuotaLimiterFactory.quota_limiters(configuration) @@ -50,7 +53,10 @@ def test_quota_limiters_empty_limiters_pg_storage() -> None: """Test the quota limiters creating when no limiters are specified.""" configuration = QuotaHandlersConfiguration() configuration.postgres = PostgreSQLDatabaseConfiguration( - db="test", user="user", password="password" + db="test", + user="user", + password="password", + namespace="foo", ) configuration.limiters = [] limiters = QuotaLimiterFactory.quota_limiters(configuration) @@ -74,7 +80,10 @@ def test_quota_limiters_user_quota_limiter_postgres_storage( """Test the quota limiters creating when one limiter is specified.""" configuration = QuotaHandlersConfiguration() configuration.postgres = PostgreSQLDatabaseConfiguration( - db="test", user="user", password="password" + db="test", + user="user", + password="password", + namespace="foo", ) configuration.limiters = [ QuotaLimiterConfiguration( @@ -118,7 +127,10 @@ def test_quota_limiters_cluster_quota_limiter_postgres_storage( """Test the quota limiters creating when one limiter is specified.""" configuration = QuotaHandlersConfiguration() configuration.postgres = PostgreSQLDatabaseConfiguration( - db="test", user="user", password="password" + db="test", + user="user", + password="password", + namespace="foo", ) configuration.limiters = [ QuotaLimiterConfiguration( @@ -160,7 +172,10 @@ def test_quota_limiters_two_limiters(mocker: MockerFixture) -> None: """Test the quota limiters creating when two limiters are specified.""" configuration = QuotaHandlersConfiguration() configuration.postgres = PostgreSQLDatabaseConfiguration( - db="test", user="user", password="password" + db="test", + user="user", + password="password", + namespace="foo", ) configuration.limiters = [ QuotaLimiterConfiguration( @@ -190,7 +205,10 @@ def test_quota_limiters_invalid_limiter_type(mocker: MockerFixture) -> None: """Test the quota limiters creating when invalid limiter type is specified.""" configuration = QuotaHandlersConfiguration() configuration.postgres = PostgreSQLDatabaseConfiguration( - db="test", user="user", password="password" + db="test", + user="user", + password="password", + namespace="foo", ) configuration.limiters = [ QuotaLimiterConfiguration(