Skip to content

Commit

Permalink
Merge branch 'cohere-ai:main' into integration
Browse files Browse the repository at this point in the history
  • Loading branch information
ClaytonSmith authored Jun 13, 2024
2 parents a9a7639 + 60f25d9 commit 9b3858f
Show file tree
Hide file tree
Showing 40 changed files with 751 additions and 316 deletions.
7 changes: 6 additions & 1 deletion .env-template
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,13 @@ USE_AGENTS_VIEW=False
USE_COMMUNITY_FEATURES='True'

# For setting up authentication, see: docs/auth_guide.md
JWT_SECRET_KEY=
JWT_SECRET_KEY=<See auth.guide.md on how to generate a secure one>

# Google OAuth
GOOGLE_CLIENT_ID=<GOOGLE_CLIENT_ID>
GOOGLE_CLIENT_SECRET=<GOOGLE_CLIENT_SECRET>

# OpenID Connect
OIDC_CLIENT_ID=<OIDC_CLIENT_ID>
OIDC_CLIENT_SECRET=<OIDC_CLIENT_SECRET>
OIDC_CONFIG_ENDPOINT=<OIDC_CONFIG_ENDPOINT>
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,4 @@ first-run:
win-first-run:
make win-setup
make migrate
make dev
make dev
349 changes: 184 additions & 165 deletions poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ itsdangerous = "^2.2.0"
bcrypt = "^4.1.2"
pypdf = "^4.2.0"
pyjwt = "^2.8.0"
pydantic-settings = "^2.3.1"

[tool.poetry.group.dev]
optional = true
Expand Down
39 changes: 39 additions & 0 deletions src/backend/alembic/versions/28763d200b29_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""empty message
Revision ID: 28763d200b29
Revises: a9b07acef4e8
Create Date: 2024-06-10 20:32:41.903400
"""

from typing import Sequence, Union

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision: str = "28763d200b29"
down_revision: Union[str, None] = "a9b07acef4e8"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"blacklist",
sa.Column("token_id", sa.String(), nullable=False),
sa.Column("id", sa.String(), nullable=False),
sa.Column("created_at", sa.DateTime(), nullable=True),
sa.Column("updated_at", sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
op.create_index("blacklist_token_id", "blacklist", ["token_id"], unique=False)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index("blacklist_token_id", table_name="blacklist")
op.drop_table("blacklist")
# ### end Alembic commands ###
49 changes: 49 additions & 0 deletions src/backend/alembic/versions/922e874930bf_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""empty message
Revision ID: 922e874930bf
Revises: 28763d200b29
Create Date: 2024-06-12 21:19:12.204875
"""

from typing import Sequence, Union

import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
revision: str = "922e874930bf"
down_revision: Union[str, None] = "28763d200b29"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"agents",
sa.Column(
"tools",
postgresql.ARRAY(
sa.Enum(
"Wiki_Retriever_LangChain",
"Search_File",
"Read_File",
"Python_Interpreter",
"Calculator",
"Tavily_Internet_Search",
name="toolname",
native_enum=False,
)
),
nullable=False,
),
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("agents", "tools")
# ### end Alembic commands ###
2 changes: 1 addition & 1 deletion src/backend/config/auth.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from backend.services.auth import BasicAuthentication, GoogleOAuth
from backend.services.auth import BasicAuthentication, GoogleOAuth, OpenIDConnect

# Add Auth strategy classes here to enable them
# Ex: [BasicAuthentication]
Expand Down
34 changes: 34 additions & 0 deletions src/backend/crud/blacklist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from sqlalchemy.orm import Session

from backend.database_models.blacklist import Blacklist


def create_blacklist(db: Session, blacklist: Blacklist) -> Blacklist:
""" "
Create a blacklist token.
Args:
db (Session): Database session.
blacklist (Blacklist): Blacklist data to be created.
Returns:
Blacklist: Created blacklist.
"""
db.add(blacklist)
db.commit()
db.refresh(blacklist)
return blacklist


def get_blacklist(db: Session, token_id: str) -> Blacklist:
"""
Get a blacklist token by token_id column.
Args:
db (Session): Database session.
token_id (str): Token ID.
Returns:
Blacklist: Blacklist with the given token_id.
"""
return db.query(Blacklist).filter(Blacklist.token_id == token_id).first()
1 change: 1 addition & 0 deletions src/backend/database_models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from backend.database_models.agent import *
from backend.database_models.base import *
from backend.database_models.blacklist import *
from backend.database_models.citation import *
from backend.database_models.conversation import *
from backend.database_models.database import *
Expand Down
8 changes: 6 additions & 2 deletions src/backend/database_models/agent.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from enum import StrEnum

from sqlalchemy import Enum, Float, Integer, String, Text, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.dialects.postgresql import ARRAY
from sqlalchemy.orm import Mapped, mapped_column

from backend.config.tools import ToolName
from backend.database_models.base import Base


Expand All @@ -28,7 +30,9 @@ class Agent(Base):
description: Mapped[str] = mapped_column(Text, default="", nullable=False)
preamble: Mapped[str] = mapped_column(Text, default="", nullable=False)
temperature: Mapped[float] = mapped_column(Float, default=0.3, nullable=False)
# tool: Mapped[List["Tool"]] = relationship()
tools: Mapped[list[ToolName]] = mapped_column(
ARRAY(Enum(ToolName, native_enum=False)), default=[], nullable=False
)

# TODO @scott-cohere: eventually switch to Fkey when new deployment tables are implemented
# TODO @scott-cohere: deployments have different names for models, need to implement mapping later
Expand Down
16 changes: 16 additions & 0 deletions src/backend/database_models/blacklist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from sqlalchemy import Index, String
from sqlalchemy.orm import Mapped, mapped_column

from backend.database_models.base import Base


class Blacklist(Base):
"""
Table that contains the list of JWT access tokens that are blacklisted during logout.
"""

__tablename__ = "blacklist"

token_id: Mapped[str] = mapped_column(String)

__table_args__ = (Index("blacklist_token_id", token_id),)
2 changes: 1 addition & 1 deletion src/backend/database_models/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import List

from sqlalchemy import Boolean, Enum, ForeignKey, Index, String
from sqlalchemy.orm import Mapped, declared_attr, mapped_column, relationship
from sqlalchemy.orm import Mapped, mapped_column, relationship

from backend.database_models.base import Base
from backend.database_models.citation import Citation
Expand Down
Empty file.
2 changes: 1 addition & 1 deletion src/backend/routers/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def create_agent(session: DBSessionDep, agent: CreateAgent, request: Request):
user_id=user_id,
model=agent.model,
deployment=agent.deployment,
# tools=request.json().get("tools"),
tools=agent.tools,
)

return agent_crud.create_agent(session, agent_data)
Expand Down
87 changes: 57 additions & 30 deletions src/backend/routers/auth.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from typing import Union

from authlib.integrations.starlette_client import OAuthError
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, HTTPException
from starlette.requests import Request
from starlette.responses import RedirectResponse

from backend.config.auth import ENABLED_AUTH_STRATEGY_MAPPING
from backend.config.routers import RouterName
from backend.database_models import get_session
from backend.database_models.database import DBSessionDep
from backend.schemas.auth import Auth, Login
from backend.schemas.auth import JWTResponse, ListAuthStrategy, Login, Logout
from backend.services.auth import GoogleOAuth, OpenIDConnect
from backend.services.auth.jwt import JWTService
from backend.services.auth.utils import (
get_or_create_user,
Expand All @@ -18,8 +19,8 @@
router.name = RouterName.AUTH


@router.get("/auth_strategies")
def get_strategies():
@router.get("/auth_strategies", response_model=list[ListAuthStrategy])
def get_strategies() -> list[ListAuthStrategy]:
"""
Retrieves the currently enabled list of Authentication strategies.
Expand All @@ -34,7 +35,7 @@ def get_strategies():
return strategies


@router.post("/login")
@router.post("/login", response_model=Union[JWTResponse, None])
async def login(request: Request, login: Login, session: DBSessionDep):
"""
Logs user in and either:
Expand Down Expand Up @@ -75,7 +76,7 @@ async def login(request: Request, login: Login, session: DBSessionDep):
# Login with redirect to /auth
if strategy.SHOULD_AUTH_REDIRECT:
# Fetch endpoint with method name
redirect_uri = request.url_for("authenticate")
redirect_uri = request.url_for(strategy.REDIRECT_METHOD_NAME)
return await strategy.login(request, redirect_uri)
# Login with email/password and set session directly
else:
Expand All @@ -91,23 +92,65 @@ async def login(request: Request, login: Login, session: DBSessionDep):
return {"token": token}


@router.post("/auth")
async def authenticate(request: Request, auth: Auth, session: DBSessionDep):
@router.get("/google/auth", response_model=JWTResponse)
async def google_authenticate(request: Request, session: DBSessionDep):
"""
Authentication endpoint used for OAuth strategies. Logs the user in the redirect environment and then
sets the current session with the user returned from the auth token.
Callback authentication endpoint used for Google OAuth after redirecting to
the service's login screen.
Args:
request (Request): current Request object.
login (Login): Login payload.
Returns:
RedirectResponse: On success.
Raises:
HTTPException: If authentication fails, or strategy is invalid.
"""
strategy_name = auth.strategy
strategy_name = GoogleOAuth.NAME

return await authenticate(request, session, strategy_name)


@router.get("/oidc/auth", response_model=JWTResponse)
async def oidc_authenticate(request: Request, session: DBSessionDep):
"""
Callback authentication endpoint used for OIDC after redirecting to
the service's login screen.
Args:
request (Request): current Request object.
Returns:
RedirectResponse: On success.
Raises:
HTTPException: If authentication fails, or strategy is invalid.
"""
strategy_name = OpenIDConnect.NAME

return await authenticate(request, session, strategy_name)


@router.get("/logout", response_model=Logout)
async def logout(request: Request):
"""
Logs out the current user.
Args:
request (Request): current Request object.
Returns:
dict: Empty on success
"""
# TODO: Design blacklist

return {}


async def authenticate(
request: Request, session: DBSessionDep, strategy_name: str
) -> JWTResponse:
if not is_enabled_authentication_strategy(strategy_name):
raise HTTPException(
status_code=404, detail=f"Invalid Authentication strategy: {strategy_name}."
Expand Down Expand Up @@ -136,19 +179,3 @@ async def authenticate(request: Request, auth: Auth, session: DBSessionDep):
token = JWTService().create_and_encode_jwt(user)

return {"token": token}


@router.get("/logout")
async def logout(request: Request):
"""
Logs out the current user.
Args:
request (Request): current Request object.
Returns:
dict: Empty on success
"""
# TODO: Design blacklist

return {}
Loading

0 comments on commit 9b3858f

Please sign in to comment.