From e274ae183c0a7caeeea4b2f6f48a588ae222ee7a Mon Sep 17 00:00:00 2001 From: NolanTrem <34580718+NolanTrem@users.noreply.github.com> Date: Fri, 16 Aug 2024 15:03:17 -0700 Subject: [PATCH 1/3] Get running --- alembic.ini | 116 ------------------ alembic/env.py | 77 ------------ alembic/script.py.mako | 26 ---- .../7423a09421e9_add_user_id_to_logs.py | 47 ------- r2r/base/abstractions/search.py | 2 +- r2r/base/api/models/__init__.py | 2 - r2r/base/api/models/ingestion/responses.py | 2 +- r2r/base/api/models/management/responses.py | 8 -- r2r/base/api/models/retrieval/responses.py | 6 +- r2r/cli/cli.py | 1 - r2r/cli/commands/server_operations.py | 26 ---- r2r/main/api/client/auth.py | 1 - r2r/main/api/client/client.py | 4 +- r2r/main/api/client/restructure.py | 2 +- r2r/main/api/routes/management/base.py | 11 +- r2r/main/services/management_service.py | 83 ++++++------- r2r/providers/database/vector.py | 8 +- 17 files changed, 55 insertions(+), 367 deletions(-) delete mode 100644 alembic.ini delete mode 100644 alembic/env.py delete mode 100644 alembic/script.py.mako delete mode 100644 alembic/versions/7423a09421e9_add_user_id_to_logs.py diff --git a/alembic.ini b/alembic.ini deleted file mode 100644 index 4e036789e..000000000 --- a/alembic.ini +++ /dev/null @@ -1,116 +0,0 @@ -# A generic, single database configuration. - -[alembic] -# path to migration scripts -# Use forward slashes (/) also on windows to provide an os agnostic path -script_location = alembic - -# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s -# Uncomment the line below if you want the files to be prepended with date and time -# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file -# for all available tokens -# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s - -# sys.path path, will be prepended to sys.path if present. -# defaults to the current working directory. -prepend_sys_path = . - -# timezone to use when rendering the date within the migration file -# as well as the filename. -# If specified, requires the python>=3.9 or backports.zoneinfo library. -# Any required deps can installed by adding `alembic[tz]` to the pip requirements -# string value is passed to ZoneInfo() -# leave blank for localtime -# timezone = - -# max length of characters to apply to the "slug" field -# truncate_slug_length = 40 - -# set to 'true' to run the environment during -# the 'revision' command, regardless of autogenerate -# revision_environment = false - -# set to 'true' to allow .pyc and .pyo files without -# a source .py file to be detected as revisions in the -# versions/ directory -# sourceless = false - -# version location specification; This defaults -# to alembic/versions. When using multiple version -# directories, initial revisions must be specified with --version-path. -# The path separator used here should be the separator specified by "version_path_separator" below. -# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions - -# version path separator; As mentioned above, this is the character used to split -# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. -# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. -# Valid values for version_path_separator are: -# -# version_path_separator = : -# version_path_separator = ; -# version_path_separator = space -version_path_separator = os # Use os.pathsep. Default configuration used for new projects. - -# set to 'true' to search source files recursively -# in each "version_locations" directory -# new in Alembic version 1.10 -# recursive_version_locations = false - -# the output encoding used when revision files -# are written from script.py.mako -# output_encoding = utf-8 - -sqlalchemy.url = sqlite:///local.sqlite - - -[post_write_hooks] -# post_write_hooks defines scripts or Python functions that are run -# on newly generated revision scripts. See the documentation for further -# detail and examples - -# format using "black" - use the console_scripts runner, against the "black" entrypoint -# hooks = black -# black.type = console_scripts -# black.entrypoint = black -# black.options = -l 79 REVISION_SCRIPT_FILENAME - -# lint with attempts to fix using "ruff" - use the exec runner, execute a binary -# hooks = ruff -# ruff.type = exec -# ruff.executable = %(here)s/.venv/bin/ruff -# ruff.options = --fix REVISION_SCRIPT_FILENAME - -# Logging configuration -[loggers] -keys = root,sqlalchemy,alembic - -[handlers] -keys = console - -[formatters] -keys = generic - -[logger_root] -level = WARN -handlers = console -qualname = - -[logger_sqlalchemy] -level = WARN -handlers = -qualname = sqlalchemy.engine - -[logger_alembic] -level = INFO -handlers = -qualname = alembic - -[handler_console] -class = StreamHandler -args = (sys.stderr,) -level = NOTSET -formatter = generic - -[formatter_generic] -format = %(levelname)-5.5s [%(name)s] %(message)s -datefmt = %H:%M:%S diff --git a/alembic/env.py b/alembic/env.py deleted file mode 100644 index b2040b5f7..000000000 --- a/alembic/env.py +++ /dev/null @@ -1,77 +0,0 @@ -from logging.config import fileConfig - -from sqlalchemy import engine_from_config, pool - -from alembic import context - -# this is the Alembic Config object, which provides -# access to the values within the .ini file in use. -config = context.config - -# Interpret the config file for Python logging. -# This line sets up loggers basically. -if config.config_file_name is not None: - fileConfig(config.config_file_name) - -# add your model's MetaData object here -# for 'autogenerate' support -# from myapp import mymodel -# target_metadata = mymodel.Base.metadata -target_metadata = None - -# other values from the config, defined by the needs of env.py, -# can be acquired: -# my_important_option = config.get_main_option("my_important_option") -# ... etc. - - -def run_migrations_offline() -> None: - """Run migrations in 'offline' mode. - - This configures the context with just a URL - and not an Engine, though an Engine is acceptable - here as well. By skipping the Engine creation - we don't even need a DBAPI to be available. - - Calls to context.execute() here emit the given string to the - script output. - - """ - url = config.get_main_option("sqlalchemy.url") - context.configure( - url=url, - target_metadata=target_metadata, - literal_binds=True, - dialect_opts={"paramstyle": "named"}, - ) - - with context.begin_transaction(): - context.run_migrations() - - -def run_migrations_online() -> None: - """Run migrations in 'online' mode. - - In this scenario we need to create an Engine - and associate a connection with the context. - - """ - connectable = engine_from_config( - config.get_section(config.config_ini_section, {}), - prefix="sqlalchemy.", - poolclass=pool.NullPool, - ) - - with connectable.connect() as connection: - context.configure( - connection=connection, target_metadata=target_metadata - ) - - with context.begin_transaction(): - context.run_migrations() - - -if context.is_offline_mode(): - run_migrations_offline() -else: - run_migrations_online() diff --git a/alembic/script.py.mako b/alembic/script.py.mako deleted file mode 100644 index fbc4b07dc..000000000 --- a/alembic/script.py.mako +++ /dev/null @@ -1,26 +0,0 @@ -"""${message} - -Revision ID: ${up_revision} -Revises: ${down_revision | comma,n} -Create Date: ${create_date} - -""" -from typing import Sequence, Union - -from alembic import op -import sqlalchemy as sa -${imports if imports else ""} - -# revision identifiers, used by Alembic. -revision: str = ${repr(up_revision)} -down_revision: Union[str, None] = ${repr(down_revision)} -branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} -depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} - - -def upgrade() -> None: - ${upgrades if upgrades else "pass"} - - -def downgrade() -> None: - ${downgrades if downgrades else "pass"} diff --git a/alembic/versions/7423a09421e9_add_user_id_to_logs.py b/alembic/versions/7423a09421e9_add_user_id_to_logs.py deleted file mode 100644 index bf451889e..000000000 --- a/alembic/versions/7423a09421e9_add_user_id_to_logs.py +++ /dev/null @@ -1,47 +0,0 @@ -"""add_user_id_to_logs - -Revision ID: 7423a09421e9 -Revises: -Create Date: 2024-08-05 10:49:10.714423 - -""" - -from typing import Sequence, Union - -import sqlalchemy as sa -from sqlalchemy.engine.reflection import Inspector - -from alembic import op - -# revision identifiers, used by Alembic. -revision: str = "7423a09421e9" -down_revision: Union[str, None] = None -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade(): - conn = op.get_bind() - inspector = Inspector.from_engine(conn) - - # Check and add user_id to logs table if it doesn't exist - if "user_id" not in [col["name"] for col in inspector.get_columns("logs")]: - op.add_column("logs", sa.Column("user_id", sa.String(), nullable=True)) - - # Check and add user_id to log_info table if it doesn't exist - if "user_id" not in [ - col["name"] for col in inspector.get_columns("log_info") - ]: - op.add_column( - "log_info", sa.Column("user_id", sa.String(), nullable=True) - ) - - -def downgrade(): - # Remove user_id column from logs table if it exists - with op.batch_alter_table("logs") as batch_op: - batch_op.drop_column("user_id") - - # Remove user_id column from log_info table if it exists - with op.batch_alter_table("log_info") as batch_op: - batch_op.drop_column("user_id") diff --git a/r2r/base/abstractions/search.py b/r2r/base/abstractions/search.py index 6218ec33b..7c94bd02d 100644 --- a/r2r/base/abstractions/search.py +++ b/r2r/base/abstractions/search.py @@ -40,7 +40,7 @@ def dict(self) -> dict: } class Config: - schema_extra = [ + json_schema_extra = [ { "fragment_id": "c68dc72e-fc23-5452-8f49-d7bd46088a96", "extraction_id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09", diff --git a/r2r/base/api/models/__init__.py b/r2r/base/api/models/__init__.py index d0afb1ec5..8ed0a0664 100644 --- a/r2r/base/api/models/__init__.py +++ b/r2r/base/api/models/__init__.py @@ -22,7 +22,6 @@ WrappedAddUserResponse, WrappedAnalyticsResponse, WrappedAppSettingsResponse, - WrappedDeleteResponse, WrappedDocumentChunkResponse, WrappedDocumentOverviewResponse, WrappedGroupListResponse, @@ -75,7 +74,6 @@ "WrappedAppSettingsResponse", "WrappedScoreCompletionResponse", "WrappedUserOverviewResponse", - "WrappedDeleteResponse", "WrappedDocumentOverviewResponse", "WrappedDocumentChunkResponse", "WrappedKnowledgeGraphResponse", diff --git a/r2r/base/api/models/ingestion/responses.py b/r2r/base/api/models/ingestion/responses.py index 76da3af35..b93d32583 100644 --- a/r2r/base/api/models/ingestion/responses.py +++ b/r2r/base/api/models/ingestion/responses.py @@ -50,7 +50,7 @@ class IngestionResponse(BaseModel): ) class Config: - schema_extra = { + json_schema_extra = { "example": { "processed_documents": [ { diff --git a/r2r/base/api/models/management/responses.py b/r2r/base/api/models/management/responses.py index de5da333e..9a04b25a1 100644 --- a/r2r/base/api/models/management/responses.py +++ b/r2r/base/api/models/management/responses.py @@ -53,13 +53,6 @@ class UserOverviewResponse(BaseModel): document_ids: List[UUID] -class DeleteResponse(BaseModel): - fragment_id: UUID - document_id: UUID - extraction_id: UUID - text: str - - class DocumentOverviewResponse(BaseModel): id: UUID title: str @@ -114,7 +107,6 @@ class AddUserResponse(BaseModel): WrappedAppSettingsResponse = ResultsWrapper[AppSettingsResponse] WrappedScoreCompletionResponse = ResultsWrapper[ScoreCompletionResponse] WrappedUserOverviewResponse = ResultsWrapper[List[UserOverviewResponse]] -WrappedDeleteResponse = ResultsWrapper[dict[str, DeleteResponse]] WrappedDocumentOverviewResponse = ResultsWrapper[ List[DocumentOverviewResponse] ] diff --git a/r2r/base/api/models/retrieval/responses.py b/r2r/base/api/models/retrieval/responses.py index b609b9a2b..36b9543e7 100644 --- a/r2r/base/api/models/retrieval/responses.py +++ b/r2r/base/api/models/retrieval/responses.py @@ -17,7 +17,7 @@ class SearchResponse(BaseModel): ) class Config: - schema_extra = { + json_schema_extra = { "example": { "vector_search_results": [ { @@ -50,7 +50,7 @@ class RAGResponse(BaseModel): ) class Config: - schema_extra = { + json_schema_extra = { "example": { "completion": { "id": "chatcmpl-example123", @@ -104,7 +104,7 @@ class RAGAgentResponse(BaseModel): ) class Config: - schema_extra = { + json_schema_extra = { "example": { "completion": { "id": "chatcmpl-example456", diff --git a/r2r/cli/cli.py b/r2r/cli/cli.py index 0e7fbdadf..c2cda4a29 100644 --- a/r2r/cli/cli.py +++ b/r2r/cli/cli.py @@ -12,7 +12,6 @@ cli.add_command(server_operations.docker_down) cli.add_command(server_operations.generate_report) cli.add_command(server_operations.health) -cli.add_command(server_operations.migrate) cli.add_command(server_operations.serve) cli.add_command(server_operations.update) cli.add_command(server_operations.version) diff --git a/r2r/cli/commands/server_operations.py b/r2r/cli/commands/server_operations.py index 0632762ca..1e0ae8933 100644 --- a/r2r/cli/commands/server_operations.py +++ b/r2r/cli/commands/server_operations.py @@ -7,8 +7,6 @@ import click from dotenv import load_dotenv -from alembic import command -from alembic.config import Config from r2r.cli.command_group import cli from r2r.cli.utils.docker_utils import ( bring_down_docker_compose, @@ -134,30 +132,6 @@ def health(obj): click.echo(response) -@cli.command() -@click.option( - "--config", - default="alembic.ini", - help="Path to the Alembic configuration file", -) -@click.pass_obj -def migrate(obj, config): - """Run database migrations.""" - click.echo("Running database migrations...") - - try: - # Create Alembic configuration - alembic_cfg = Config(config) - - # Run the migration - command.upgrade(alembic_cfg, "head") - - click.echo("Migrations completed successfully.") - except Exception as e: - click.echo(f"Error running migrations: {str(e)}") - sys.exit(1) - - @cli.command() @click.option("--host", default="0.0.0.0", help="Host to run the server on") @click.option("--port", default=8000, help="Port to run the server on") diff --git a/r2r/main/api/client/auth.py b/r2r/main/api/client/auth.py index 63b1afb9e..c697da55f 100644 --- a/r2r/main/api/client/auth.py +++ b/r2r/main/api/client/auth.py @@ -6,7 +6,6 @@ class AuthMethods: - @staticmethod async def register(client, email: str, password: str) -> UserResponse: data = {"email": email, "password": password} diff --git a/r2r/main/api/client/client.py b/r2r/main/api/client/client.py index 7f0cf3bfd..d4991e432 100644 --- a/r2r/main/api/client/client.py +++ b/r2r/main/api/client/client.py @@ -64,7 +64,7 @@ async def handle_request_error_async(response): else: message = str(error_content) except Exception: - message = response.text() + message = response.text raise R2RException( status_code=response.status_code, @@ -135,7 +135,7 @@ async def _make_request(self, method, endpoint, **kwargs): except httpx.RequestError as e: raise R2RException( status_code=500, message=f"Request failed: {str(e)}" - ) + ) from e def _get_auth_header(self) -> dict: if not self.access_token: diff --git a/r2r/main/api/client/restructure.py b/r2r/main/api/client/restructure.py index e117243f8..0fcbc9f82 100644 --- a/r2r/main/api/client/restructure.py +++ b/r2r/main/api/client/restructure.py @@ -41,7 +41,7 @@ async def query_graph(client, query: str) -> Dict[str, Any]: ) @staticmethod - async def get_graph_statistics(self) -> Dict[str, Any]: + async def get_graph_statistics(client) -> Dict[str, Any]: """ Get statistics about the knowledge graph. diff --git a/r2r/main/api/routes/management/base.py b/r2r/main/api/routes/management/base.py index fc3b3250b..f5dfb90ff 100644 --- a/r2r/main/api/routes/management/base.py +++ b/r2r/main/api/routes/management/base.py @@ -5,7 +5,7 @@ from typing import Optional import psutil -from fastapi import Body, Depends, Path, Query +from fastapi import Body, Depends, Path, Query, Response from pydantic import BaseModel from r2r.base import R2RException @@ -13,7 +13,6 @@ WrappedAddUserResponse, WrappedAnalyticsResponse, WrappedAppSettingsResponse, - WrappedDeleteResponse, WrappedDocumentChunkResponse, WrappedDocumentOverviewResponse, WrappedGroupListResponse, @@ -136,14 +135,14 @@ async def get_analytics_app( f"Invalid data in query parameters: {str(e)}", 400 ) - @self.router.delete("/delete") + @self.router.delete("/delete", status_code=204) @self.base_endpoint async def delete_app( filters: Optional[str] = Query("{}"), auth_user=Depends(self.engine.providers.auth.auth_wrapper), - ) -> WrappedDeleteResponse: + ) -> None: filters_dict = json.loads(filters) if filters else None - return await self.engine.adelete(filters=filters_dict) + await self.engine.adelete(filters=filters_dict) @self.router.get("/document_chunks") @self.base_endpoint @@ -190,7 +189,7 @@ async def documents_overview_app( auth_user=Depends(self.engine.providers.auth.auth_wrapper), ) -> WrappedDocumentOverviewResponse: request_user_ids = ( - [auth_user.id] if not auth_user.is_superuser else None + None if auth_user.is_superuser else [auth_user.id] ) return await self.engine.adocuments_overview( user_ids=request_user_ids, diff --git a/r2r/main/services/management_service.py b/r2r/main/services/management_service.py index 4ff8f4101..1833ba218 100644 --- a/r2r/main/services/management_service.py +++ b/r2r/main/services/management_service.py @@ -4,6 +4,8 @@ from collections import defaultdict from typing import Any, Dict, List, Optional, Tuple +import toml + from r2r.base import ( AnalysisTypes, LogFilterCriteria, @@ -41,18 +43,6 @@ def __init__( logging_connection, ) - @telemetry_event("UpdatePrompt") - async def update_prompt( - self, - name: str, - template: Optional[str] = None, - input_types: Optional[dict[str, str]] = {}, - *args, - **kwargs, - ): - self.providers.prompt.update_prompt(name, template, input_types) - return f"Prompt '{name}' added successfully." - @telemetry_event("Logs") async def alogs( self, run_type_filter: Optional[str] = None, max_runs: int = 100 @@ -72,7 +62,6 @@ async def alogs( logs = await self.logging_connection.get_logs(run_ids) aggregated_logs = [] - warning_shown = False for run in run_info: run_logs = [ @@ -98,14 +87,8 @@ async def alogs( if run.timestamp: log_entry["timestamp"] = run.timestamp.isoformat() - if hasattr(run, "user_id"): - if run.user_id is not None: - log_entry["user_id"] = run.user_id - elif not warning_shown: - logger.warning( - "Logs are missing user ids. This may be due to an outdated database schema. Please run `r2r migrate` to run database migrations." - ) - warning_shown = True + if hasattr(run, "user_id") and run.user_id is not None: + log_entry["user_id"] = run.user_id aggregated_logs.append(log_entry) @@ -196,8 +179,10 @@ async def aanalytics( @telemetry_event("AppSettings") async def aapp_settings(self, *args: Any, **kwargs: Any): prompts = self.providers.prompt.get_all_prompts() + config_toml = self.config.to_toml() + config_dict = toml.loads(config_toml) return { - "config": self.config.to_toml(), + "config": config_dict, "prompts": { name: prompt.dict() for name, prompt in prompts.items() }, @@ -298,7 +283,7 @@ async def delete( self.providers.database.relational.delete_from_documents_overview( document_id ) - return results + return {} @telemetry_event("DocumentsOverview") async def adocuments_overview( @@ -382,10 +367,9 @@ async def aassign_document_to_group( self, document_id: str, group_id: uuid.UUID ): - success = self.providers.database.vector.assign_document_to_group( + if self.providers.database.vector.assign_document_to_group( document_id, group_id - ) - if success: + ): return {"message": "Document assigned to group successfully"} else: raise R2RException( @@ -397,10 +381,9 @@ async def aassign_document_to_group( async def aremove_document_from_group( self, document_id: str, group_id: uuid.UUID ): - success = self.providers.database.vector.remove_document_from_group( + if self.providers.database.vector.remove_document_from_group( document_id, group_id - ) - if success: + ): return {"message": "Document removed from group successfully"} else: raise R2RException( @@ -435,27 +418,39 @@ def generate_output( # Print grouped relationships for subject, relations in grouped_relationships.items(): - output.append(f"\n== {subject} ==") - for relation, objects in relations.items(): - output.append(f" {relation}:") - for obj in objects: - output.append(f" - {obj}") + output.extend( + [ + f"\n== {subject} ==", + *(f" {relation}:" for relation in relations), + *( + f" - {obj}" + for objects in relations.values() + for obj in objects + ), + ] + ) # Print basic graph statistics - output.append("\n== Graph Statistics ==") - output.append(f"Number of nodes: {len(graph)}") - output.append( - f"Number of edges: {sum(len(neighbors) for neighbors in graph.values())}" - ) - output.append( - f"Number of connected components: {self.count_connected_components(graph)}" + output.extend( + [ + "\n== Graph Statistics ==", + f"Number of nodes: {len(graph)}", + f"Number of edges: {sum(len(neighbors) for neighbors in graph.values())}", + f"Number of connected components: {self.count_connected_components(graph)}", + ] ) # Find central nodes central_nodes = self.get_central_nodes(graph) - output.append("\n== Most Central Nodes ==") - for node, centrality in central_nodes: - output.append(f" {node}: {centrality:.4f}") + output.extend( + [ + "\n== Most Central Nodes ==", + *( + f" {node}: {centrality:.4f}" + for node, centrality in central_nodes + ), + ] + ) return output diff --git a/r2r/providers/database/vector.py b/r2r/providers/database/vector.py index c8523a159..0b3728fef 100644 --- a/r2r/providers/database/vector.py +++ b/r2r/providers/database/vector.py @@ -41,11 +41,9 @@ def __init__(self, config: DatabaseConfig, *args, **kwargs): ) # Check if a complete Postgres URI is provided - postgres_uri = self.config.extra_fields.get( + if postgres_uri := self.config.extra_fields.get( "postgres_uri" - ) or os.getenv("POSTGRES_URI") - - if postgres_uri: + ) or os.getenv("POSTGRES_URI"): # Log loudly that Postgres URI is being used logger.warning("=" * 50) logger.warning( @@ -353,7 +351,7 @@ def hybrid_search( ) params = { - "query_text": str(query_text), + "query_text": query_text, "query_embedding": list(query_vector), "match_limit": limit, "full_text_weight": full_text_weight, From 9965e12b5e4a66e0ddd63e4a8ffd7792787fa82f Mon Sep 17 00:00:00 2001 From: NolanTrem <34580718+NolanTrem@users.noreply.github.com> Date: Fri, 16 Aug 2024 15:55:13 -0700 Subject: [PATCH 2/3] fixes in sdk --- r2r/base/api/models/__init__.py | 2 -- r2r/main/api/client/auth.py | 25 +++++++++++++++---------- r2r/main/api/client/ingestion.py | 30 ------------------------------ r2r/main/api/client/management.py | 5 ++++- 4 files changed, 19 insertions(+), 43 deletions(-) diff --git a/r2r/base/api/models/__init__.py b/r2r/base/api/models/__init__.py index 8ed0a0664..9944d0fb7 100644 --- a/r2r/base/api/models/__init__.py +++ b/r2r/base/api/models/__init__.py @@ -8,7 +8,6 @@ from .management.responses import ( AnalyticsResponse, AppSettingsResponse, - DeleteResponse, DocumentChunkResponse, DocumentOverviewResponse, GroupOverviewResponse, @@ -61,7 +60,6 @@ "AppSettingsResponse", "ScoreCompletionResponse", "UserOverviewResponse", - "DeleteResponse", "DocumentOverviewResponse", "DocumentChunkResponse", "KnowledgeGraphResponse", diff --git a/r2r/main/api/client/auth.py b/r2r/main/api/client/auth.py index c697da55f..63c0aa8de 100644 --- a/r2r/main/api/client/auth.py +++ b/r2r/main/api/client/auth.py @@ -11,24 +11,26 @@ async def register(client, email: str, password: str) -> UserResponse: data = {"email": email, "password": password} return await client._make_request("POST", "register", json=data) - async def verify_email(client, verification_code: str) -> dict: + async def verify_email(self, client, verification_code: str) -> dict: return await client._make_request( "POST", "verify_email", json={"verification_code": verification_code}, ) - async def login(client, email: str, password: str) -> dict[str, Token]: + async def login( + self, client, email: str, password: str + ) -> dict[str, Token]: data = {"username": email, "password": password} response = await client._make_request("POST", "login", data=data) client.access_token = response["results"]["access_token"]["token"] client._refresh_token = response["results"]["refresh_token"]["token"] return response["results"] - async def user(client) -> UserResponse: + async def user(self, client) -> UserResponse: return await client._make_request("GET", "user") - async def refresh_access_token(client) -> dict[str, Token]: + async def refresh_access_token(self, client) -> dict[str, Token]: data = {"refresh_token": client._refresh_token} response = await client._make_request( "POST", "refresh_access_token", json=data @@ -38,7 +40,7 @@ async def refresh_access_token(client) -> dict[str, Token]: return response["results"] async def change_password( - client, current_password: str, new_password: str + self, client, current_password: str, new_password: str ) -> dict: data = { "current_password": current_password, @@ -46,27 +48,30 @@ async def change_password( } return await client._make_request("POST", "change_password", json=data) - async def request_password_reset(client, email: str) -> dict: + async def request_password_reset(self, client, email: str) -> dict: return await client._make_request( "POST", "request_password_reset", json={"email": email} ) async def confirm_password_reset( - client, reset_token: str, new_password: str + self, client, reset_token: str, new_password: str ) -> dict: data = {"reset_token": reset_token, "new_password": new_password} return await client._make_request("POST", "reset_password", json=data) - async def logout(client) -> dict: + async def logout(self, client) -> dict: response = await client._make_request("POST", "logout") client.access_token = None client._refresh_token = None return response - async def get_user_profile(client, user_id: uuid.UUID) -> UserResponse: + async def get_user_profile( + self, client, user_id: uuid.UUID + ) -> UserResponse: return await client._make_request("GET", f"user/{user_id}") async def update_user( + self, client, email: Optional[str] = None, name: Optional[str] = None, @@ -83,7 +88,7 @@ async def update_user( return await client._make_request("PUT", "user", json=data) async def delete_user( - client, user_id: uuid.UUID, password: Optional[str] = None + self, client, user_id: uuid.UUID, password: Optional[str] = None ) -> dict: data = {"user_id": str(user_id), "password": password} response = await client._make_request("DELETE", "user", json=data) diff --git a/r2r/main/api/client/ingestion.py b/r2r/main/api/client/ingestion.py index 5e85e8304..9a2514cca 100644 --- a/r2r/main/api/client/ingestion.py +++ b/r2r/main/api/client/ingestion.py @@ -163,36 +163,6 @@ async def update_files( "POST", "update_files", data=data, files=files ) - @staticmethod - async def get_document_info(client, document_id: str) -> dict: - """ - Retrieve information about a specific document. - - Args: - document_id (str): The ID of the document to retrieve information for. - - Returns: - dict: Document information including metadata, status, and version. - """ - return await client._make_request( - "GET", f"document_info/{document_id}" - ) - - @staticmethod - async def delete_document(client, document_id: str) -> dict: - """ - Delete a specific document from the system. - - Args: - document_id (str): The ID of the document to delete. - - Returns: - dict: Confirmation of document deletion. - """ - return await client._make_request( - "DELETE", f"delete_document/{document_id}" - ) - @staticmethod async def list_documents( client, diff --git a/r2r/main/api/client/management.py b/r2r/main/api/client/management.py index b056c6034..dac614bb9 100644 --- a/r2r/main/api/client/management.py +++ b/r2r/main/api/client/management.py @@ -14,8 +14,11 @@ async def update_prompt( client, name: str, template: Optional[str] = None, - input_types: Optional[dict[str, str]] = {}, + input_types: Optional[dict[str, str]] = None, ) -> dict: + if input_types is None: + input_types = {} + data = { "name": name, "template": template, From b63642911942c725efc520c51ec64ef425e75089 Mon Sep 17 00:00:00 2001 From: NolanTrem <34580718+NolanTrem@users.noreply.github.com> Date: Fri, 16 Aug 2024 16:42:14 -0700 Subject: [PATCH 3/3] Add in more fixes --- r2r/main/api/client/auth.py | 36 ++++++++++++++++----------- r2r/main/api/routes/ingestion/base.py | 27 ++++++++++---------- 2 files changed, 34 insertions(+), 29 deletions(-) diff --git a/r2r/main/api/client/auth.py b/r2r/main/api/client/auth.py index 63c0aa8de..c1a627056 100644 --- a/r2r/main/api/client/auth.py +++ b/r2r/main/api/client/auth.py @@ -11,26 +11,28 @@ async def register(client, email: str, password: str) -> UserResponse: data = {"email": email, "password": password} return await client._make_request("POST", "register", json=data) - async def verify_email(self, client, verification_code: str) -> dict: + @staticmethod + async def verify_email(client, verification_code: str) -> dict: return await client._make_request( "POST", "verify_email", json={"verification_code": verification_code}, ) - async def login( - self, client, email: str, password: str - ) -> dict[str, Token]: + @staticmethod + async def login(client, email: str, password: str) -> dict[str, Token]: data = {"username": email, "password": password} response = await client._make_request("POST", "login", data=data) client.access_token = response["results"]["access_token"]["token"] client._refresh_token = response["results"]["refresh_token"]["token"] return response["results"] - async def user(self, client) -> UserResponse: + @staticmethod + async def user(client) -> UserResponse: return await client._make_request("GET", "user") - async def refresh_access_token(self, client) -> dict[str, Token]: + @staticmethod + async def refresh_access_token(client) -> dict[str, Token]: data = {"refresh_token": client._refresh_token} response = await client._make_request( "POST", "refresh_access_token", json=data @@ -39,8 +41,9 @@ async def refresh_access_token(self, client) -> dict[str, Token]: client._refresh_token = response["results"]["refresh_token"]["token"] return response["results"] + @staticmethod async def change_password( - self, client, current_password: str, new_password: str + client, current_password: str, new_password: str ) -> dict: data = { "current_password": current_password, @@ -48,30 +51,32 @@ async def change_password( } return await client._make_request("POST", "change_password", json=data) - async def request_password_reset(self, client, email: str) -> dict: + @staticmethod + async def request_password_reset(client, email: str) -> dict: return await client._make_request( "POST", "request_password_reset", json={"email": email} ) + @staticmethod async def confirm_password_reset( - self, client, reset_token: str, new_password: str + client, reset_token: str, new_password: str ) -> dict: data = {"reset_token": reset_token, "new_password": new_password} return await client._make_request("POST", "reset_password", json=data) - async def logout(self, client) -> dict: + @staticmethod + async def logout(client) -> dict: response = await client._make_request("POST", "logout") client.access_token = None client._refresh_token = None return response - async def get_user_profile( - self, client, user_id: uuid.UUID - ) -> UserResponse: + @staticmethod + async def get_user_profile(client, user_id: uuid.UUID) -> UserResponse: return await client._make_request("GET", f"user/{user_id}") + @staticmethod async def update_user( - self, client, email: Optional[str] = None, name: Optional[str] = None, @@ -87,8 +92,9 @@ async def update_user( data = {k: v for k, v in data.items() if v is not None} return await client._make_request("PUT", "user", json=data) + @staticmethod async def delete_user( - self, client, user_id: uuid.UUID, password: Optional[str] = None + client, user_id: uuid.UUID, password: Optional[str] = None ) -> dict: data = {"user_id": str(user_id), "password": password} response = await client._make_request("DELETE", "user", json=data) diff --git a/r2r/main/api/routes/ingestion/base.py b/r2r/main/api/routes/ingestion/base.py index 8da142dca..db23021e0 100644 --- a/r2r/main/api/routes/ingestion/base.py +++ b/r2r/main/api/routes/ingestion/base.py @@ -182,20 +182,19 @@ async def ingest_files_app( # Handle user management logic at the request level if not auth_user: for metadata in metadatas or []: - if "user_id" in metadata: - if not is_superuser and metadata["user_id"] != str( - auth_user.id - ): - raise R2RException( - status_code=403, - message="Non-superusers cannot set user_id in metadata.", - ) - if "group_ids" in metadata: - if not is_superuser: - raise R2RException( - status_code=403, - message="Non-superusers cannot set group_ids in metadata.", - ) + if "user_id" in metadata and ( + not is_superuser + and metadata["user_id"] != str(auth_user.id) + ): + raise R2RException( + status_code=403, + message="Non-superusers cannot set user_id in metadata.", + ) + if "group_ids" in metadata and not is_superuser: + raise R2RException( + status_code=403, + message="Non-superusers cannot set group_ids in metadata.", + ) # If user is not a superuser, set user_id in metadata metadata["user_id"] = str(auth_user.id)