Skip to content

Commit

Permalink
Merge pull request #1165 from Shivam-19agg/add-azure-storage
Browse files Browse the repository at this point in the history
Added the Azure-storage option for document sources.
  • Loading branch information
ElishaKay authored Feb 19, 2025
2 parents b2d29c8 + 6d16238 commit 1051f0b
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 11 deletions.
3 changes: 2 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
OPENAI_API_KEY=
TAVILY_API_KEY=
DOC_PATH=./my-docs

AZURE_CONNECTION_STRING=
AZURE_CONTAINER_NAME=
## OPTIONAL CONFIGS FOR DOCKERIZED HOSTS
# NEXT_PUBLIC_GA_MEASUREMENT_ID='' # Can be left empty if not using Google Analytics
# NEXT_PUBLIC_GPTR_API_URL=http://0.0.0.0:8000 # Your server IP with backend port
Expand Down
1 change: 1 addition & 0 deletions frontend/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ <h1 class="text-4xl font-extrabold mx-auto lg:text-7xl">
<option value="web">The Web</option>
<option value="local">My Documents</option>
<option value="hybrid">Hybrid</option>
<option value="azure">Azure storage</option>
</select>
</div>
<input type="submit" value="Research" class="btn btn-primary button-padding">
Expand Down
22 changes: 22 additions & 0 deletions gpt_researcher/document/azure_document_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from azure.storage.blob import BlobServiceClient
import os
import tempfile

class AzureDocumentLoader:
def __init__(self, container_name, connection_string):
self.client = BlobServiceClient.from_connection_string(connection_string)
self.container = self.client.get_container_client(container_name)

async def load(self):
"""Download all blobs to temp files and return their paths."""
temp_dir = tempfile.mkdtemp()
blobs = self.container.list_blobs()
file_paths = []
for blob in blobs:
blob_client = self.container.get_blob_client(blob.name)
local_path = os.path.join(temp_dir, blob.name)
with open(local_path, "wb") as f:
blob_data = blob_client.download_blob()
f.write(blob_data.readall())
file_paths.append(local_path)
return file_paths # Pass to existing DocumentLoader
35 changes: 27 additions & 8 deletions gpt_researcher/document/document.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import os

from typing import List, Union
from langchain_community.document_loaders import (
PyMuPDFLoader,
TextLoader,
Expand All @@ -15,17 +15,36 @@

class DocumentLoader:

def __init__(self, path):
def __init__(self, path: Union[str, List[str]]):
self.path = path

async def load(self) -> list:
tasks = []
for root, dirs, files in os.walk(self.path):
for file in files:
file_path = os.path.join(root, file)
file_name, file_extension_with_dot = os.path.splitext(file_path)
file_extension = file_extension_with_dot.strip(".")
tasks.append(self._load_document(file_path, file_extension))
if isinstance(self.path, list):
for file_path in self.path:
if os.path.isfile(file_path): # Ensure it's a valid file
filename = os.path.basename(file_path)
file_name, file_extension_with_dot = os.path.splitext(filename)
file_extension = file_extension_with_dot.strip(".").lower()
tasks.append(self._load_document(file_path, file_extension))

elif isinstance(self.path, (str, bytes, os.PathLike)):
for root, dirs, files in os.walk(self.path):
for file in files:
file_path = os.path.join(root, file)
file_name, file_extension_with_dot = os.path.splitext(file)
file_extension = file_extension_with_dot.strip(".").lower()
tasks.append(self._load_document(file_path, file_extension))

else:
raise ValueError("Invalid type for path. Expected str, bytes, os.PathLike, or list thereof.")

# for root, dirs, files in os.walk(self.path):
# for file in files:
# file_path = os.path.join(root, file)
# file_name, file_extension_with_dot = os.path.splitext(file_path)
# file_extension = file_extension_with_dot.strip(".")
# tasks.append(self._load_document(file_path, file_extension))

docs = []
for pages in await asyncio.gather(*tasks):
Expand Down
12 changes: 11 additions & 1 deletion gpt_researcher/skills/researcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
from typing import Dict, Optional
import logging

import os
from ..actions.utils import stream_output
from ..actions.query_processing import plan_research_outline, get_search_results
from ..document import DocumentLoader, OnlineDocumentLoader, LangChainDocumentLoader
Expand Down Expand Up @@ -115,6 +115,16 @@ async def conduct_research(self):
web_context = await self._get_context_by_web_search(self.researcher.query)
research_data = f"Context from local documents: {docs_context}\n\nContext from web sources: {web_context}"

elif self.researcher.report_source == ReportSource.Azure.value:
from ..document.azure_document_loader import AzureDocumentLoader
azure_loader = AzureDocumentLoader(
container_name=os.getenv("AZURE_CONTAINER_NAME"),
connection_string=os.getenv("AZURE_CONNECTION_STRING")
)
azure_files = await azure_loader.load()
document_data = await DocumentLoader(azure_files).load() # Reuse existing loader
research_data = await self._get_context_by_web_search(self.researcher.query, document_data)

elif self.researcher.report_source == ReportSource.LangChainDocuments.value:
langchain_documents_data = await LangChainDocumentLoader(
self.researcher.documents
Expand Down
1 change: 1 addition & 0 deletions gpt_researcher/utils/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class ReportType(Enum):
class ReportSource(Enum):
Web = "web"
Local = "local"
Azure = "azure"
LangChainDocuments = "langchain_documents"
LangChainVectorStore = "langchain_vectorstore"
Static = "static"
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ unstructured
json_repair
json5
loguru

azure-storage-blob
# uncomment for testing
# pytest
# pytest-asyncio

0 comments on commit 1051f0b

Please sign in to comment.