Skip to content

Commit

Permalink
Remove test_cli.py
Browse files Browse the repository at this point in the history
  • Loading branch information
NolanTrem committed Aug 21, 2024
2 parents f84da73 + 2bc77c2 commit e639c8a
Show file tree
Hide file tree
Showing 10 changed files with 62 additions and 40 deletions.
8 changes: 4 additions & 4 deletions go/cli/rag_operations.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ var ragCmd = &cobra.Command{
json.Unmarshal([]byte(filtersStr), &filters)
}
searchLimit, _ := cmd.Flags().GetInt("search-limit")
doHybridSearch, _ := cmd.Flags().GetBool("do-hybrid-search")
doHybridSearch, _ := cmd.Flags().GetBool("use-hybrid-search")
useKgSearch, _ := cmd.Flags().GetBool("use-kg-search")
kgSearchModel, _ := cmd.Flags().GetString("kg-search-model")
stream, _ := cmd.Flags().GetBool("stream")
Expand Down Expand Up @@ -108,7 +108,7 @@ var searchCmd = &cobra.Command{
json.Unmarshal([]byte(filtersStr), &filters)
}
searchLimit, _ := cmd.Flags().GetInt("search-limit")
doHybridSearch, _ := cmd.Flags().GetBool("do-hybrid-search")
doHybridSearch, _ := cmd.Flags().GetBool("use-hybrid-search")
useKgSearch, _ := cmd.Flags().GetBool("use-kg-search")
kgSearchModel, _ := cmd.Flags().GetString("kg-search-model")
kgSearchType, _ := cmd.Flags().GetString("kg-search-type")
Expand Down Expand Up @@ -161,7 +161,7 @@ func init() {
ragCmd.Flags().Bool("use-vector-search", true, "Use vector search")
ragCmd.Flags().String("filters", "", "Search filters as JSON")
ragCmd.Flags().Int("search-limit", 10, "Number of search results to return")
ragCmd.Flags().Bool("do-hybrid-search", false, "Perform hybrid search")
ragCmd.Flags().Bool("use-hybrid-search", false, "Perform hybrid search")
ragCmd.Flags().Bool("use-kg-search", false, "Use knowledge graph search")
ragCmd.Flags().String("kg-search-model", "", "Model for KG agent")
ragCmd.Flags().Bool("stream", false, "Stream the RAG response")
Expand All @@ -173,7 +173,7 @@ func init() {
searchCmd.Flags().Bool("use-vector-search", true, "Use vector search")
searchCmd.Flags().String("filters", "", "Search filters as JSON")
searchCmd.Flags().Int("search-limit", 10, "Number of search results to return")
searchCmd.Flags().Bool("do-hybrid-search", false, "Perform hybrid search")
searchCmd.Flags().Bool("use-hybrid-search", false, "Perform hybrid search")
searchCmd.Flags().Bool("use-kg-search", false, "Use knowledge graph search")
searchCmd.Flags().String("kg-search-model", "", "Model for KG agent")
searchCmd.Flags().String("kg-search-type", "global", "Local or Global")
Expand Down
2 changes: 1 addition & 1 deletion go/sdk/pkg/sdk/retrieval.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ type VectorSearchSettings struct {
UseVectorSearch bool `json:"use_vector_search"`
Filters map[string]interface{} `json:"filters"`
SearchLimit int `json:"search_limit"`
DoHybridSearch bool `json:"do_hybrid_search"`
DoHybridSearch bool `json:"use_hybrid_search"`
SelectedGroupIDs []string `json:"selected_group_ids"`
}

Expand Down
2 changes: 1 addition & 1 deletion js/sdk/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ const ragResult = await client.rag({
use_vector_search: true,
search_filters: {},
search_limit: 10,
do_hybrid_search: false,
use_hybrid_search: false,
use_kg_search: false,
kg_generation_config: {},
rag_generation_config: {
Expand Down
2 changes: 1 addition & 1 deletion js/sdk/src/models.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ export interface VectorSearchSettings {
use_vector_search?: boolean;
filters?: Record<string, any>;
search_limit?: number;
do_hybrid_search?: boolean;
use_hybrid_search?: boolean;
}

export interface KGSearchSettings {
Expand Down
8 changes: 4 additions & 4 deletions py/cli/commands/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
@click.option(
"--search-limit", default=None, help="Number of search results to return"
)
@click.option("--do-hybrid-search", is_flag=True, help="Perform hybrid search")
@click.option("--use-hybrid-search", is_flag=True, help="Perform hybrid search")
@click.option(
"--selected-group-ids", type=JSON, help="Group IDs to search for as a JSON"
)
Expand Down Expand Up @@ -64,7 +64,7 @@ def search(client, query, **kwargs):
"use_vector_search",
"filters",
"search_limit",
"do_hybrid_search",
"use_hybrid_search",
"selected_group_ids",
]
and v is not None
Expand Down Expand Up @@ -122,7 +122,7 @@ def search(client, query, **kwargs):
@click.option(
"--search-limit", default=10, help="Number of search results to return"
)
@click.option("--do-hybrid-search", is_flag=True, help="Perform hybrid search")
@click.option("--use-hybrid-search", is_flag=True, help="Perform hybrid search")
@click.option(
"--selected-group-ids", type=JSON, help="Group IDs to search for as a JSON"
)
Expand Down Expand Up @@ -169,7 +169,7 @@ def rag(client, query, **kwargs):
"use_vector_search",
"filters",
"search_limit",
"do_hybrid_search",
"use_hybrid_search",
"selected_group_ids",
]
and v is not None
Expand Down
30 changes: 14 additions & 16 deletions py/core/examples/hello_r2r.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,23 @@
from core import R2R, Document, GenerationConfig
from r2r import R2RClient

app = R2R() # You may pass a custom configuration to `R2R` with config=...
client = R2RClient("http://localhost:8000")

app.ingest_documents(
[
Document(
type="txt",
data="John is a person that works at Google.",
metadata={},
)
]
)
with open("test.txt", "w") as file:
file.write("John is a person that works at Google.")

client.ingest_files(file_paths=["test.txt"])

rag_results = app.rag(
"Who is john", GenerationConfig(model="gpt-3.5-turbo", temperature=0.0)
# Call RAG directly on an R2R object
rag_response = client.rag(
query="Who is john",
rag_generation_config={"model": "gpt-3.5-turbo", "temperature": 0.0},
)
print(f"Search Results:\n{rag_results.search_results}")
print(f"Completion:\n{rag_results.completion}")
results = rag_response['results']
print(f"Search Results:\n{results['search_results']}")
print(f"Completion:\n{results['completion']}")

# RAG Results:
# Search Results:
# AggregateSearchResult(vector_search_results=[VectorSearchResult(id=2d71e689-0a0e-5491-a50b-4ecb9494c832, score=0.6848798582029441, metadata={'text': 'John is a person that works at Google.', 'version': 'v0', 'chunk_order': 0, 'document_id': 'ed76b6ee-dd80-5172-9263-919d493b439a', 'extraction_id': '1ba494d7-cb2f-5f0e-9f64-76c31da11381', 'associatedQuery': 'Who is john'})], kg_search_results=None)
# Completion:
# ChatCompletion(id='chatcmpl-9g0HnjGjyWDLADe7E2EvLWa35cMkB', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content='John is a person that works at Google [1].', role='agent', function_call=None, tool_calls=None))], created=1719797903, model='gpt-3.5-turbo-0125', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=11, prompt_tokens=145, total_tokens=156))
# ChatCompletion(id='chatcmpl-9g0HnjGjyWDLADe7E2EvLWa35cMkB', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content='John is a person that works at Google [1].', role='assistant', function_call=None, tool_calls=None))], created=1719797903, model='gpt-3.5-turbo-0125', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=11, prompt_tokens=145, total_tokens=156))
1 change: 0 additions & 1 deletion py/core/main/api/routes/ingestion/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,6 @@ async def update_files_app(
chunking_provider = (
R2RProviderFactory.create_chunking_provider(config)
)
print("input metadatas = ", metadatas)

return await self.engine.aupdate_files(
files=files,
Expand Down
19 changes: 11 additions & 8 deletions py/core/main/services/ingestion_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ async def update_files(
self,
files: list[UploadFile],
user: UserResponse,
document_ids: list[UUID],
document_ids: Optional[list[UUID]],
metadatas: Optional[list[dict]] = None,
chunking_provider: Optional[ChunkingProvider] = None,
*args: Any,
Expand All @@ -120,12 +120,16 @@ async def update_files(
message="Database provider is not available for updating documents.",
)
try:
if len(document_ids) != len(files):
raise R2RException(
status_code=400,
message="Number of ids does not match number of files.",
)

if document_ids:
if len(document_ids) != len(files):
raise R2RException(
status_code=400,
message="Number of ids does not match number of files.",
)
else:
document_ids = [generate_user_document_id(file.filename, user.id) for file in files]
print('user_id = ', user.id)
print('document_ids = ', document_ids)
# Only superusers can modify arbitrary document ids, which this gate guarantees in conjuction with the check that follows
documents_overview = (
(
Expand Down Expand Up @@ -171,7 +175,6 @@ async def update_files(
document = self._file_to_document(
file, user, doc_id, updated_metadata
)
print("document = ", document)
documents.append(document)

ingestion_results = await self.ingest_documents(
Expand Down
28 changes: 25 additions & 3 deletions py/sdk/ingestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,31 @@ async def ingest_files(
Args:
file_paths (List[str]): List of file paths to ingest.
metadatas (Optional[List[dict]]): List of metadata dictionaries for each file.
document_ids (Optional[List[str]]): List of document IDs.
metadatas (Optional[List[dict]]): List of metadata dictionaries for each file.
versions (Optional[List[str]]): List of version strings for each file.
chunking_settings (Optional[Union[dict, ChunkingConfig]]): Custom chunking configuration.
Returns:
dict: Ingestion results containing processed, failed, and skipped documents.
"""
if document_ids is not None and len(file_paths) != len(document_ids):
raise ValueError(
"Number of file paths must match number of document IDs."
)
if metadatas is not None and len(file_paths) != len(metadatas):
raise ValueError(
"Number of metadatas must match number of document IDs."
)
if versions is not None and len(file_paths) != len(versions):
raise ValueError(
"Number of versions must match number of document IDs."
)
if chunking_settings is not None and chunking_settings is not ChunkingConfig:
# check if the provided dict maps to a ChunkingConfig
ChunkingConfig(**chunking_settings)


all_file_paths = []
for path in file_paths:
if os.path.isdir(path):
Expand Down Expand Up @@ -80,7 +97,7 @@ async def ingest_files(
async def update_files(
client,
file_paths: list[str],
document_ids: Optional[list[str]],
document_ids: Optional[list[str]] = None,
metadatas: Optional[list[dict]] = None,
chunking_settings: Optional[Union[dict, ChunkingConfig]] = None,
) -> dict:
Expand All @@ -96,10 +113,15 @@ async def update_files(
Returns:
dict: Update results containing processed, failed, and skipped documents.
"""
if len(file_paths) != len(document_ids):
if document_ids is not None and len(file_paths) != len(document_ids):
raise ValueError(
"Number of file paths must match number of document IDs."
)
if metadatas is not None and len(file_paths) != len(metadatas):
raise ValueError(
"Number of file paths must match number of document IDs."
)


with ExitStack() as stack:
files = [
Expand Down
2 changes: 1 addition & 1 deletion py/sdk/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ class VectorSearchSettings(BaseModel):
ge=1,
le=100,
)
do_hybrid_search: bool = Field(
use_hybrid_search: bool = Field(
default=False,
description="Whether to perform a hybrid search (combining vector and keyword search)",
)
Expand Down

0 comments on commit e639c8a

Please sign in to comment.