Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/revive advanced rag #932

Merged
merged 3 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ services:
- PYTHONUNBUFFERED=1
- PORT=${PORT:-8000}
- HOST=${HOST:-0.0.0.0}
- BASE_URL=${BASE_URL:-http://localhost}

# R2R
- CONFIG_NAME=${CONFIG_NAME:-}
Expand Down
25 changes: 2 additions & 23 deletions py/cli/command_group.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,10 @@
import os

import click
from sdk.client import R2RClient


# TODO: refactor this to remove config path and config name
@click.group()
@click.option(
"--config-path", default=None, help="Path to the configuration file"
)
@click.option(
"--config-name", default=None, help="Name of the configuration to use"
)
@click.option(
"--base-url",
default="http://localhost:8000",
help="Base URL for client mode",
)
@click.pass_context
def cli(ctx, config_path, config_name, base_url):
def cli(ctx):
"""R2R CLI for all core operations."""
if config_path and config_name:
raise click.UsageError(
"Cannot specify both config_path and config_name"
)

if config_path:
config_path = os.path.abspath(config_path)

ctx.obj = R2RClient(base_url)
ctx.obj = R2RClient()
12 changes: 10 additions & 2 deletions py/cli/commands/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,14 @@ def generate_report():
)
@click.option("--project-name", default="r2r", help="Project name for Docker")
@click.option("--image", help="Docker image to use")
@click.option("--config-path", help="Path to the configuration file")
@click.option(
"--config-name", default=None, help="Name of the R2R configuration to use"
)
@click.option(
"--config-path",
default=None,
help="Path to a custom R2R configuration file",
)
@click.pass_obj
def serve(
client,
Expand All @@ -194,6 +201,7 @@ def serve(
exclude_postgres,
project_name,
image,
config_name,
config_path,
):
"""Start the R2R server."""
Expand Down Expand Up @@ -239,7 +247,7 @@ def serve(
click.echo(f"Opening browser to {url}")
webbrowser.open(url)
else:
run_local_serve(client, host, port)
run_local_serve(host, port, config_name, config_path)


@cli.command()
Expand Down
26 changes: 15 additions & 11 deletions py/cli/utils/docker_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,11 @@ def remove_r2r_network():


def run_local_serve(
obj: R2RClient, host: str, port: int, config_path: Optional[str] = None
):
host: str,
port: int,
config_name: Optional[str] = None,
config_path: Optional[str] = None,
) -> None:
try:
from r2r import R2R
except ImportError:
Expand All @@ -83,14 +86,17 @@ def run_local_serve(
)
sys.exit(1)

r2r_instance = R2R()
llm_provider = r2r_instance.config.completion.provider
llm_model = r2r_instance.config.completion.generation_config.model
model_provider = llm_model.split("/")[0]
r2r_instance = R2R(config_name=config_name, config_path=config_path)

if config_name or config_path:
completion_config = r2r_instance.config.completion
llm_provider = completion_config.provider
llm_model = completion_config.generation_config.model
model_provider = llm_model.split("/")[0]
check_llm_reqs(llm_provider, model_provider, include_ollama=True)

available_port = find_available_port(port)

check_llm_reqs(llm_provider, model_provider, include_ollama=True)
r2r_instance.serve(host, available_port)


Expand All @@ -115,10 +121,8 @@ def run_docker_serve(
config_name = client.config_name
else:
config_name = "default"

config = R2RConfig.from_toml(
R2RBuilder.CONFIG_OPTIONS[config_name]
)

config = R2RConfig.from_toml(R2RBuilder.CONFIG_OPTIONS[config_name])

completion_provider = config.completion.provider
completion_model = config.completion.generation_config.model
Expand Down
3 changes: 3 additions & 0 deletions py/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@
"RelationshipType",
"format_entity_types",
"format_relations",
## INTEGRATIONS
# Serper
"SerperClient",
## MAIN
## R2R ABSTRACTIONS
"R2RProviders",
Expand Down
2 changes: 1 addition & 1 deletion py/core/base/abstractions/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class VectorSearchResult(BaseModel):
fragment_id: UUID
extraction_id: UUID
document_id: UUID
user_id: UUID
user_id: Optional[UUID]
group_ids: list[UUID]
score: float
text: str
Expand Down
32 changes: 0 additions & 32 deletions py/core/examples/scripts/run_hyde.py

This file was deleted.

20 changes: 20 additions & 0 deletions py/core/examples/scripts/serve_with_hyde.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from r2r import R2RBuilder, R2RConfig, R2RPipeFactoryWithMultiSearch
import fire


def main(task_prompt_name="hyde"):
# Load the default configuration file
config = R2RConfig.from_toml()

app = (
R2RBuilder(config)
.with_pipe_factory(R2RPipeFactoryWithMultiSearch)
.build(
# Add optional override arguments which propagate to the pipe factory
task_prompt_name=task_prompt_name,
)
)
app.serve()

if __name__ == "__main__":
fire.Fire(main)
Original file line number Diff line number Diff line change
@@ -1,24 +1,16 @@
import fire
from core.base.abstractions.llm import GenerationConfig
from r2r import R2RBuilder, SerperClient, WebSearchPipe


def run_rag_pipeline(query="Who was Aristotle?"):
def run_rag_pipeline():
# Create search pipe override and pipes
web_search_pipe = WebSearchPipe(
serper_client=SerperClient() # TODO - Develop a `WebSearchProvider` for configurability
)

app = R2RBuilder().with_vector_search_pipe(web_search_pipe).build()

# Run the RAG pipeline through the R2R application
result = app.rag(
query,
rag_generation_config=GenerationConfig(model="gpt-4o"),
)

print(f"Search Results:\n\n{result.search_results}")
print(f"RAG Results:\n\n{result.completion}")
app.serve()


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from core.main.assembly.factory_extensions import R2RPipeFactoryWithMultiSearch


def run_rag_pipeline(query="Who was Aristotle?"):
def run_rag_pipeline():
# Initialize a web search pipe
web_search_pipe = WebSearchPipe(serper_client=SerperClient())

Expand Down Expand Up @@ -37,17 +37,7 @@ def run_rag_pipeline(query="Who was Aristotle?"):
multi_inner_search_pipe_override=web_search_pipe,
query_generation_template_override=synthetic_query_generation_template,
)
)

# Run the RAG pipeline through the R2R application
result = app.rag(
query,
rag_generation_config=GenerationConfig(model="gpt-4o"),
)

print(f"Search Results:\n\n{result.search_results}")
print(f"RAG Results:\n\n{result.completion}")

).serve()

if __name__ == "__main__":
fire.Fire(run_rag_pipeline)
4 changes: 3 additions & 1 deletion py/core/main/api/routes/auth/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@


class AuthRouter(BaseRouter):
def __init__(self, engine: "R2REngine", run_type: RunType = RunType.INGESTION):
def __init__(
self, engine: "R2REngine", run_type: RunType = RunType.INGESTION
):
super().__init__(engine, run_type)
self.setup_routes()

Expand Down
4 changes: 3 additions & 1 deletion py/core/main/api/routes/ingestion/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@


class IngestionRouter(BaseRouter):
def __init__(self, engine: R2REngine, run_type: RunType = RunType.INGESTION):
def __init__(
self, engine: R2REngine, run_type: RunType = RunType.INGESTION
):
super().__init__(engine, run_type)
self.openapi_extras = self.load_openapi_extras()
self.setup_routes()
Expand Down
10 changes: 7 additions & 3 deletions py/core/main/api/routes/management/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@


class ManagementRouter(BaseRouter):
def __init__(self, engine: R2REngine, run_type: RunType = RunType.MANAGEMENT):
def __init__(
self, engine: R2REngine, run_type: RunType = RunType.MANAGEMENT
):
super().__init__(engine, run_type)
self.start_time = datetime.now(timezone.utc)
self.setup_routes()
Expand Down Expand Up @@ -99,8 +101,10 @@ async def get_analytics_app(

try:
result = await self.engine.aanalytics(
filter_criteria=LogFilterCriteria(**filter_criteria),
analysis_types=AnalysisTypes(**analysis_types),
filter_criteria=LogFilterCriteria(filters=filter_criteria),
analysis_types=AnalysisTypes(
analysis_types=analysis_types
),
)
return result
except json.JSONDecodeError as e:
Expand Down
4 changes: 3 additions & 1 deletion py/core/main/api/routes/restructure/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@


class RestructureRouter(BaseRouter):
def __init__(self, engine: R2REngine, run_type: RunType = RunType.RESTRUCTURE):
def __init__(
self, engine: R2REngine, run_type: RunType = RunType.RESTRUCTURE
):
super().__init__(engine, run_type)
self.setup_routes()

Expand Down
4 changes: 3 additions & 1 deletion py/core/main/api/routes/retrieval/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@


class RetrievalRouter(BaseRouter):
def __init__(self, engine: R2REngine, run_type: RunType = RunType.RETRIEVAL):
def __init__(
self, engine: R2REngine, run_type: RunType = RunType.RETRIEVAL
):
super().__init__(engine, run_type)
self.openapi_extras = self.load_openapi_extras()
self.setup_routes()
Expand Down
4 changes: 0 additions & 4 deletions py/core/main/app_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ class PipelineType(Enum):
def r2r_app(
config_name: Optional[str] = "default",
config_path: Optional[str] = None,
base_url: Optional[str] = None,
pipeline_type: PipelineType = PipelineType.QNA,
) -> FastAPI:
if pipeline_type != PipelineType.QNA:
Expand Down Expand Up @@ -58,19 +57,16 @@ def r2r_app(
config_path = os.getenv("CONFIG_PATH", None)
if not config_path and not config_name:
config_name = "default"
base_url = os.getenv("BASE_URL")
host = os.getenv("HOST", "0.0.0.0")
port = int(os.getenv("PORT", "8000"))
pipeline_type = os.getenv("PIPELINE_TYPE", "qna")

logger.info(f"Environment CONFIG_NAME: {config_name}")
logger.info(f"Environment CONFIG_PATH: {config_path}")
logger.info(f"Environment BASE_URL: {base_url}")
logger.info(f"Environment PIPELINE_TYPE: {pipeline_type}")

app = r2r_app(
config_name=config_name,
config_path=config_path,
base_url=base_url,
pipeline_type=PipelineType(pipeline_type),
)
15 changes: 11 additions & 4 deletions py/core/main/assembly/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ class R2RBuilder:
CONFIG_OPTIONS["default"] = None

@staticmethod
def _get_config(config_name):
def _get_config(config_name, config_path=None):
if config_path:
return R2RConfig.from_toml(config_path)
if config_name is None:
return R2RConfig.from_toml()
if config_name in R2RBuilder.CONFIG_OPTIONS:
Expand All @@ -55,10 +57,15 @@ def __init__(
self,
config: Optional[R2RConfig] = None,
config_name: Optional[str] = None,
config_path: Optional[str] = None,
):
if config and config_name:
raise ValueError("Cannot specify both config and config_name")
self.config = config or R2RBuilder._get_config(config_name)
if sum(x is not None for x in [config, config_name, config_path]) > 1:
raise ValueError(
"Specify only one of config, config_name, or config_path"
)
self.config = config or R2RBuilder._get_config(
config_name, config_path
)
self.r2r_app_override: Optional[Type[R2REngine]] = None
self.provider_factory_override: Optional[Type[R2RProviderFactory]] = (
None
Expand Down
1 change: 1 addition & 0 deletions py/core/main/assembly/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def create_prompt_provider(
prompt_provider = None
if prompt_config.provider == "r2r":
from core.providers import R2RPromptProvider

prompt_provider = R2RPromptProvider(prompt_config)
else:
raise ValueError(
Expand Down
Loading
Loading