Skip to content

Commit

Permalink
remove sample_row_in_table_info and simplify set operations in SQLDB (l…
Browse files Browse the repository at this point in the history
…angchain-ai#932)

-Address TODO: deprecate for sample_row_in_table_info
-Simplify set operations by casting to sets to not need multiple set
casts + .difference() calls
  • Loading branch information
kwhuo68 authored and zachschillaci27 committed Mar 8, 2023
1 parent 257095d commit 8393cf4
Showing 1 changed file with 6 additions and 16 deletions.
22 changes: 6 additions & 16 deletions langchain/sql_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,40 +17,30 @@ def __init__(
ignore_tables: Optional[List[str]] = None,
include_tables: Optional[List[str]] = None,
sample_rows_in_table_info: int = 0,
# TODO: deprecate.
sample_row_in_table_info: bool = False,
):
"""Create engine from database URI."""
if sample_row_in_table_info and sample_rows_in_table_info > 0:
raise ValueError(
"Only one of `sample_row_in_table_info` "
"and `sample_rows_in_table_info` should be set"
)
self._engine = engine
self._schema = schema
if include_tables and ignore_tables:
raise ValueError("Cannot specify both include_tables and ignore_tables")

self._inspector = inspect(self._engine)
self._all_tables = self._inspector.get_table_names(schema=schema)
self._include_tables = include_tables or []
self._all_tables = set(self._inspector.get_table_names(schema=schema))
self._include_tables = set(include_tables) if include_tables else set()
if self._include_tables:
missing_tables = set(self._include_tables).difference(self._all_tables)
missing_tables = self._include_tables - self._all_tables
if missing_tables:
raise ValueError(
f"include_tables {missing_tables} not found in database"
)
self._ignore_tables = ignore_tables or []
self._ignore_tables = set(ignore_tables) if ignore_tables else set()
if self._ignore_tables:
missing_tables = set(self._ignore_tables).difference(self._all_tables)
missing_tables = self._ignore_tables - self._all_tables
if missing_tables:
raise ValueError(
f"ignore_tables {missing_tables} not found in database"
)
self._sample_rows_in_table_info = sample_rows_in_table_info
# TODO: deprecate
if sample_row_in_table_info:
self._sample_rows_in_table_info = 1

@classmethod
def from_uri(cls, database_uri: str, **kwargs: Any) -> SQLDatabase:
Expand All @@ -66,7 +56,7 @@ def get_table_names(self) -> Iterable[str]:
"""Get names of tables available."""
if self._include_tables:
return self._include_tables
return set(self._all_tables) - set(self._ignore_tables)
return self._all_tables - self._ignore_tables

@property
def table_info(self) -> str:
Expand Down

0 comments on commit 8393cf4

Please sign in to comment.