Skip to content

Commit

Permalink
feat(assistant): cdp (#3305)
Browse files Browse the repository at this point in the history
# Description

Please include a summary of the changes and the related issue. Please
also include relevant motivation and context.

## Checklist before requesting a review

Please delete options that are not relevant.

- [ ] My code follows the style guidelines of this project
- [ ] I have performed a self-review of my code
- [ ] I have commented hard-to-understand areas
- [ ] I have ideally added tests that prove my fix is effective or that
my feature works
- [ ] New and existing unit tests pass locally with my changes
- [ ] Any dependent changes have been merged

## Screenshots (if appropriate):

---------

Co-authored-by: Zewed <[email protected]>
  • Loading branch information
StanGirard and Zewed authored Oct 3, 2024
1 parent c399139 commit b767f19
Show file tree
Hide file tree
Showing 110 changed files with 6,315 additions and 706 deletions.
3 changes: 3 additions & 0 deletions backend/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ RUN apt-get clean && apt-get update && apt-get install -y \
libreoffice \
libpq-dev \
gcc \
libhdf5-serial-dev \
pandoc && \
rm -rf /var/lib/apt/lists/* && apt-get clean

Expand All @@ -46,6 +47,8 @@ COPY core/pyproject.toml core/README.md ./core/
COPY core/quivr_core/__init__.py ./core/quivr_core/__init__.py
COPY worker/pyproject.toml worker/README.md ./worker/
COPY worker/quivr_worker/__init__.py ./worker/quivr_worker/__init__.py
COPY worker/diff-assistant/pyproject.toml worker/diff-assistant/README.md ./worker/diff-assistant/
COPY worker/diff-assistant/quivr_diff_assistant/__init__.py ./worker/diff-assistant/quivr_diff_assistant/__init__.py
COPY core/MegaParse/pyproject.toml core/MegaParse/README.md ./core/MegaParse/
COPY core/MegaParse/megaparse/__init__.py ./core/MegaParse/megaparse/__init__.py

Expand Down
3 changes: 3 additions & 0 deletions backend/Dockerfile.dev
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ RUN apt-get clean && apt-get update && apt-get install -y \
libreoffice \
libpq-dev \
gcc \
libhdf5-serial-dev \
pandoc && \
rm -rf /var/lib/apt/lists/* && apt-get clean

Expand All @@ -33,6 +34,8 @@ COPY core/pyproject.toml core/README.md ./core/
COPY core/quivr_core/__init__.py ./core/quivr_core/__init__.py
COPY worker/pyproject.toml worker/README.md ./worker/
COPY worker/quivr_worker/__init__.py ./worker/quivr_worker/__init__.py
COPY worker/diff-assistant/pyproject.toml worker/diff-assistant/README.md ./worker/diff-assistant/
COPY worker/diff-assistant/quivr_diff_assistant/__init__.py ./worker/diff-assistant/quivr_diff_assistant/__init__.py
COPY core/MegaParse/pyproject.toml core/MegaParse/README.md ./core/MegaParse/
COPY core/MegaParse/megaparse/__init__.py ./core/MegaParse/megaparse/__init__.py

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Annotated, List
from uuid import uuid4

from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile

from quivr_api.celery_config import celery
from quivr_api.logger import get_logger
Expand All @@ -16,6 +16,7 @@
from quivr_api.modules.assistant.entity.assistant_entity import (
AssistantSettings,
)
from quivr_api.modules.assistant.entity.task_entity import TaskMetadata
from quivr_api.modules.assistant.services.tasks_service import TasksService
from quivr_api.modules.dependencies import get_service
from quivr_api.modules.upload.service.upload_file import (
Expand Down Expand Up @@ -64,12 +65,15 @@ async def create_task(
current_user: UserIdentityDep,
tasks_service: TasksServiceDep,
request: Request,
input: InputAssistant,
input: str = File(...),
files: List[UploadFile] = None,
):
input = InputAssistant.model_validate_json(input)

assistant = next(
(assistant for assistant in assistants if assistant.id == input.id), None
)

if assistant is None:
raise HTTPException(status_code=404, detail="Assistant not found")

Expand All @@ -80,7 +84,7 @@ async def create_task(
raise HTTPException(status_code=400, detail=error)
else:
print("Assistant input is valid.")
notification_uuid = uuid4()
notification_uuid = f"{assistant.name}-{str(uuid4())[:8]}"

# Process files dynamically
for upload_file in files:
Expand All @@ -96,8 +100,14 @@ async def create_task(

task = CreateTask(
assistant_id=input.id,
pretty_id=str(notification_uuid),
assistant_name=assistant.name,
pretty_id=notification_uuid,
settings=input.model_dump(mode="json"),
task_metadata=TaskMetadata(
input_files=[file.filename for file in files]
).model_dump(mode="json")
if files
else None, # type: ignore
)

task_created = await tasks_service.create_task(task, current_user.id)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from quivr_api.modules.assistant.dto.inputs import InputAssistant
from quivr_api.modules.assistant.dto.outputs import (
AssistantOutput,
ConditionalInput,
InputBoolean,
InputFile,
Inputs,
InputSelectText,
Pricing,
)

Expand Down Expand Up @@ -166,10 +169,10 @@ def validate_assistant_input(

assistant1 = AssistantOutput(
id=1,
name="Assistant 1",
description="Assistant 1 description",
name="Compliance Check",
description="Allows analyzing the compliance of the information contained in documents against charter or regulatory requirements.",
pricing=Pricing(),
tags=["tag1", "tag2"],
tags=["Disabled"],
input_description="Input description",
output_description="Output description",
inputs=Inputs(
Expand All @@ -183,19 +186,66 @@ def validate_assistant_input(

assistant2 = AssistantOutput(
id=2,
name="Assistant 2",
description="Assistant 2 description",
name="Consistency Check",
description="Ensures that the information in one document is replicated identically in another document.",
pricing=Pricing(),
tags=["tag1", "tag2"],
tags=[],
input_description="Input description",
output_description="Output description",
icon_url="https://example.com/icon.png",
inputs=Inputs(
files=[
InputFile(key="file_1", description="File description"),
InputFile(key="file_2", description="File description"),
InputFile(key="Document 1", description="File description"),
InputFile(key="Document 2", description="File description"),
],
select_texts=[
InputSelectText(
key="DocumentsType",
description="Select Documents Type",
options=[
"Etiquettes VS Cahier des charges",
"Fiche Dev VS Cahier des charges",
],
),
],
),
)

assistant3 = AssistantOutput(
id=3,
name="Difference Detection",
description="Highlights differences between one document and another after modifications.",
pricing=Pricing(),
tags=[],
input_description="Input description",
output_description="Output description",
icon_url="https://example.com/icon.png",
inputs=Inputs(
files=[
InputFile(key="Document 1", description="File description"),
InputFile(key="Document 2", description="File description"),
],
booleans=[
InputBoolean(
key="Hard-to-Read Document?", description="Boolean description"
),
],
select_texts=[
InputSelectText(
key="DocumentsType",
description="Select Documents Type",
options=["Etiquettes", "Cahier des charges"],
),
],
conditional_inputs=[
ConditionalInput(
key="DocumentsType",
conditional_key="Hard-to-Read Document?",
condition="equals",
value="Etiquettes",
),
],
),
)

assistants = [assistant1, assistant2]
assistants = [assistant1, assistant2, assistant3]
4 changes: 3 additions & 1 deletion backend/api/quivr_api/modules/assistant/dto/inputs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import Dict, List, Optional
from uuid import UUID

from pydantic import BaseModel, root_validator
Expand All @@ -7,7 +7,9 @@
class CreateTask(BaseModel):
pretty_id: str
assistant_id: int
assistant_name: str
settings: dict
task_metadata: Dict | None = None


class BrainInput(BaseModel):
Expand Down
16 changes: 16 additions & 0 deletions backend/api/quivr_api/modules/assistant/dto/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,21 @@ class InputSelectNumber(BaseModel):
default: Optional[int] = None


class ConditionalInput(BaseModel):
"""
Conditional input is a list of inputs that are conditional to the value of another input.
key: The key of the input that is conditional.
conditional_key: The key that determines if the input is shown.
"""

key: str
conditional_key: str
condition: Optional[str] = (
None # e.g. "equals", "contains", "starts_with", "ends_with", "regex", "in", "not_in", "is_empty", "is_not_empty"
)
value: Optional[str] = None


class Inputs(BaseModel):
files: Optional[List[InputFile]] = None
urls: Optional[List[InputUrl]] = None
Expand All @@ -70,6 +85,7 @@ class Inputs(BaseModel):
select_texts: Optional[List[InputSelectText]] = None
select_numbers: Optional[List[InputSelectNumber]] = None
brain: Optional[BrainInput] = None
conditional_inputs: Optional[List[ConditionalInput]] = None


class Pricing(BaseModel):
Expand Down
12 changes: 8 additions & 4 deletions backend/api/quivr_api/modules/assistant/entity/task_entity.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from datetime import datetime
from typing import Dict
from typing import Dict, List, Optional
from uuid import UUID

from pydantic import BaseModel
from sqlmodel import JSON, TIMESTAMP, BigInteger, Column, Field, SQLModel, text


class TaskMetadata(BaseModel):
input_files: Optional[List[str]] = None


class Task(SQLModel, table=True):
__tablename__ = "tasks" # type: ignore

Expand All @@ -17,6 +22,7 @@ class Task(SQLModel, table=True):
),
)
assistant_id: int
assistant_name: str
pretty_id: str
user_id: UUID
status: str = Field(default="pending")
Expand All @@ -29,6 +35,4 @@ class Task(SQLModel, table=True):
)
settings: Dict = Field(default_factory=dict, sa_column=Column(JSON))
answer: str | None = Field(default=None)

class Config:
arbitrary_types_allowed = True
task_metadata: Dict | None = Field(default_factory=dict, sa_column=Column(JSON))
8 changes: 6 additions & 2 deletions backend/api/quivr_api/modules/assistant/repository/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from sqlalchemy import exc
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import select
from sqlmodel import col, select

from quivr_api.modules.assistant.dto.inputs import CreateTask
from quivr_api.modules.assistant.entity.task_entity import Task
Expand All @@ -21,9 +21,11 @@ async def create_task(self, task: CreateTask, user_id: UUID) -> Task:
try:
task_to_create = Task(
assistant_id=task.assistant_id,
assistant_name=task.assistant_name,
pretty_id=task.pretty_id,
user_id=user_id,
settings=task.settings,
task_metadata=task.task_metadata, # type: ignore
)
self.session.add(task_to_create)
await self.session.commit()
Expand All @@ -40,7 +42,9 @@ async def get_task_by_id(self, task_id: UUID, user_id: UUID) -> Task:
return response.one()

async def get_tasks_by_user_id(self, user_id: UUID) -> Sequence[Task]:
query = select(Task).where(Task.user_id == user_id)
query = (
select(Task).where(Task.user_id == user_id).order_by(col(Task.id).desc())
)
response = await self.session.exec(query)
return response.all()

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Tuple, Dict
from typing import Dict, Optional, Tuple
from uuid import UUID

from fastapi import HTTPException
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import time
import os
import time
from enum import Enum

from fastapi import HTTPException
Expand Down
4 changes: 2 additions & 2 deletions backend/api/quivr_api/modules/chat/controller/chat_routes.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
from typing import Annotated, List, Optional
from uuid import UUID
import os

from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query, Request
from fastapi.responses import StreamingResponse
from quivr_core.config import RetrievalConfig

from quivr_api.logger import get_logger
from quivr_api.middlewares.auth import AuthBearer, get_current_user
Expand Down Expand Up @@ -36,7 +37,6 @@
from quivr_api.modules.vector.service.vector_service import VectorService
from quivr_api.utils.telemetry import maybe_send_telemetry
from quivr_api.utils.uuid_generator import generate_uuid_from_string
from quivr_core.config import RetrievalConfig

logger = get_logger(__name__)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from enum import Enum
from typing import Any, Dict, List, Optional
from uuid import UUID
from pydantic import BaseModel

from pydantic import BaseModel
from quivr_core.models import KnowledgeStatus
from sqlalchemy import JSON, TIMESTAMP, Column, text
from sqlalchemy.ext.asyncio import AsyncAttrs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,3 @@ async def remove_file(self, storage_path: str):
except Exception as e:
logger.error(e)
raise e

Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,9 @@ async def test_should_process_knowledge_prev_error(
assert new.file_sha1


@pytest.mark.skip(reason="Bug: UnboundLocalError: cannot access local variable 'response'")
@pytest.mark.skip(
reason="Bug: UnboundLocalError: cannot access local variable 'response'"
)
@pytest.mark.asyncio(loop_scope="session")
async def test_get_knowledge_storage_path(session: AsyncSession, test_data: TestData):
_, [knowledge, _] = test_data
Expand Down
4 changes: 1 addition & 3 deletions backend/api/quivr_api/modules/misc/controller/misc_routes.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@

from fastapi import APIRouter, Depends, HTTPException
from quivr_api.logger import get_logger
from quivr_api.modules.dependencies import get_async_session
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel import text
from sqlmodel.ext.asyncio.session import AsyncSession

logger = get_logger(__name__)

Expand All @@ -20,7 +19,6 @@ async def root():

@misc_router.get("/healthz", tags=["Health"])
async def healthz(session: AsyncSession = Depends(get_async_session)):

try:
result = await session.execute(text("SELECT 1"))
if not result:
Expand Down
2 changes: 1 addition & 1 deletion backend/api/quivr_api/modules/rag_service/rag_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import os
from uuid import UUID, uuid4

from quivr_api.utils.uuid_generator import generate_uuid_from_string
from quivr_core.brain import Brain as BrainCore
from quivr_core.chat import ChatHistory as ChatHistoryCore
from quivr_core.config import LLMEndpointConfig, RetrievalConfig
Expand All @@ -29,6 +28,7 @@
from quivr_api.modules.prompt.service.prompt_service import PromptService
from quivr_api.modules.user.entity.user_identity import UserIdentity
from quivr_api.modules.vector.service.vector_service import VectorService
from quivr_api.utils.uuid_generator import generate_uuid_from_string
from quivr_api.vectorstore.supabase import CustomSupabaseVectorStore

from .utils import generate_source
Expand Down
Loading

0 comments on commit b767f19

Please sign in to comment.