diff --git a/langchain/chains/query_constructor/ir.py b/langchain/chains/query_constructor/ir.py index f131de3a1495f..5c67efa80d2f6 100644 --- a/langchain/chains/query_constructor/ir.py +++ b/langchain/chains/query_constructor/ir.py @@ -71,6 +71,8 @@ class Comparator(str, Enum): GTE = "gte" LT = "lt" LTE = "lte" + CONTAIN = 'contain' + LIKE = "like" class FilterDirective(Expr, ABC): diff --git a/langchain/chains/query_constructor/parser.py b/langchain/chains/query_constructor/parser.py index 4c560bfe27a0d..09f51ac7cabd3 100644 --- a/langchain/chains/query_constructor/parser.py +++ b/langchain/chains/query_constructor/parser.py @@ -1,4 +1,5 @@ from typing import Any, Optional, Sequence, Union +import datetime try: import lark @@ -34,12 +35,14 @@ def v_args(*args: Any, **kwargs: Any) -> Any: # type: ignore ?value: SIGNED_INT -> int | SIGNED_FLOAT -> float + | TIMESTAMP -> timestamp | list | string | ("false" | "False" | "FALSE") -> false | ("true" | "True" | "TRUE") -> true args: expr ("," expr)* + TIMESTAMP.2: /["'](\d{4}-[01]\d-[0-3]\d)["']/ string: /'[^']*'/ | ESCAPED_STRING list: "[" [args] "]" @@ -119,6 +122,10 @@ def int(self, item: Any) -> int: def float(self, item: Any) -> float: return float(item) + + def timestamp(self, item: Any): + item = item.replace("'", '"') + return datetime.datetime.strptime(item, '"%Y-%m-%d"').date() def string(self, item: Any) -> str: # Remove escaped quotes diff --git a/langchain/chains/query_constructor/prompt.py b/langchain/chains/query_constructor/prompt.py index ae7530b70f0c3..6545ae6fe0ae4 100644 --- a/langchain/chains/query_constructor/prompt.py +++ b/langchain/chains/query_constructor/prompt.py @@ -141,6 +141,7 @@ Make sure that you only use the comparators and logical operators listed above and \ no others. Make sure that filters only refer to attributes that exist in the data source. +Make sure that filters only use the attributed names with its function names if there are functions applied on them. Make sure that filters take into account the descriptions of attributes and only make \ comparisons that are feasible given the type of data being stored. Make sure that filters are only used as needed. If there are no filters that should be \ @@ -179,6 +180,8 @@ Make sure that you only use the comparators and logical operators listed above and \ no others. Make sure that filters only refer to attributes that exist in the data source. +Make sure that filters only use the attributed names with its function names if there are functions applied on them. +Make sure that filters only use format `YYYY-MM-DD` when handling timestamp data typed values. Make sure that filters take into account the descriptions of attributes and only make \ comparisons that are feasible given the type of data being stored. Make sure that filters are only used as needed. If there are no filters that should be \ diff --git a/langchain/retrievers/self_query/base.py b/langchain/retrievers/self_query/base.py index 6fed65b45310b..b2fe1acb808e0 100644 --- a/langchain/retrievers/self_query/base.py +++ b/langchain/retrievers/self_query/base.py @@ -12,8 +12,9 @@ from langchain.retrievers.self_query.pinecone import PineconeTranslator from langchain.retrievers.self_query.qdrant import QdrantTranslator from langchain.retrievers.self_query.weaviate import WeaviateTranslator +from langchain.retrievers.self_query.myscale import MyScaleTranslator from langchain.schema import BaseRetriever, Document -from langchain.vectorstores import Chroma, Pinecone, Qdrant, VectorStore, Weaviate +from langchain.vectorstores import Chroma, Pinecone, Qdrant, VectorStore, Weaviate, MyScale def _get_builtin_translator(vectorstore: VectorStore) -> Visitor: @@ -24,6 +25,7 @@ def _get_builtin_translator(vectorstore: VectorStore) -> Visitor: Chroma: ChromaTranslator, Weaviate: WeaviateTranslator, Qdrant: QdrantTranslator, + MyScale: MyScaleTranslator, } if vectorstore_cls not in BUILTIN_TRANSLATORS: raise ValueError( @@ -32,6 +34,8 @@ def _get_builtin_translator(vectorstore: VectorStore) -> Visitor: ) if isinstance(vectorstore, Qdrant): return QdrantTranslator(metadata_key=vectorstore.metadata_payload_key) + elif isinstance(vectorstore, MyScale): + return MyScaleTranslator(metadata_key=vectorstore.metadata_column) return BUILTIN_TRANSLATORS[vectorstore_cls]() diff --git a/langchain/retrievers/self_query/myscale.py b/langchain/retrievers/self_query/myscale.py new file mode 100644 index 0000000000000..32e67941c51af --- /dev/null +++ b/langchain/retrievers/self_query/myscale.py @@ -0,0 +1,95 @@ +import re +import datetime +from typing import Dict, Tuple, Union +from langchain.chains.query_constructor.ir import ( + Comparator, + Comparison, + Operation, + Operator, + StructuredQuery, + Visitor, +) + +def DEFAULT_COMPOSER(op): + def f(*args): + args = map(str, args) + return f' {op} '.join(args) + return f + + +def FUNCTION_COMPOSER(op): + def f(*args): + args = map(str, args) + return f"{op}({','.join(args)})" + return f + + +class MyScaleTranslator(Visitor): + """Logic for converting internal query language elements to valid filters.""" + + allowed_operators = [Operator.AND, Operator.OR, Operator.NOT] + """Subset of allowed logical operators.""" + + allowed_comparators = [Comparator.EQ, + Comparator.GT, + Comparator.GTE, + Comparator.LT, + Comparator.LTE, + Comparator.CONTAIN, + Comparator.LIKE] + + map_dict = {Operator.AND: DEFAULT_COMPOSER("AND"), + Operator.OR: DEFAULT_COMPOSER("OR"), + Operator.NOT: DEFAULT_COMPOSER("NOT"), + Comparator.EQ: DEFAULT_COMPOSER('='), + Comparator.GT: DEFAULT_COMPOSER('>'), + Comparator.GTE: DEFAULT_COMPOSER('>='), + Comparator.LT: DEFAULT_COMPOSER('<='), + Comparator.LTE: DEFAULT_COMPOSER('<'), + Comparator.CONTAIN: FUNCTION_COMPOSER('has'), + Comparator.LIKE: DEFAULT_COMPOSER("ILIKE"), + } + + def __init__(self, metadata_key: str = 'metadata') -> None: + super().__init__() + self.metadata_key = metadata_key + + def visit_operation(self, operation: Operation) -> Dict: + args = [arg.accept(self) for arg in operation.arguments] + func = operation.operator + self._validate_func(func) + return self.map_dict[func](*args) + + def visit_comparison(self, comparison: Comparison) -> Dict: + regex = '\((.*?)\)' + matched = re.search('\(\w+\)', comparison.attribute) + + # If arbitrary function is applied to an attribute + if matched: + attr = re.sub(regex, f'({self.metadata_key}.{matched.group(0)[1:-1]})', comparison.attribute) + else: + attr = f'{self.metadata_key}.{comparison.attribute}' + value = comparison.value + comp = comparison.comparator + + value = f"'{value}'" if type(value) is str else value + + # convert timestamp for datetime objects + if type(value) is datetime.date: + attr = f"parseDateTime32BestEffort({attr})" + value = f"parseDateTime32BestEffort('{value.strftime('%Y-%m-%d')}')" + + # string pattern match + if comp is Comparator.LIKE: + value = f"'%{value[1:-1]}%'" + return self.map_dict[comp](attr, value) + + def visit_structured_query( + self, structured_query: StructuredQuery + ) -> Tuple[str, dict]: + print(structured_query) + if structured_query.filter is None: + kwargs = {} + else: + kwargs = {"where_str": structured_query.filter.accept(self)} + return structured_query.query, kwargs \ No newline at end of file