Skip to content

Commit

Permalink
add model metadata (#96)
Browse files Browse the repository at this point in the history
  • Loading branch information
ccurme authored Mar 22, 2024
1 parent 54402b7 commit 0da6b2b
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 44 deletions.
21 changes: 0 additions & 21 deletions backend/extraction/parsing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Convert binary input to blobs and parse them using the appropriate parser."""
from __future__ import annotations

import io
from typing import BinaryIO, List

from fastapi import HTTPException
Expand All @@ -10,7 +9,6 @@
from langchain.document_loaders.parsers.txt import TextParser
from langchain_community.document_loaders import Blob
from langchain_core.documents import Document
from pdfminer.pdfpage import PDFPage

HANDLERS = {
"application/pdf": PDFMinerParser(),
Expand All @@ -28,7 +26,6 @@
SUPPORTED_MIMETYPES = sorted(HANDLERS.keys())

MAX_FILE_SIZE_MB = 10 # in MB
MAX_PAGES = 50 # for PDFs


def _guess_mimetype(file_bytes: bytes) -> str:
Expand All @@ -54,13 +51,6 @@ def _get_file_size_in_mb(data: BinaryIO) -> float:
return file_size_in_mb


def _get_pdf_page_count(file_bytes: bytes) -> int:
"""Get the number of pages in a PDF file."""
file_stream = io.BytesIO(file_bytes)
pages = PDFPage.get_pages(file_stream)
return sum(1 for _ in pages)


# PUBLIC API

MIMETYPE_BASED_PARSER = MimeTypeBasedParser(
Expand All @@ -83,17 +73,6 @@ def convert_binary_input_to_blob(data: BinaryIO) -> Blob:
mimetype = _guess_mimetype(file_data)
file_name = data.name

if mimetype == "application/pdf":
number_of_pages = _get_pdf_page_count(file_data)
if number_of_pages > MAX_PAGES:
raise HTTPException(
status_code=413,
detail=(
f"PDF has too many pages: {number_of_pages}, "
f"exceeding the maximum of {MAX_PAGES}."
),
)

return Blob.from_data(
data=file_data,
path=file_name,
Expand Down
42 changes: 26 additions & 16 deletions backend/server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,32 @@ def get_supported_models():
"""Get models according to environment secrets."""
models = {}
if "OPENAI_API_KEY" in os.environ:
models["gpt-3.5-turbo"] = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
models["gpt-4-0125-preview"] = ChatOpenAI(
model="gpt-4-0125-preview", temperature=0
)
models["gpt-3.5-turbo"] = {
"chat_model": ChatOpenAI(model="gpt-3.5-turbo", temperature=0),
"description": "GPT-3.5 Turbo",
}
models["gpt-4-0125-preview"] = {
"chat_model": ChatOpenAI(model="gpt-4-0125-preview", temperature=0),
"description": "GPT-4 0125 Preview",
}
if "FIREWORKS_API_KEY" in os.environ:
models["fireworks"] = ChatFireworks(
model="accounts/fireworks/models/firefunction-v1",
temperature=0,
)
models["fireworks"] = {
"chat_model": ChatFireworks(
model="accounts/fireworks/models/firefunction-v1",
temperature=0,
),
"description": "Fireworks Firefunction-v1",
}
if "TOGETHER_API_KEY" in os.environ:
models["together-ai-mistral-8x7b-instruct-v0.1"] = ChatOpenAI(
base_url="https://api.together.xyz/v1",
api_key=os.environ["TOGETHER_API_KEY"],
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
temperature=0,
)
models["together-ai-mistral-8x7b-instruct-v0.1"] = {
"chat_model": ChatOpenAI(
base_url="https://api.together.xyz/v1",
api_key=os.environ["TOGETHER_API_KEY"],
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
temperature=0,
),
"description": "Mixtral 8x7B Instruct v0.1 (Together AI)",
}

return models

Expand All @@ -47,7 +57,7 @@ def get_chunk_size(model_name: str) -> int:
def get_model(model_name: Optional[str] = None) -> BaseChatModel:
"""Get the model."""
if model_name is None:
return SUPPORTED_MODELS[DEFAULT_MODEL]
return SUPPORTED_MODELS[DEFAULT_MODEL]["chat_model"]
else:
supported_model_names = list(SUPPORTED_MODELS.keys())
if model_name not in supported_model_names:
Expand All @@ -56,4 +66,4 @@ def get_model(model_name: Optional[str] = None) -> BaseChatModel:
f"Supported models: {supported_model_names}"
)
else:
return SUPPORTED_MODELS[model_name]
return SUPPORTED_MODELS[model_name]["chat_model"]
20 changes: 13 additions & 7 deletions backend/tests/unit_tests/api/test_api_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,11 @@ async def test_extract_from_file() -> None:
assert response.json() == {"data": ["This is a "]}


@patch(
"server.extraction_runnable.extraction_runnable",
new=RunnableLambda(mock_extraction_runnable),
)
@patch("server.extraction_runnable.TokenTextSplitter", mock_text_splitter)
async def test_extract_from_large_file() -> None:
user_id = str(uuid4())
headers = {"x-key": user_id}
Expand Down Expand Up @@ -167,22 +172,23 @@ async def test_extract_from_large_file() -> None:
)
assert response.status_code == 413

# Test page number constraint
# Test chunk count constraint
with tempfile.NamedTemporaryFile(mode="w+t", delete=True) as f:
f.write("This is a named temporary file.")
f.seek(0)
f.flush()
with patch(
"extraction.parsing._guess_mimetype", return_value="application/pdf"
):
with patch("extraction.parsing._get_pdf_page_count", return_value=100):
with patch("server.extraction_runnable.settings.MAX_CHUNKS", 1):
with patch.object(
CharacterTextSplitter, "split_text", return_value=["a", "b"]
):
response = await client.post(
"/extract",
data={
"extractor_id": extractor_id,
"mode": "entire_document",
},
files={"file": f.name},
files={"file": f},
headers=headers,
)
assert response.status_code == 413
assert response.status_code == 200
assert response.json() == {"data": ["a"]}

0 comments on commit 0da6b2b

Please sign in to comment.