Skip to content

Commit

Permalink
Fix CLI Tests (#912)
Browse files Browse the repository at this point in the history
Fix CLI tests
  • Loading branch information
NolanTrem authored Aug 21, 2024
1 parent 2bc77c2 commit d638ca9
Show file tree
Hide file tree
Showing 18 changed files with 797 additions and 344 deletions.
8 changes: 6 additions & 2 deletions py/cli/commands/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
@click.option(
"--search-limit", default=None, help="Number of search results to return"
)
@click.option("--use-hybrid-search", is_flag=True, help="Perform hybrid search")
@click.option(
"--use-hybrid-search", is_flag=True, help="Perform hybrid search"
)
@click.option(
"--selected-group-ids", type=JSON, help="Group IDs to search for as a JSON"
)
Expand Down Expand Up @@ -122,7 +124,9 @@ def search(client, query, **kwargs):
@click.option(
"--search-limit", default=10, help="Number of search results to return"
)
@click.option("--use-hybrid-search", is_flag=True, help="Perform hybrid search")
@click.option(
"--use-hybrid-search", is_flag=True, help="Perform hybrid search"
)
@click.option(
"--selected-group-ids", type=JSON, help="Group IDs to search for as a JSON"
)
Expand Down
2 changes: 2 additions & 0 deletions py/cli/utils/param_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ class JsonParamType(click.ParamType):
name = "json"

def convert(self, value, param, ctx) -> Dict[str, Any]:
if value is None:
return None
if isinstance(value, dict):
return value
try:
Expand Down
4 changes: 2 additions & 2 deletions py/core/examples/hello_r2r.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
query="Who is john",
rag_generation_config={"model": "gpt-3.5-turbo", "temperature": 0.0},
)
results = rag_response['results']
results = rag_response["results"]
print(f"Search Results:\n{results['search_results']}")
print(f"Completion:\n{results['completion']}")

# RAG Results:
# Search Results:
# AggregateSearchResult(vector_search_results=[VectorSearchResult(id=2d71e689-0a0e-5491-a50b-4ecb9494c832, score=0.6848798582029441, metadata={'text': 'John is a person that works at Google.', 'version': 'v0', 'chunk_order': 0, 'document_id': 'ed76b6ee-dd80-5172-9263-919d493b439a', 'extraction_id': '1ba494d7-cb2f-5f0e-9f64-76c31da11381', 'associatedQuery': 'Who is john'})], kg_search_results=None)
# Completion:
# ChatCompletion(id='chatcmpl-9g0HnjGjyWDLADe7E2EvLWa35cMkB', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content='John is a person that works at Google [1].', role='assistant', function_call=None, tool_calls=None))], created=1719797903, model='gpt-3.5-turbo-0125', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=11, prompt_tokens=145, total_tokens=156))
# ChatCompletion(id='chatcmpl-9g0HnjGjyWDLADe7E2EvLWa35cMkB', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content='John is a person that works at Google [1].', role='assistant', function_call=None, tool_calls=None))], created=1719797903, model='gpt-3.5-turbo-0125', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=11, prompt_tokens=145, total_tokens=156))
9 changes: 6 additions & 3 deletions py/core/main/services/ingestion_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,12 @@ async def update_files(
message="Number of ids does not match number of files.",
)
else:
document_ids = [generate_user_document_id(file.filename, user.id) for file in files]
print('user_id = ', user.id)
print('document_ids = ', document_ids)
document_ids = [
generate_user_document_id(file.filename, user.id)
for file in files
]
print("user_id = ", user.id)
print("document_ids = ", document_ids)
# Only superusers can modify arbitrary document ids, which this gate guarantees in conjuction with the check that follows
documents_overview = (
(
Expand Down
7 changes: 4 additions & 3 deletions py/sdk/ingestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,12 @@ async def ingest_files(
raise ValueError(
"Number of versions must match number of document IDs."
)
if chunking_settings is not None and chunking_settings is not ChunkingConfig:
if (
chunking_settings is not None
and chunking_settings is not ChunkingConfig
):
# check if the provided dict maps to a ChunkingConfig
ChunkingConfig(**chunking_settings)


all_file_paths = []
for path in file_paths:
Expand Down Expand Up @@ -121,7 +123,6 @@ async def update_files(
raise ValueError(
"Number of file paths must match number of document IDs."
)


with ExitStack() as stack:
files = [
Expand Down
Empty file added py/tests/cli/__init__.py
Empty file.
Empty file.
32 changes: 32 additions & 0 deletions py/tests/cli/commands/test_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import pytest
from cli.commands.auth import generate_private_key
from click.testing import CliRunner


@pytest.fixture
def runner():
return CliRunner()


def test_generate_private_key(runner):
result = runner.invoke(generate_private_key)
assert result.exit_code == 0
assert "Generated Private Key:" in result.output
assert (
"Keep this key secure and use it as your R2R_SECRET_KEY."
in result.output
)


def test_generate_private_key_output_format(runner):
result = runner.invoke(generate_private_key)
key_line = [
line
for line in result.output.split("\n")
if "Generated Private Key:" in line
][0]
key = key_line.split(":")[1].strip()
assert len(key) > 32 # The key should be reasonably long
assert (
key.isalnum() or "-" in key or "_" in key
) # The key should be URL-safe
174 changes: 174 additions & 0 deletions py/tests/cli/commands/test_ingestion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
import tempfile
from unittest.mock import MagicMock, patch

import click
import pytest
from cli.cli import cli
from click.testing import CliRunner


@pytest.fixture
def runner():
return CliRunner()


@pytest.fixture
def mock_client():
return MagicMock()


@pytest.fixture
def temp_file():
with tempfile.NamedTemporaryFile(mode="w", delete=False) as f:
f.write("Test content")
f.flush()
yield f.name


@pytest.fixture(autouse=True)
def mock_cli_obj(mock_client):
with patch(
"cli.commands.ingestion.click.get_current_context"
) as mock_context:
mock_context.return_value.obj = mock_client
yield


@pytest.fixture(autouse=True)
def mock_r2r_client():
with patch(
"cli.command_group.R2RClient", new=MagicMock()
) as MockR2RClient:
mock_client = MockR2RClient.return_value
mock_client.ingest_files.return_value = {"status": "success"}
mock_client.update_files.return_value = {"status": "updated"}

original_callback = cli.callback

def new_callback(*args, **kwargs):
ctx = click.get_current_context()
ctx.obj = mock_client
return original_callback(*args, **kwargs)

cli.callback = new_callback

yield mock_client

cli.callback = original_callback


def test_ingest_files(runner, mock_r2r_client, temp_file):
result = runner.invoke(cli, ["ingest-files", temp_file])
assert result.exit_code == 0
assert '"status": "success"' in result.output
mock_r2r_client.ingest_files.assert_called_once_with(
[temp_file], None, None, None
)


def test_ingest_files_with_options(runner, mock_r2r_client, temp_file):
result = runner.invoke(
cli,
[
"ingest-files",
temp_file,
"--document-ids",
"doc1",
"--metadatas",
'{"key": "value"}',
"--versions",
"v1",
],
)
assert result.exit_code == 0
assert '"status": "success"' in result.output
assert mock_r2r_client.ingest_files.called, "ingest_files was not called"
mock_r2r_client.ingest_files.assert_called_once_with(
[temp_file], {"key": "value"}, ["doc1"], ["v1"]
)


def test_update_files(runner, mock_r2r_client, temp_file):
result = runner.invoke(
cli,
[
"update-files",
temp_file,
"--document-ids",
"doc1",
"--metadatas",
'{"key": "new_value"}',
],
)
assert result.exit_code == 0
assert '"status": "updated"' in result.output
assert mock_r2r_client.update_files.called, "update_files was not called"
mock_r2r_client.update_files.assert_called_once_with(
[temp_file], ["doc1"], [{"key": "new_value"}]
)


@patch("cli.commands.ingestion.ingest_files_from_urls")
def test_ingest_sample_file(mock_ingest, runner, mock_r2r_client):
mock_ingest.return_value = ["aristotle.txt"]
result = runner.invoke(cli, ["ingest-sample-file"])
assert result.exit_code == 0
assert "Sample file ingestion completed" in result.output
assert "aristotle.txt" in result.output
mock_ingest.assert_called_once()


@patch("cli.commands.ingestion.ingest_files_from_urls")
def test_ingest_sample_files(mock_ingest, runner, mock_r2r_client):
mock_ingest.return_value = ["aristotle.txt", "got.txt"]
result = runner.invoke(cli, ["ingest-sample-files"])
assert result.exit_code == 0
assert "Sample files ingestion completed" in result.output
assert "aristotle.txt" in result.output
assert "got.txt" in result.output
mock_ingest.assert_called_once()


@patch("cli.commands.ingestion.requests.get")
@patch("cli.commands.ingestion.tempfile.NamedTemporaryFile")
def test_ingest_files_from_urls(mock_temp_file, mock_get, mock_r2r_client):
mock_get.return_value.text = "File content"
mock_temp_file.return_value.__enter__.return_value.name = "/tmp/test_file"
mock_r2r_client.ingest_files.return_value = {"status": "success"}

with patch("cli.commands.ingestion.os.unlink") as mock_unlink:
from cli.commands.ingestion import ingest_files_from_urls

result = ingest_files_from_urls(
mock_r2r_client, ["http://example.com/file.txt"]
)

assert result == ["file.txt"]
mock_r2r_client.ingest_files.assert_called_once_with(["/tmp/test_file"])
mock_unlink.assert_called_once_with("/tmp/test_file")


def test_ingest_files_with_invalid_file(runner, mock_r2r_client):
result = runner.invoke(cli, ["ingest-files", "nonexistent_file.txt"])
assert result.exit_code != 0
assert "Error" in result.output
assert not mock_r2r_client.ingest_files.called


def test_update_files_with_invalid_metadata(
runner, mock_r2r_client, temp_file
):
result = runner.invoke(
cli,
[
"update-files",
temp_file,
"--document-ids",
"doc1",
"--metadatas",
"invalid_json",
],
)
assert result.exit_code != 0
assert "Error" in result.output
assert not mock_r2r_client.update_files.called
Loading

0 comments on commit d638ca9

Please sign in to comment.