From 403344d459cf62a4eafdedb54c7a317cf39e495b Mon Sep 17 00:00:00 2001 From: Igor Magalhaes Date: Sun, 31 Dec 2023 00:41:10 -0300 Subject: [PATCH] create_tables_on_startd created, docs added --- README.md | 23 +++++++- src/app/api/paginated.py | 10 ++-- src/app/api/v1/login.py | 34 +++-------- src/app/api/v1/logout.py | 7 +-- src/app/api/v1/rate_limits.py | 67 ++++++--------------- src/app/api/v1/tasks.py | 3 +- src/app/api/v1/tiers.py | 44 ++++---------- src/app/api/v1/users.py | 83 ++++++++++----------------- src/app/core/config.py | 9 +-- src/app/core/db/database.py | 16 ++---- src/app/core/db/token_blacklist.py | 4 +- src/app/core/logger.py | 8 +-- src/app/core/schemas.py | 10 ++-- src/app/core/setup.py | 24 +++++--- src/app/core/utils/cache.py | 3 +- src/app/core/utils/rate_limit.py | 9 +-- src/app/models/post.py | 12 +--- src/app/models/rate_limit.py | 10 +--- src/app/models/tier.py | 8 +-- src/app/schemas/post.py | 78 ++++++------------------- src/app/schemas/rate_limit.py | 43 ++++---------- src/app/schemas/tier.py | 5 +- src/app/schemas/user.py | 81 ++++++-------------------- src/app/worker.py | 5 +- src/scripts/create_first_superuser.py | 23 ++++---- src/scripts/create_first_tier.py | 12 ++-- tests/helper.py | 7 +-- tests/test_user.py | 53 ++++++----------- 28 files changed, 233 insertions(+), 458 deletions(-) diff --git a/README.md b/README.md index 95a4214..e96ae5f 100644 --- a/README.md +++ b/README.md @@ -94,7 +94,8 @@ 10. [ARQ Job Queues](#510-arq-job-queues) 11. [Rate Limiting](#511-rate-limiting) 12. [JWT Authentication](#512-jwt-authentication) - 13. [Running](#512-running) + 13. [Running](#513-running) + 14. [Create Application](#514-create-application) 6. [Running in Production](#6-running-in-production) 1. [Uvicorn Workers with Gunicorn](#61-uvicorn-workers-with-gunicorn) 2. [Running With NGINX](#62-running-with-nginx) @@ -1393,6 +1394,26 @@ CMD ["gunicorn", "app.main:app", "-w", "4", "-k", "uvicorn.workers.UvicornWorker > [!CAUTION] > Do not forget to set the `ENVIRONMENT` in `.env` to `production` unless you want the API docs to be public. +### 5.14 Create Application +If you want to stop tables from being created every time you run the api, you should disable this here: +```python +# app/main.py + +from .api import router +from .core.config import settings +from .core.setup import create_application + +# create_tables_on_start defaults to True +app = create_application(router=router, settings=settings, create_tables_on_start=False) +``` + +This `create_application` function is defined in `app/core/setup.py`, and it's a flexible way to configure the behavior of your application. + +A few examples: +- Deactivate or password protect /docs +- Add client-side cache middleware +- Add Startup and Shutdown event handlers for cache, queue and rate limit + ### 6.2 Running with NGINX NGINX is a high-performance web server, known for its stability, rich feature set, simple configuration, and low resource consumption. NGINX acts as a reverse proxy, that is, it receives client requests, forwards them to the FastAPI server (running via Uvicorn or Gunicorn), and then passes the responses back to the clients. diff --git a/src/app/api/paginated.py b/src/app/api/paginated.py index 4649731..4e0c28b 100644 --- a/src/app/api/paginated.py +++ b/src/app/api/paginated.py @@ -4,6 +4,7 @@ SchemaType = TypeVar("SchemaType", bound=BaseModel) + class ListResponse(BaseModel, Generic[SchemaType]): data: List[SchemaType] @@ -15,11 +16,7 @@ class PaginatedListResponse(ListResponse[SchemaType]): items_per_page: int | None = None -def paginated_response( - crud_data: dict, - page: int, - items_per_page: int -) -> Dict[str, Any]: +def paginated_response(crud_data: dict, page: int, items_per_page: int) -> Dict[str, Any]: """ Create a paginated response based on the provided data and pagination parameters. @@ -46,9 +43,10 @@ def paginated_response( "total_count": crud_data["total_count"], "has_more": (page * items_per_page) < crud_data["total_count"], "page": page, - "items_per_page": items_per_page + "items_per_page": items_per_page, } + def compute_offset(page: int, items_per_page: int) -> int: """ Calculate the offset for pagination based on the given page number and items per page. diff --git a/src/app/api/v1/login.py b/src/app/api/v1/login.py index e983cac..85bf219 100644 --- a/src/app/api/v1/login.py +++ b/src/app/api/v1/login.py @@ -20,48 +20,32 @@ router = fastapi.APIRouter(tags=["login"]) + @router.post("/login", response_model=Token) async def login_for_access_token( response: Response, form_data: Annotated[OAuth2PasswordRequestForm, Depends()], - db: Annotated[AsyncSession, Depends(async_get_db)] + db: Annotated[AsyncSession, Depends(async_get_db)], ) -> Dict[str, str]: - user = await authenticate_user( - username_or_email=form_data.username, - password=form_data.password, - db=db - ) + user = await authenticate_user(username_or_email=form_data.username, password=form_data.password, db=db) if not user: raise UnauthorizedException("Wrong username, email or password.") - + access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) - access_token = await create_access_token( - data={"sub": user["username"]}, expires_delta=access_token_expires - ) + access_token = await create_access_token(data={"sub": user["username"]}, expires_delta=access_token_expires) refresh_token = await create_refresh_token(data={"sub": user["username"]}) max_age = settings.REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60 response.set_cookie( - key="refresh_token", - value=refresh_token, - httponly=True, - secure=True, - samesite='Lax', - max_age=max_age + key="refresh_token", value=refresh_token, httponly=True, secure=True, samesite="Lax", max_age=max_age ) - - return { - "access_token": access_token, - "token_type": "bearer" - } + + return {"access_token": access_token, "token_type": "bearer"} @router.post("/refresh") -async def refresh_access_token( - request: Request, - db: AsyncSession = Depends(async_get_db) -) -> Dict[str, str]: +async def refresh_access_token(request: Request, db: AsyncSession = Depends(async_get_db)) -> Dict[str, str]: refresh_token = request.cookies.get("refresh_token") if not refresh_token: raise UnauthorizedException("Refresh token missing.") diff --git a/src/app/api/v1/logout.py b/src/app/api/v1/logout.py index 59125f1..be29462 100644 --- a/src/app/api/v1/logout.py +++ b/src/app/api/v1/logout.py @@ -10,17 +10,16 @@ router = APIRouter(tags=["login"]) + @router.post("/logout") async def logout( - response: Response, - access_token: str = Depends(oauth2_scheme), - db: AsyncSession = Depends(async_get_db) + response: Response, access_token: str = Depends(oauth2_scheme), db: AsyncSession = Depends(async_get_db) ) -> Dict[str, str]: try: await blacklist_token(token=access_token, db=db) response.delete_cookie(key="refresh_token") return {"message": "Logged out successfully"} - + except JWTError: raise UnauthorizedException("Invalid token.") diff --git a/src/app/api/v1/rate_limits.py b/src/app/api/v1/rate_limits.py index f2c54bb..58192df 100644 --- a/src/app/api/v1/rate_limits.py +++ b/src/app/api/v1/rate_limits.py @@ -14,12 +14,10 @@ router = fastapi.APIRouter(tags=["rate_limits"]) + @router.post("/tier/{tier_name}/rate_limit", dependencies=[Depends(get_current_superuser)], status_code=201) async def write_rate_limit( - request: Request, - tier_name: str, - rate_limit: RateLimitCreate, - db: Annotated[AsyncSession, Depends(async_get_db)] + request: Request, tier_name: str, rate_limit: RateLimitCreate, db: Annotated[AsyncSession, Depends(async_get_db)] ) -> RateLimitRead: db_tier = await crud_tiers.get(db=db, name=tier_name) if not db_tier: @@ -31,7 +29,7 @@ async def write_rate_limit( db_rate_limit = await crud_rate_limits.exists(db=db, name=rate_limit_internal_dict["name"]) if db_rate_limit: raise DuplicateValueException("Rate Limit Name not available") - + rate_limit_internal = RateLimitCreateInternal(**rate_limit_internal_dict) return await crud_rate_limits.create(db=db, object=rate_limit_internal) @@ -42,7 +40,7 @@ async def read_rate_limits( tier_name: str, db: Annotated[AsyncSession, Depends(async_get_db)], page: int = 1, - items_per_page: int = 10 + items_per_page: int = 10, ) -> dict: db_tier = await crud_tiers.get(db=db, name=tier_name) if not db_tier: @@ -53,33 +51,21 @@ async def read_rate_limits( offset=compute_offset(page, items_per_page), limit=items_per_page, schema_to_select=RateLimitRead, - tier_id=db_tier["id"] + tier_id=db_tier["id"], ) - return paginated_response( - crud_data=rate_limits_data, - page=page, - items_per_page=items_per_page - ) + return paginated_response(crud_data=rate_limits_data, page=page, items_per_page=items_per_page) @router.get("/tier/{tier_name}/rate_limit/{id}", response_model=RateLimitRead) async def read_rate_limit( - request: Request, - tier_name: str, - id: int, - db: Annotated[AsyncSession, Depends(async_get_db)] + request: Request, tier_name: str, id: int, db: Annotated[AsyncSession, Depends(async_get_db)] ) -> dict: db_tier = await crud_tiers.get(db=db, name=tier_name) if not db_tier: raise NotFoundException("Tier not found") - - db_rate_limit = await crud_rate_limits.get( - db=db, - schema_to_select=RateLimitRead, - tier_id=db_tier["id"], - id=id - ) + + db_rate_limit = await crud_rate_limits.get(db=db, schema_to_select=RateLimitRead, tier_id=db_tier["id"], id=id) if db_rate_limit is None: raise NotFoundException("Rate Limit not found") @@ -92,26 +78,17 @@ async def patch_rate_limit( tier_name: str, id: int, values: RateLimitUpdate, - db: Annotated[AsyncSession, Depends(async_get_db)] + db: Annotated[AsyncSession, Depends(async_get_db)], ) -> Dict[str, str]: db_tier = await crud_tiers.get(db=db, name=tier_name) if db_tier is None: raise NotFoundException("Tier not found") - - db_rate_limit = await crud_rate_limits.get( - db=db, - schema_to_select=RateLimitRead, - tier_id=db_tier["id"], - id=id - ) + + db_rate_limit = await crud_rate_limits.get(db=db, schema_to_select=RateLimitRead, tier_id=db_tier["id"], id=id) if db_rate_limit is None: raise NotFoundException("Rate Limit not found") - - db_rate_limit_path = await crud_rate_limits.exists( - db=db, - tier_id=db_tier["id"], - path=values.path - ) + + db_rate_limit_path = await crud_rate_limits.exists(db=db, tier_id=db_tier["id"], path=values.path) if db_rate_limit_path is not None: raise DuplicateValueException("There is already a rate limit for this path") @@ -125,23 +102,15 @@ async def patch_rate_limit( @router.delete("/tier/{tier_name}/rate_limit/{id}", dependencies=[Depends(get_current_superuser)]) async def erase_rate_limit( - request: Request, - tier_name: str, - id: int, - db: Annotated[AsyncSession, Depends(async_get_db)] + request: Request, tier_name: str, id: int, db: Annotated[AsyncSession, Depends(async_get_db)] ) -> Dict[str, str]: db_tier = await crud_tiers.get(db=db, name=tier_name) if not db_tier: raise NotFoundException("Tier not found") - - db_rate_limit = await crud_rate_limits.get( - db=db, - schema_to_select=RateLimitRead, - tier_id=db_tier["id"], - id=id - ) + + db_rate_limit = await crud_rate_limits.get(db=db, schema_to_select=RateLimitRead, tier_id=db_tier["id"], id=id) if db_rate_limit is None: raise RateLimitException("Rate Limit not found") - + await crud_rate_limits.delete(db=db, db_row=db_rate_limit, id=db_rate_limit["id"]) return {"message": "Rate Limit deleted"} diff --git a/src/app/api/v1/tasks.py b/src/app/api/v1/tasks.py index 84501a7..2a39f0f 100644 --- a/src/app/api/v1/tasks.py +++ b/src/app/api/v1/tasks.py @@ -9,6 +9,7 @@ router = APIRouter(prefix="/tasks", tags=["tasks"]) + @router.post("/task", response_model=Job, status_code=201, dependencies=[Depends(rate_limiter)]) async def create_task(message: str) -> Dict[str, str]: """ @@ -24,7 +25,7 @@ async def create_task(message: str) -> Dict[str, str]: Dict[str, str] A dictionary containing the ID of the created task. """ - job = await queue.pool.enqueue_job("sample_background_task", message) # type: ignore + job = await queue.pool.enqueue_job("sample_background_task", message) # type: ignore return {"id": job.job_id} diff --git a/src/app/api/v1/tiers.py b/src/app/api/v1/tiers.py index 9d6a120..82d8be7 100644 --- a/src/app/api/v1/tiers.py +++ b/src/app/api/v1/tiers.py @@ -13,48 +13,33 @@ router = fastapi.APIRouter(tags=["tiers"]) + @router.post("/tier", dependencies=[Depends(get_current_superuser)], status_code=201) async def write_tier( - request: Request, - tier: TierCreate, - db: Annotated[AsyncSession, Depends(async_get_db)] + request: Request, tier: TierCreate, db: Annotated[AsyncSession, Depends(async_get_db)] ) -> TierRead: tier_internal_dict = tier.model_dump() db_tier = await crud_tiers.exists(db=db, name=tier_internal_dict["name"]) if db_tier: raise DuplicateValueException("Tier Name not available") - + tier_internal = TierCreateInternal(**tier_internal_dict) return await crud_tiers.create(db=db, object=tier_internal) @router.get("/tiers", response_model=PaginatedListResponse[TierRead]) async def read_tiers( - request: Request, - db: Annotated[AsyncSession, Depends(async_get_db)], - page: int = 1, - items_per_page: int = 10 + request: Request, db: Annotated[AsyncSession, Depends(async_get_db)], page: int = 1, items_per_page: int = 10 ) -> dict: tiers_data = await crud_tiers.get_multi( - db=db, - offset=compute_offset(page, items_per_page), - limit=items_per_page, - schema_to_select=TierRead + db=db, offset=compute_offset(page, items_per_page), limit=items_per_page, schema_to_select=TierRead ) - return paginated_response( - crud_data=tiers_data, - page=page, - items_per_page=items_per_page - ) + return paginated_response(crud_data=tiers_data, page=page, items_per_page=items_per_page) @router.get("/tier/{name}", response_model=TierRead) -async def read_tier( - request: Request, - name: str, - db: Annotated[AsyncSession, Depends(async_get_db)] -) -> dict: +async def read_tier(request: Request, name: str, db: Annotated[AsyncSession, Depends(async_get_db)]) -> dict: db_tier = await crud_tiers.get(db=db, schema_to_select=TierRead, name=name) if db_tier is None: raise NotFoundException("Tier not found") @@ -64,28 +49,21 @@ async def read_tier( @router.patch("/tier/{name}", dependencies=[Depends(get_current_superuser)]) async def patch_tier( - request: Request, - values: TierUpdate, - name: str, - db: Annotated[AsyncSession, Depends(async_get_db)] + request: Request, values: TierUpdate, name: str, db: Annotated[AsyncSession, Depends(async_get_db)] ) -> Dict[str, str]: db_tier = await crud_tiers.get(db=db, schema_to_select=TierRead, name=name) if db_tier is None: raise NotFoundException("Tier not found") - + await crud_tiers.update(db=db, object=values, name=name) return {"message": "Tier updated"} @router.delete("/tier/{name}", dependencies=[Depends(get_current_superuser)]) -async def erase_tier( - request: Request, - name: str, - db: Annotated[AsyncSession, Depends(async_get_db)] -) -> Dict[str, str]: +async def erase_tier(request: Request, name: str, db: Annotated[AsyncSession, Depends(async_get_db)]) -> Dict[str, str]: db_tier = await crud_tiers.get(db=db, schema_to_select=TierRead, name=name) if db_tier is None: raise NotFoundException("Tier not found") - + await crud_tiers.delete(db=db, db_row=db_tier, name=name) return {"message": "Tier deleted"} diff --git a/src/app/api/v1/users.py b/src/app/api/v1/users.py index da8d203..8d94f6b 100644 --- a/src/app/api/v1/users.py +++ b/src/app/api/v1/users.py @@ -18,11 +18,10 @@ router = fastapi.APIRouter(tags=["users"]) + @router.post("/user", response_model=UserRead, status_code=201) async def write_user( - request: Request, - user: UserCreate, - db: Annotated[AsyncSession, Depends(async_get_db)] + request: Request, user: UserCreate, db: Annotated[AsyncSession, Depends(async_get_db)] ) -> UserRead: email_row = await crud_users.exists(db=db, email=user.email) if email_row: @@ -31,7 +30,7 @@ async def write_user( username_row = await crud_users.exists(db=db, username=user.username) if username_row: raise DuplicateValueException("Username not available") - + user_internal_dict = user.model_dump() user_internal_dict["hashed_password"] = get_password_hash(password=user_internal_dict["password"]) del user_internal_dict["password"] @@ -42,40 +41,26 @@ async def write_user( @router.get("/users", response_model=PaginatedListResponse[UserRead]) async def read_users( - request: Request, - db: Annotated[AsyncSession, Depends(async_get_db)], - page: int = 1, - items_per_page: int = 10 + request: Request, db: Annotated[AsyncSession, Depends(async_get_db)], page: int = 1, items_per_page: int = 10 ) -> dict: users_data = await crud_users.get_multi( db=db, offset=compute_offset(page, items_per_page), limit=items_per_page, schema_to_select=UserRead, - is_deleted=False + is_deleted=False, ) - return paginated_response( - crud_data=users_data, - page=page, - items_per_page=items_per_page - ) + return paginated_response(crud_data=users_data, page=page, items_per_page=items_per_page) @router.get("/user/me/", response_model=UserRead) -async def read_users_me( - request: Request, - current_user: Annotated[UserRead, Depends(get_current_user)] -) -> UserRead: +async def read_users_me(request: Request, current_user: Annotated[UserRead, Depends(get_current_user)]) -> UserRead: return current_user @router.get("/user/{username}", response_model=UserRead) -async def read_user( - request: Request, - username: str, - db: Annotated[AsyncSession, Depends(async_get_db)] -) -> dict: +async def read_user(request: Request, username: str, db: Annotated[AsyncSession, Depends(async_get_db)]) -> dict: db_user = await crud_users.get(db=db, schema_to_select=UserRead, username=username, is_deleted=False) if db_user is None: raise NotFoundException("User not found") @@ -85,19 +70,19 @@ async def read_user( @router.patch("/user/{username}") async def patch_user( - request: Request, + request: Request, values: UserUpdate, username: str, current_user: Annotated[UserRead, Depends(get_current_user)], - db: Annotated[AsyncSession, Depends(async_get_db)] + db: Annotated[AsyncSession, Depends(async_get_db)], ) -> Dict[str, str]: db_user = await crud_users.get(db=db, schema_to_select=UserRead, username=username) if db_user is None: raise NotFoundException("User not found") - + if db_user["username"] != current_user["username"]: raise ForbiddenException() - + if values.username != db_user["username"]: existing_username = await crud_users.exists(db=db, username=values.username) if existing_username: @@ -114,16 +99,16 @@ async def patch_user( @router.delete("/user/{username}") async def erase_user( - request: Request, + request: Request, username: str, current_user: Annotated[UserRead, Depends(get_current_user)], db: Annotated[AsyncSession, Depends(async_get_db)], - token: str = Depends(oauth2_scheme) + token: str = Depends(oauth2_scheme), ) -> Dict[str, str]: db_user = await crud_users.get(db=db, schema_to_select=UserRead, username=username) if not db_user: raise NotFoundException("User not found") - + if username != current_user["username"]: raise ForbiddenException() @@ -134,15 +119,15 @@ async def erase_user( @router.delete("/db_user/{username}", dependencies=[Depends(get_current_superuser)]) async def erase_db_user( - request: Request, + request: Request, username: str, db: Annotated[AsyncSession, Depends(async_get_db)], - token: str = Depends(oauth2_scheme) + token: str = Depends(oauth2_scheme), ) -> Dict[str, str]: db_user = await crud_users.exists(db=db, username=username) if not db_user: raise NotFoundException("User not found") - + await crud_users.db_delete(db=db, username=username) await blacklist_token(token=token, db=db) return {"message": "User deleted from the database"} @@ -150,9 +135,7 @@ async def erase_db_user( @router.get("/user/{username}/rate_limits", dependencies=[Depends(get_current_superuser)]) async def read_user_rate_limits( - request: Request, - username: str, - db: Annotated[AsyncSession, Depends(async_get_db)] + request: Request, username: str, db: Annotated[AsyncSession, Depends(async_get_db)] ) -> Dict[str, Any]: db_user: dict | None = await crud_users.get(db=db, username=username, schema_to_select=UserRead) if db_user is None: @@ -161,15 +144,12 @@ async def read_user_rate_limits( if db_user["tier_id"] is None: db_user["tier_rate_limits"] = [] return db_user - + db_tier = await crud_tiers.get(db=db, id=db_user["tier_id"]) if db_tier is None: raise NotFoundException("Tier not found") - - db_rate_limits = await crud_rate_limits.get_multi( - db=db, - tier_id=db_tier["id"] - ) + + db_rate_limits = await crud_rate_limits.get_multi(db=db, tier_id=db_tier["id"]) db_user["tier_rate_limits"] = db_rate_limits["data"] @@ -178,25 +158,23 @@ async def read_user_rate_limits( @router.get("/user/{username}/tier") async def read_user_tier( - request: Request, - username: str, - db: Annotated[AsyncSession, Depends(async_get_db)] + request: Request, username: str, db: Annotated[AsyncSession, Depends(async_get_db)] ) -> dict | None: db_user = await crud_users.get(db=db, username=username, schema_to_select=UserRead) if db_user is None: raise NotFoundException("User not found") - + db_tier = await crud_tiers.exists(db=db, id=db_user["tier_id"]) if not db_tier: raise NotFoundException("Tier not found") joined = await crud_users.get_joined( - db=db, - join_model=Tier, - join_prefix="tier_", + db=db, + join_model=Tier, + join_prefix="tier_", schema_to_select=UserRead, join_schema_to_select=TierRead, - username=username + username=username, ) return joined @@ -204,10 +182,7 @@ async def read_user_tier( @router.patch("/user/{username}/tier", dependencies=[Depends(get_current_superuser)]) async def patch_user_tier( - request: Request, - username: str, - values: UserTierUpdate, - db: Annotated[AsyncSession, Depends(async_get_db)] + request: Request, username: str, values: UserTierUpdate, db: Annotated[AsyncSession, Depends(async_get_db)] ) -> Dict[str, str]: db_user = await crud_users.get(db=db, username=username, schema_to_select=UserRead) if db_user is None: diff --git a/src/app/core/config.py b/src/app/core/config.py index e488a5d..e8b85bf 100644 --- a/src/app/core/config.py +++ b/src/app/core/config.py @@ -8,6 +8,7 @@ env_path = os.path.join(current_file_dir, "..", "..", ".env") config = Config(env_path) + class AppSettings(BaseSettings): APP_NAME: str = config("APP_NAME", default="FastAPI app") APP_DESCRIPTION: str | None = config("APP_DESCRIPTION", default=None) @@ -109,9 +110,9 @@ class EnvironmentSettings(BaseSettings): class Settings( - AppSettings, - PostgresSettings, - CryptSettings, + AppSettings, + PostgresSettings, + CryptSettings, FirstUserSettings, TestSettings, RedisCacheSettings, @@ -119,7 +120,7 @@ class Settings( RedisQueueSettings, RedisRateLimiterSettings, DefaultRateLimitSettings, - EnvironmentSettings + EnvironmentSettings, ): pass diff --git a/src/app/core/db/database.py b/src/app/core/db/database.py index 453e509..076efa5 100644 --- a/src/app/core/db/database.py +++ b/src/app/core/db/database.py @@ -8,25 +8,19 @@ class Base(DeclarativeBase, MappedAsDataclass): pass + DATABASE_URI = settings.POSTGRES_URI DATABASE_PREFIX = settings.POSTGRES_ASYNC_PREFIX DATABASE_URL = f"{DATABASE_PREFIX}{DATABASE_URI}" -async_engine = create_async_engine( - DATABASE_URL, - echo=False, - future=True -) +async_engine = create_async_engine(DATABASE_URL, echo=False, future=True) + +local_session = sessionmaker(bind=async_engine, class_=AsyncSession, expire_on_commit=False) -local_session = sessionmaker( - bind=async_engine, - class_=AsyncSession, - expire_on_commit=False -) async def async_get_db() -> AsyncSession: async_session = local_session - + async with async_session() as db: yield db await db.commit() diff --git a/src/app/core/db/token_blacklist.py b/src/app/core/db/token_blacklist.py index 7f069dc..6a387ae 100644 --- a/src/app/core/db/token_blacklist.py +++ b/src/app/core/db/token_blacklist.py @@ -9,8 +9,6 @@ class TokenBlacklist(Base): __tablename__ = "token_blacklist" - id: Mapped[int] = mapped_column( - "id", autoincrement=True, nullable=False, unique=True, primary_key=True, init=False - ) + id: Mapped[int] = mapped_column("id", autoincrement=True, nullable=False, unique=True, primary_key=True, init=False) token: Mapped[str] = mapped_column(String, unique=True, index=True) expires_at: Mapped[datetime] = mapped_column(DateTime) diff --git a/src/app/core/logger.py b/src/app/core/logger.py index dc71c2d..91b35a1 100644 --- a/src/app/core/logger.py +++ b/src/app/core/logger.py @@ -2,14 +2,14 @@ import os from logging.handlers import RotatingFileHandler -LOG_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'logs') +LOG_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs") if not os.path.exists(LOG_DIR): os.makedirs(LOG_DIR) -LOG_FILE_PATH = os.path.join(LOG_DIR, 'app.log') +LOG_FILE_PATH = os.path.join(LOG_DIR, "app.log") LOGGING_LEVEL = logging.INFO -LOGGING_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' +LOGGING_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" logging.basicConfig(level=LOGGING_LEVEL, format=LOGGING_FORMAT) @@ -17,4 +17,4 @@ file_handler.setLevel(LOGGING_LEVEL) file_handler.setFormatter(logging.Formatter(LOGGING_FORMAT)) -logging.getLogger('').addHandler(file_handler) +logging.getLogger("").addHandler(file_handler) diff --git a/src/app/core/schemas.py b/src/app/core/schemas.py index 605be29..e0a4472 100644 --- a/src/app/core/schemas.py +++ b/src/app/core/schemas.py @@ -10,20 +10,21 @@ class HealthCheck(BaseModel): version: str description: str + # -------------- mixins -------------- class UUIDSchema(BaseModel): uuid: uuid_pkg.UUID = Field(default_factory=uuid_pkg.uuid4) class TimestampSchema(BaseModel): - created_at: datetime = Field(default_factory=lambda: datetime.now(UTC).replace(tzinfo=None)) + created_at: datetime = Field(default_factory=lambda: datetime.now(UTC).replace(tzinfo=None)) updated_at: datetime = Field(default=None) @field_serializer("created_at") def serialize_dt(self, created_at: datetime | None, _info: Any) -> str | None: if created_at is not None: return created_at.isoformat() - + return None @field_serializer("updated_at") @@ -33,15 +34,16 @@ def serialize_updated_at(self, updated_at: datetime | None, _info: Any) -> str | return None + class PersistentDeletion(BaseModel): deleted_at: datetime | None = Field(default=None) is_deleted: bool = False - @field_serializer('deleted_at') + @field_serializer("deleted_at") def serialize_dates(self, deleted_at: datetime | None, _info: Any) -> str | None: if deleted_at is not None: return deleted_at.isoformat() - + return None diff --git a/src/app/core/setup.py b/src/app/core/setup.py index d731a5e..aa08415 100644 --- a/src/app/core/setup.py +++ b/src/app/core/setup.py @@ -80,19 +80,19 @@ def create_application( RedisRateLimiterSettings, EnvironmentSettings, ], + create_tables_on_start: bool = True, **kwargs: Any, ) -> FastAPI: """ Creates and configures a FastAPI application based on the provided settings. - This function initializes a FastAPI application, then conditionally configures - it with various settings and handlers. The specific configuration is determined - by the type of the `settings` object provided. + This function initializes a FastAPI application and configures it with various settings + and handlers based on the type of the `settings` object provided. Parameters ---------- router : APIRouter - The APIRouter object that contains the routes to be included in the FastAPI application. + The APIRouter object containing the routes to be included in the FastAPI application. settings An instance representing the settings for configuring the FastAPI application. @@ -103,19 +103,27 @@ def create_application( - RedisCacheSettings: Sets up event handlers for creating and closing a Redis cache pool. - ClientSideCacheSettings: Integrates middleware for client-side caching. - RedisQueueSettings: Sets up event handlers for creating and closing a Redis queue pool. + - RedisRateLimiterSettings: Sets up event handlers for creating and closing a Redis rate limiter pool. - EnvironmentSettings: Conditionally sets documentation URLs and integrates custom routes for API documentation - based on environment type. + based on the environment type. + + create_tables_on_start : bool + A flag to indicate whether to create database tables on application startup. + Defaults to True. **kwargs - Extra keyword arguments passed directly to the FastAPI constructor. + Additional keyword arguments passed directly to the FastAPI constructor. Returns ------- FastAPI A fully configured FastAPI application instance. + The function configures the FastAPI application with different features and behaviors + based on the provided settings. It includes setting up database connections, Redis pools + for caching, queue, and rate limiting, client-side caching, and customizing the API documentation + based on the environment settings. """ - # --- before creating application --- if isinstance(settings, AppSettings): to_update = { @@ -135,7 +143,7 @@ def create_application( application.include_router(router) application.add_event_handler("startup", set_threadpool_tokens) - if isinstance(settings, DatabaseSettings): + if isinstance(settings, DatabaseSettings) and create_tables_on_start: application.add_event_handler("startup", create_tables) if isinstance(settings, RedisCacheSettings): diff --git a/src/app/core/utils/cache.py b/src/app/core/utils/cache.py index 9742c96..7a1d4f4 100644 --- a/src/app/core/utils/cache.py +++ b/src/app/core/utils/cache.py @@ -1,7 +1,8 @@ import functools import json import re -from typing import Any, Callable, Dict, List, Tuple, Union +from collections.abc import Callable +from typing import Any, Dict, List, Tuple, Union from fastapi import Request, Response from fastapi.encoders import jsonable_encoder diff --git a/src/app/core/utils/rate_limit.py b/src/app/core/utils/rate_limit.py index 80dba8b..1c1ba7c 100644 --- a/src/app/core/utils/rate_limit.py +++ b/src/app/core/utils/rate_limit.py @@ -11,13 +11,8 @@ pool: ConnectionPool | None = None client: Redis | None = None -async def is_rate_limited( - db: AsyncSession, - user_id: int, - path: str, - limit: int, - period: int -) -> bool: + +async def is_rate_limited(db: AsyncSession, user_id: int, path: str, limit: int, period: int) -> bool: if client is None: logger.error("Redis client is not initialized.") raise Exception("Redis client is not initialized.") diff --git a/src/app/models/post.py b/src/app/models/post.py index 94bb514..d183bf5 100644 --- a/src/app/models/post.py +++ b/src/app/models/post.py @@ -11,20 +11,14 @@ class Post(Base): __tablename__ = "post" - id: Mapped[int] = mapped_column( - "id", autoincrement=True, nullable=False, unique=True, primary_key=True, init=False - ) + id: Mapped[int] = mapped_column("id", autoincrement=True, nullable=False, unique=True, primary_key=True, init=False) created_by_user_id: Mapped[int] = mapped_column(ForeignKey("user.id"), index=True) title: Mapped[str] = mapped_column(String(30)) text: Mapped[str] = mapped_column(String(63206)) - uuid: Mapped[uuid_pkg.UUID] = mapped_column( - default_factory=uuid_pkg.uuid4, primary_key=True, unique=True - ) + uuid: Mapped[uuid_pkg.UUID] = mapped_column(default_factory=uuid_pkg.uuid4, primary_key=True, unique=True) media_url: Mapped[str | None] = mapped_column(String, default=None) - created_at: Mapped[datetime] = mapped_column( - DateTime(timezone=True), default_factory=lambda: datetime.now(UTC) - ) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default_factory=lambda: datetime.now(UTC)) updated_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), default=None) deleted_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), default=None) is_deleted: Mapped[bool] = mapped_column(default=False, index=True) diff --git a/src/app/models/rate_limit.py b/src/app/models/rate_limit.py index b30ad2c..5b7ce50 100644 --- a/src/app/models/rate_limit.py +++ b/src/app/models/rate_limit.py @@ -9,17 +9,13 @@ class RateLimit(Base): __tablename__ = "rate_limit" - - id: Mapped[int] = mapped_column( - "id", autoincrement=True, nullable=False, unique=True, primary_key=True, init=False - ) + + id: Mapped[int] = mapped_column("id", autoincrement=True, nullable=False, unique=True, primary_key=True, init=False) tier_id: Mapped[int] = mapped_column(ForeignKey("tier.id"), index=True) name: Mapped[str] = mapped_column(String, nullable=False, unique=True) path: Mapped[str] = mapped_column(String, nullable=False) limit: Mapped[int] = mapped_column(Integer, nullable=False) period: Mapped[int] = mapped_column(Integer, nullable=False) - created_at: Mapped[datetime] = mapped_column( - DateTime(timezone=True), default_factory=lambda: datetime.now(UTC) - ) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default_factory=lambda: datetime.now(UTC)) updated_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), default=None) diff --git a/src/app/models/tier.py b/src/app/models/tier.py index 2230462..72aa0af 100644 --- a/src/app/models/tier.py +++ b/src/app/models/tier.py @@ -10,12 +10,8 @@ class Tier(Base): __tablename__ = "tier" - id: Mapped[int] = mapped_column( - "id", autoincrement=True, nullable=False, unique=True, primary_key=True, init=False - ) + id: Mapped[int] = mapped_column("id", autoincrement=True, nullable=False, unique=True, primary_key=True, init=False) name: Mapped[str] = mapped_column(String, nullable=False, unique=True) - created_at: Mapped[datetime] = mapped_column( - DateTime(timezone=True), default_factory=lambda: datetime.now(UTC) - ) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default_factory=lambda: datetime.now(UTC)) updated_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), default=None) diff --git a/src/app/schemas/post.py b/src/app/schemas/post.py index 2cae10a..d9eb509 100644 --- a/src/app/schemas/post.py +++ b/src/app/schemas/post.py @@ -7,59 +7,36 @@ class PostBase(BaseModel): - title: Annotated[ - str, - Field(min_length=2, max_length=30, examples=["This is my post"]) - ] - text: Annotated[ - str, - Field(min_length=1, max_length=63206, examples=["This is the content of my post."]) - ] - + title: Annotated[str, Field(min_length=2, max_length=30, examples=["This is my post"])] + text: Annotated[str, Field(min_length=1, max_length=63206, examples=["This is the content of my post."])] + class Post(TimestampSchema, PostBase, UUIDSchema, PersistentDeletion): media_url: Annotated[ - str | None, - Field( - pattern=r"^(https?|ftp)://[^\s/$.?#].[^\s]*$", - examples=["https://www.postimageurl.com"], - default=None - ), + str | None, + Field(pattern=r"^(https?|ftp)://[^\s/$.?#].[^\s]*$", examples=["https://www.postimageurl.com"], default=None), ] created_by_user_id: int class PostRead(BaseModel): id: int - title: Annotated[ - str, - Field(min_length=2, max_length=30, examples=["This is my post"]) - ] - text: Annotated[ - str, - Field(min_length=1, max_length=63206, examples=["This is the content of my post."]) - ] + title: Annotated[str, Field(min_length=2, max_length=30, examples=["This is my post"])] + text: Annotated[str, Field(min_length=1, max_length=63206, examples=["This is the content of my post."])] media_url: Annotated[ - str | None, - Field( - examples=["https://www.postimageurl.com"], - default=None - ), + str | None, + Field(examples=["https://www.postimageurl.com"], default=None), ] created_by_user_id: int created_at: datetime class PostCreate(PostBase): - model_config = ConfigDict(extra='forbid') - + model_config = ConfigDict(extra="forbid") + media_url: Annotated[ - str | None, - Field( - pattern=r"^(https?|ftp)://[^\s/$.?#].[^\s]*$", - examples=["https://www.postimageurl.com"], - default=None - ), + str | None, + Field(pattern=r"^(https?|ftp)://[^\s/$.?#].[^\s]*$", examples=["https://www.postimageurl.com"], default=None), ] @@ -68,33 +45,16 @@ class PostCreateInternal(PostCreate): class PostUpdate(BaseModel): - model_config = ConfigDict(extra='forbid') - - title: Annotated[ - str | None, - Field( - min_length=2, - max_length=30, - examples=["This is my updated post"], - default=None - ) - ] + model_config = ConfigDict(extra="forbid") + + title: Annotated[str | None, Field(min_length=2, max_length=30, examples=["This is my updated post"], default=None)] text: Annotated[ str | None, - Field( - min_length=1, - max_length=63206, - examples=["This is the updated content of my post."], - default=None - ) + Field(min_length=1, max_length=63206, examples=["This is the updated content of my post."], default=None), ] media_url: Annotated[ str | None, - Field( - pattern=r"^(https?|ftp)://[^\s/$.?#].[^\s]*$", - examples=["https://www.postimageurl.com"], - default=None - ) + Field(pattern=r"^(https?|ftp)://[^\s/$.?#].[^\s]*$", examples=["https://www.postimageurl.com"], default=None), ] @@ -103,7 +63,7 @@ class PostUpdateInternal(PostUpdate): class PostDelete(BaseModel): - model_config = ConfigDict(extra='forbid') + model_config = ConfigDict(extra="forbid") is_deleted: bool deleted_at: datetime diff --git a/src/app/schemas/rate_limit.py b/src/app/schemas/rate_limit.py index 5dc7335..5f3e4d1 100644 --- a/src/app/schemas/rate_limit.py +++ b/src/app/schemas/rate_limit.py @@ -11,33 +11,18 @@ def sanitize_path(path: str) -> str: class RateLimitBase(BaseModel): - path: Annotated[ - str, - Field(examples=["users"]) - ] - limit: Annotated[ - int, - Field(examples=[5]) - ] - period: Annotated[ - int, - Field(examples=[60]) - ] - - @field_validator('path') + path: Annotated[str, Field(examples=["users"])] + limit: Annotated[int, Field(examples=[5])] + period: Annotated[int, Field(examples=[60])] + + @field_validator("path") def validate_and_sanitize_path(cls, v: str) -> str: return sanitize_path(v) class RateLimit(TimestampSchema, RateLimitBase): tier_id: int - name: Annotated[ - str | None, - Field( - default=None, - examples=["users:5:60"] - ) - ] + name: Annotated[str | None, Field(default=None, examples=["users:5:60"])] class RateLimitRead(RateLimitBase): @@ -47,15 +32,9 @@ class RateLimitRead(RateLimitBase): class RateLimitCreate(RateLimitBase): - model_config = ConfigDict(extra='forbid') - - name: Annotated[ - str | None, - Field( - default=None, - examples=["api_v1_users:5:60"] - ) - ] + model_config = ConfigDict(extra="forbid") + + name: Annotated[str | None, Field(default=None, examples=["api_v1_users:5:60"])] class RateLimitCreateInternal(RateLimitCreate): @@ -68,9 +47,9 @@ class RateLimitUpdate(BaseModel): period: int | None = None name: str | None = None - @field_validator('path') + @field_validator("path") def validate_and_sanitize_path(cls, v: str) -> str: - return sanitize_path(v) if v is not None else None + return sanitize_path(v) if v is not None else None class RateLimitUpdateInternal(RateLimitUpdate): diff --git a/src/app/schemas/tier.py b/src/app/schemas/tier.py index a24ab71..2e6f81f 100644 --- a/src/app/schemas/tier.py +++ b/src/app/schemas/tier.py @@ -7,10 +7,7 @@ class TierBase(BaseModel): - name: Annotated[ - str, - Field(examples=["free"]) - ] + name: Annotated[str, Field(examples=["free"])] class Tier(TimestampSchema, TierBase): diff --git a/src/app/schemas/user.py b/src/app/schemas/user.py index f1e1dff..7592795 100644 --- a/src/app/schemas/user.py +++ b/src/app/schemas/user.py @@ -7,25 +7,13 @@ class UserBase(BaseModel): - name: Annotated[ - str, - Field(min_length=2, max_length=30, examples=["User Userson"]) - ] - username: Annotated[ - str, - Field(min_length=2, max_length=20, pattern=r"^[a-z0-9]+$", examples=["userson"]) - ] - email: Annotated[ - EmailStr, - Field(examples=["user.userson@example.com"]) - ] + name: Annotated[str, Field(min_length=2, max_length=30, examples=["User Userson"])] + username: Annotated[str, Field(min_length=2, max_length=20, pattern=r"^[a-z0-9]+$", examples=["userson"])] + email: Annotated[EmailStr, Field(examples=["user.userson@example.com"])] class User(TimestampSchema, UserBase, UUIDSchema, PersistentDeletion): - profile_image_url: Annotated[ - str, - Field(default="https://www.profileimageurl.com") - ] + profile_image_url: Annotated[str, Field(default="https://www.profileimageurl.com")] hashed_password: str is_superuser: bool = False tier_id: int | None = None @@ -33,30 +21,18 @@ class User(TimestampSchema, UserBase, UUIDSchema, PersistentDeletion): class UserRead(BaseModel): id: int - - name: Annotated[ - str, - Field(min_length=2, max_length=30, examples=["User Userson"]) - ] - username: Annotated[ - str, - Field(min_length=2, max_length=20, pattern=r"^[a-z0-9]+$", examples=["userson"]) - ] - email: Annotated[ - EmailStr, - Field(examples=["user.userson@example.com"]) - ] + + name: Annotated[str, Field(min_length=2, max_length=30, examples=["User Userson"])] + username: Annotated[str, Field(min_length=2, max_length=20, pattern=r"^[a-z0-9]+$", examples=["userson"])] + email: Annotated[EmailStr, Field(examples=["user.userson@example.com"])] profile_image_url: str tier_id: int | None class UserCreate(UserBase): - model_config = ConfigDict(extra='forbid') + model_config = ConfigDict(extra="forbid") - password: Annotated[ - str, - Field(pattern=r"^.{8,}|[0-9]+|[A-Z]+|[a-z]+|[^a-zA-Z0-9]+$", examples=["Str1ngst!"]) - ] + password: Annotated[str, Field(pattern=r"^.{8,}|[0-9]+|[A-Z]+|[a-z]+|[^a-zA-Z0-9]+$", examples=["Str1ngst!"])] class UserCreateInternal(UserBase): @@ -64,41 +40,18 @@ class UserCreateInternal(UserBase): class UserUpdate(BaseModel): - model_config = ConfigDict(extra='forbid') + model_config = ConfigDict(extra="forbid") - name: Annotated[ - Optional[str], - Field( - min_length=2, - max_length=30, - examples=["User Userberg"], - default=None - ) - ] + name: Annotated[Optional[str], Field(min_length=2, max_length=30, examples=["User Userberg"], default=None)] username: Annotated[ - Optional[str], - Field( - min_length=2, - max_length=20, - pattern=r"^[a-z0-9]+$", - examples=["userberg"], - default=None - ) - ] - email: Annotated[ - Optional[EmailStr], - Field( - examples=["user.userberg@example.com"], - default=None - ) + Optional[str], Field(min_length=2, max_length=20, pattern=r"^[a-z0-9]+$", examples=["userberg"], default=None) ] + email: Annotated[Optional[EmailStr], Field(examples=["user.userberg@example.com"], default=None)] profile_image_url: Annotated[ Optional[str], Field( - pattern=r"^(https?|ftp)://[^\s/$.?#].[^\s]*$", - examples=["https://www.profileimageurl.com"], - default=None - ) + pattern=r"^(https?|ftp)://[^\s/$.?#].[^\s]*$", examples=["https://www.profileimageurl.com"], default=None + ), ] @@ -111,7 +64,7 @@ class UserTierUpdate(BaseModel): class UserDelete(BaseModel): - model_config = ConfigDict(extra='forbid') + model_config = ConfigDict(extra="forbid") is_deleted: bool deleted_at: datetime diff --git a/src/app/worker.py b/src/app/worker.py index 2e65716..1254284 100644 --- a/src/app/worker.py +++ b/src/app/worker.py @@ -30,10 +30,7 @@ async def shutdown(ctx: Worker) -> None: # -------- class -------- class WorkerSettings: functions = [sample_background_task] - redis_settings = RedisSettings( - host=REDIS_QUEUE_HOST, - port=REDIS_QUEUE_PORT - ) + redis_settings = RedisSettings(host=REDIS_QUEUE_HOST, port=REDIS_QUEUE_PORT) on_startup = startup on_shutdown = shutdown handle_signals = False diff --git a/src/scripts/create_first_superuser.py b/src/scripts/create_first_superuser.py index f0944d5..6d35df3 100644 --- a/src/scripts/create_first_superuser.py +++ b/src/scripts/create_first_superuser.py @@ -20,11 +20,12 @@ async def create_first_user(session: AsyncSession) -> None: query = select(User).filter_by(email=email) result = await session.execute(query) user = result.scalar_one_or_none() - + if user is None: metadata = MetaData() user_table = Table( - "user", metadata, + "user", + metadata, Column("id", Integer, primary_key=True, autoincrement=True, nullable=False), Column("name", String(30), nullable=False), Column("username", String(20), nullable=False, unique=True, index=True), @@ -32,33 +33,33 @@ async def create_first_user(session: AsyncSession) -> None: Column("hashed_password", String, nullable=False), Column("profile_image_url", String, default="https://profileimageurl.com"), Column("uuid", UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, unique=True), - Column("created_at", DateTime(timezone=True), default=lambda: datetime.now(UTC), nullable=False), + Column("created_at", DateTime(timezone=True), default=lambda: datetime.now(UTC), nullable=False), Column("updated_at", DateTime), Column("deleted_at", DateTime), Column("is_deleted", Boolean, default=False, index=True), Column("is_superuser", Boolean, default=False), - Column("tier_id", Integer, ForeignKey("tier.id"), index=True) + Column("tier_id", Integer, ForeignKey("tier.id"), index=True), ) - data = { - 'name': name, - 'email': email, - 'username': username, - 'hashed_password': hashed_password, - 'is_superuser': True + "name": name, + "email": email, + "username": username, + "hashed_password": hashed_password, + "is_superuser": True, } - stmt = insert(user_table).values(data) async with async_engine.connect() as conn: await conn.execute(stmt) await conn.commit() + async def main(): async with local_session() as session: await create_first_user(session) + if __name__ == "__main__": loop = asyncio.get_event_loop() loop.run_until_complete(main()) diff --git a/src/scripts/create_first_tier.py b/src/scripts/create_first_tier.py index 5fd0c28..28c0334 100644 --- a/src/scripts/create_first_tier.py +++ b/src/scripts/create_first_tier.py @@ -9,22 +9,22 @@ async def create_first_tier(session: AsyncSession) -> None: tier_name = config("TIER_NAME", default="free") - + query = select(Tier).where(Tier.name == tier_name) result = await session.execute(query) tier = result.scalar_one_or_none() - + if tier is None: - session.add( - Tier(name=tier_name) - ) - + session.add(Tier(name=tier_name)) + await session.commit() + async def main(): async with local_session() as session: await create_first_tier(session) + if __name__ == "__main__": loop = asyncio.get_event_loop() loop.run_until_complete(main()) diff --git a/tests/helper.py b/tests/helper.py index 112196a..2163593 100644 --- a/tests/helper.py +++ b/tests/helper.py @@ -4,9 +4,6 @@ def _get_token(username: str, password: str, client: TestClient): return client.post( "/api/v1/login", - data={ - "username": username, - "password": password - }, - headers={"content-type": "application/x-www-form-urlencoded"} + data={"username": username, "password": password}, + headers={"content-type": "application/x-www-form-urlencoded"}, ) diff --git a/tests/test_user.py b/tests/test_user.py index edfdcf6..954829f 100644 --- a/tests/test_user.py +++ b/tests/test_user.py @@ -15,68 +15,49 @@ client = TestClient(app) + def test_post_user(client: TestClient) -> None: response = client.post( "/api/v1/user", - json = { - "name": test_name, - "username": test_username, - "email": test_email, - "password": test_password - } + json={"name": test_name, "username": test_username, "email": test_email, "password": test_password}, ) assert response.status_code == 201 + def test_get_user(client: TestClient) -> None: - response = client.get( - f"/api/v1/user/{test_username}" - ) + response = client.get(f"/api/v1/user/{test_username}") assert response.status_code == 200 + def test_get_multiple_users(client: TestClient) -> None: - response = client.get( - "/api/v1/users" - ) + response = client.get("/api/v1/users") assert response.status_code == 200 + def test_update_user(client: TestClient) -> None: - token = _get_token( - username=test_username, - password=test_password, - client=client - ) - + token = _get_token(username=test_username, password=test_password, client=client) + response = client.patch( f"/api/v1/user/{test_username}", - json={ - "name": f"Updated {test_name}" - }, - headers={"Authorization": f'Bearer {token.json()["access_token"]}'} + json={"name": f"Updated {test_name}"}, + headers={"Authorization": f'Bearer {token.json()["access_token"]}'}, ) assert response.status_code == 200 + def test_delete_user(client: TestClient) -> None: - token = _get_token( - username=test_username, - password=test_password, - client=client - ) + token = _get_token(username=test_username, password=test_password, client=client) response = client.delete( - f"/api/v1/user/{test_username}", - headers={"Authorization": f'Bearer {token.json()["access_token"]}'} + f"/api/v1/user/{test_username}", headers={"Authorization": f'Bearer {token.json()["access_token"]}'} ) assert response.status_code == 200 + def test_delete_db_user(client: TestClient) -> None: - token = _get_token( - username=admin_username, - password=admin_password, - client=client - ) + token = _get_token(username=admin_username, password=admin_password, client=client) response = client.delete( - f"/api/v1/db_user/{test_username}", - headers={"Authorization": f'Bearer {token.json()["access_token"]}'} + f"/api/v1/db_user/{test_username}", headers={"Authorization": f'Bearer {token.json()["access_token"]}'} ) assert response.status_code == 200