Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix CLI Tests #912

Merged
merged 3 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions py/cli/commands/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
@click.option(
"--search-limit", default=None, help="Number of search results to return"
)
@click.option("--use-hybrid-search", is_flag=True, help="Perform hybrid search")
@click.option(
"--use-hybrid-search", is_flag=True, help="Perform hybrid search"
)
@click.option(
"--selected-group-ids", type=JSON, help="Group IDs to search for as a JSON"
)
Expand Down Expand Up @@ -122,7 +124,9 @@ def search(client, query, **kwargs):
@click.option(
"--search-limit", default=10, help="Number of search results to return"
)
@click.option("--use-hybrid-search", is_flag=True, help="Perform hybrid search")
@click.option(
"--use-hybrid-search", is_flag=True, help="Perform hybrid search"
)
@click.option(
"--selected-group-ids", type=JSON, help="Group IDs to search for as a JSON"
)
Expand Down
2 changes: 2 additions & 0 deletions py/cli/utils/param_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ class JsonParamType(click.ParamType):
name = "json"

def convert(self, value, param, ctx) -> Dict[str, Any]:
if value is None:
return None
if isinstance(value, dict):
return value
try:
Expand Down
4 changes: 2 additions & 2 deletions py/core/examples/hello_r2r.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
query="Who is john",
rag_generation_config={"model": "gpt-3.5-turbo", "temperature": 0.0},
)
results = rag_response['results']
results = rag_response["results"]
print(f"Search Results:\n{results['search_results']}")
print(f"Completion:\n{results['completion']}")

# RAG Results:
# Search Results:
# AggregateSearchResult(vector_search_results=[VectorSearchResult(id=2d71e689-0a0e-5491-a50b-4ecb9494c832, score=0.6848798582029441, metadata={'text': 'John is a person that works at Google.', 'version': 'v0', 'chunk_order': 0, 'document_id': 'ed76b6ee-dd80-5172-9263-919d493b439a', 'extraction_id': '1ba494d7-cb2f-5f0e-9f64-76c31da11381', 'associatedQuery': 'Who is john'})], kg_search_results=None)
# Completion:
# ChatCompletion(id='chatcmpl-9g0HnjGjyWDLADe7E2EvLWa35cMkB', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content='John is a person that works at Google [1].', role='assistant', function_call=None, tool_calls=None))], created=1719797903, model='gpt-3.5-turbo-0125', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=11, prompt_tokens=145, total_tokens=156))
# ChatCompletion(id='chatcmpl-9g0HnjGjyWDLADe7E2EvLWa35cMkB', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content='John is a person that works at Google [1].', role='assistant', function_call=None, tool_calls=None))], created=1719797903, model='gpt-3.5-turbo-0125', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=11, prompt_tokens=145, total_tokens=156))
9 changes: 6 additions & 3 deletions py/core/main/services/ingestion_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,12 @@ async def update_files(
message="Number of ids does not match number of files.",
)
else:
document_ids = [generate_user_document_id(file.filename, user.id) for file in files]
print('user_id = ', user.id)
print('document_ids = ', document_ids)
document_ids = [
generate_user_document_id(file.filename, user.id)
for file in files
]
print("user_id = ", user.id)
print("document_ids = ", document_ids)
# Only superusers can modify arbitrary document ids, which this gate guarantees in conjuction with the check that follows
documents_overview = (
(
Expand Down
7 changes: 4 additions & 3 deletions py/sdk/ingestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,12 @@ async def ingest_files(
raise ValueError(
"Number of versions must match number of document IDs."
)
if chunking_settings is not None and chunking_settings is not ChunkingConfig:
if (
chunking_settings is not None
and chunking_settings is not ChunkingConfig
):
# check if the provided dict maps to a ChunkingConfig
ChunkingConfig(**chunking_settings)


all_file_paths = []
for path in file_paths:
Expand Down Expand Up @@ -121,7 +123,6 @@ async def update_files(
raise ValueError(
"Number of file paths must match number of document IDs."
)


with ExitStack() as stack:
files = [
Expand Down
Empty file added py/tests/cli/__init__.py
Empty file.
Empty file.
32 changes: 32 additions & 0 deletions py/tests/cli/commands/test_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import pytest
from cli.commands.auth import generate_private_key
from click.testing import CliRunner


@pytest.fixture
def runner():
return CliRunner()


def test_generate_private_key(runner):
result = runner.invoke(generate_private_key)
assert result.exit_code == 0
assert "Generated Private Key:" in result.output
assert (
"Keep this key secure and use it as your R2R_SECRET_KEY."
in result.output
)


def test_generate_private_key_output_format(runner):
result = runner.invoke(generate_private_key)
key_line = [
line
for line in result.output.split("\n")
if "Generated Private Key:" in line
][0]
key = key_line.split(":")[1].strip()
assert len(key) > 32 # The key should be reasonably long
assert (
key.isalnum() or "-" in key or "_" in key
) # The key should be URL-safe
174 changes: 174 additions & 0 deletions py/tests/cli/commands/test_ingestion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
import tempfile
from unittest.mock import MagicMock, patch

import click
import pytest
from cli.cli import cli
from click.testing import CliRunner


@pytest.fixture
def runner():
return CliRunner()


@pytest.fixture
def mock_client():
return MagicMock()


@pytest.fixture
def temp_file():
with tempfile.NamedTemporaryFile(mode="w", delete=False) as f:
f.write("Test content")
f.flush()
yield f.name


@pytest.fixture(autouse=True)
def mock_cli_obj(mock_client):
with patch(
"cli.commands.ingestion.click.get_current_context"
) as mock_context:
mock_context.return_value.obj = mock_client
yield


@pytest.fixture(autouse=True)
def mock_r2r_client():
with patch(
"cli.command_group.R2RClient", new=MagicMock()
) as MockR2RClient:
mock_client = MockR2RClient.return_value
mock_client.ingest_files.return_value = {"status": "success"}
mock_client.update_files.return_value = {"status": "updated"}

original_callback = cli.callback

def new_callback(*args, **kwargs):
ctx = click.get_current_context()
ctx.obj = mock_client
return original_callback(*args, **kwargs)

cli.callback = new_callback

yield mock_client

cli.callback = original_callback


def test_ingest_files(runner, mock_r2r_client, temp_file):
result = runner.invoke(cli, ["ingest-files", temp_file])
assert result.exit_code == 0
assert '"status": "success"' in result.output
mock_r2r_client.ingest_files.assert_called_once_with(
[temp_file], None, None, None
)


def test_ingest_files_with_options(runner, mock_r2r_client, temp_file):
result = runner.invoke(
cli,
[
"ingest-files",
temp_file,
"--document-ids",
"doc1",
"--metadatas",
'{"key": "value"}',
"--versions",
"v1",
],
)
assert result.exit_code == 0
assert '"status": "success"' in result.output
assert mock_r2r_client.ingest_files.called, "ingest_files was not called"
mock_r2r_client.ingest_files.assert_called_once_with(
[temp_file], {"key": "value"}, ["doc1"], ["v1"]
)


def test_update_files(runner, mock_r2r_client, temp_file):
result = runner.invoke(
cli,
[
"update-files",
temp_file,
"--document-ids",
"doc1",
"--metadatas",
'{"key": "new_value"}',
],
)
assert result.exit_code == 0
assert '"status": "updated"' in result.output
assert mock_r2r_client.update_files.called, "update_files was not called"
mock_r2r_client.update_files.assert_called_once_with(
[temp_file], ["doc1"], [{"key": "new_value"}]
)


@patch("cli.commands.ingestion.ingest_files_from_urls")
def test_ingest_sample_file(mock_ingest, runner, mock_r2r_client):
mock_ingest.return_value = ["aristotle.txt"]
result = runner.invoke(cli, ["ingest-sample-file"])
assert result.exit_code == 0
assert "Sample file ingestion completed" in result.output
assert "aristotle.txt" in result.output
mock_ingest.assert_called_once()


@patch("cli.commands.ingestion.ingest_files_from_urls")
def test_ingest_sample_files(mock_ingest, runner, mock_r2r_client):
mock_ingest.return_value = ["aristotle.txt", "got.txt"]
result = runner.invoke(cli, ["ingest-sample-files"])
assert result.exit_code == 0
assert "Sample files ingestion completed" in result.output
assert "aristotle.txt" in result.output
assert "got.txt" in result.output
mock_ingest.assert_called_once()


@patch("cli.commands.ingestion.requests.get")
@patch("cli.commands.ingestion.tempfile.NamedTemporaryFile")
def test_ingest_files_from_urls(mock_temp_file, mock_get, mock_r2r_client):
mock_get.return_value.text = "File content"
mock_temp_file.return_value.__enter__.return_value.name = "/tmp/test_file"
mock_r2r_client.ingest_files.return_value = {"status": "success"}

with patch("cli.commands.ingestion.os.unlink") as mock_unlink:
from cli.commands.ingestion import ingest_files_from_urls

result = ingest_files_from_urls(
mock_r2r_client, ["http://example.com/file.txt"]
)

assert result == ["file.txt"]
mock_r2r_client.ingest_files.assert_called_once_with(["/tmp/test_file"])
mock_unlink.assert_called_once_with("/tmp/test_file")


def test_ingest_files_with_invalid_file(runner, mock_r2r_client):
result = runner.invoke(cli, ["ingest-files", "nonexistent_file.txt"])
assert result.exit_code != 0
assert "Error" in result.output
assert not mock_r2r_client.ingest_files.called


def test_update_files_with_invalid_metadata(
runner, mock_r2r_client, temp_file
):
result = runner.invoke(
cli,
[
"update-files",
temp_file,
"--document-ids",
"doc1",
"--metadatas",
"invalid_json",
],
)
assert result.exit_code != 0
assert "Error" in result.output
assert not mock_r2r_client.update_files.called
Loading
Loading