From 1087d64efbc8b7bd56d19c361e345b96e1025e6c Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 22 Jan 2025 15:53:49 -0800 Subject: [PATCH 001/131] added skeletons for pipeline expressions --- .../firestore_v1/pipeline_expressions.py | 175 ++++++++++++++++++ 1 file changed, 175 insertions(+) create mode 100644 google/cloud/firestore_v1/pipeline_expressions.py diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py new file mode 100644 index 000000000..55b5ee459 --- /dev/null +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -0,0 +1,175 @@ + +class Expr: + """ + Represents an expression that can be evaluated to a value within the execution of a pipeline + """ + +class Constant(Expr): + +class ListOfExprs(Expr): + +class Function(Expr): + """ + A type of Expression that takes in inputs and gives outputs + """ + +class Divide(Function): + +class DotProduct(Function): + +class EuclideanDistance(Function): + + +class LogicalMax(Function): + +class LogicalMin(Function): + +class MapGet(Function): + +class Mod(Function): + +class Multiply(Function): + +class Parent(Function): + +class ReplaceAll(Function): + +class ReplaceFirst(Function): + +class Reverse(Function): + +class StrConcat(Function): + +class Subtract(Function): + +class TimestampAdd(Function): + +class TimestampSub(Function): + +class TimestampToUnixMicros(Function): + +class TimestampToUnixMillis(Function): + +class TimestampToUnixSeconds(Function): + +class ToLower(Function): + +class ToUpper(Function): + +class Trim(Function): + +class UnixMicrosToTimestamp(Function): + +class UnixMillisToTimestamp(Function): + +class UnixSecondsToTimestamp(Function): + +class VectorLength(Function): + +class Add(Function): + +class ArrayConcat(Function): + +class ArrayElement(Function): + +class ArrayFilter(Function): + +class ArrayLength(Function): + +class ArrayReverse(Function): + +class ArrayTransform(Function): + +class ByteLength(Function): + +class CharLength(Function): + +class CollectionId(Function): + +class CosineDistance(Function): + + +class Accumulator(Function): + """ + A type of expression that takes in many, and results in one value + """ + + +class Max(Accumulator): + +class Min(Accumulator): + +class Sum(Accumulator): + + + + + +class Avg(Function, Accumulator): + +class Count(Function, Accumulator): +class CountIf(Function, Accumulator): + +class Selectable: + """ + Points at something in the database? + """ + +class AccumulatorTarget(Selectable): + +class ExprWithAlies(Expr, Selectable): + +class Field(Expr, Selectable): + + +class FilterCondition(Function): + """ + filters the given data in some way + """ + +class And(FilterCondition): + +class ArrayContains(FilterCondition) + +class ArrayContainsAll(FilterCondition) + +class ArrayContainsAny(FilterCondition) + +class EndsWith(FilterCondition) + +class Eq(FilterCondition) + +class Exists(FilterCondition) + +class Gt(FilterCondition) + +class Gte(FilterCondition) + +class If(FilterCondition) + +class In(FilterCondition) + +class IsNan(FilterCondition) + +class Like(FilterCondition) + + +class Lt(FilterCondition) + +class Lte(FilterCondition) + +class Neq(FilterCondition) + +class Not(FilterCondition) + +class Or(FilterCondition) + +class RegexContains(FilterCondition) + +class RegexMatch(FilterCondition) + +class StartsWith(FilterCondition) + +class StrContains(FilterCondition): + +class Xor(FilterCondition): From 053f55bf22ec318cbe473cbcac5a25961cb9e7ad Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 22 Jan 2025 16:03:07 -0800 Subject: [PATCH 002/131] added quick implemtation for expressions --- .../firestore_v1/pipeline_expressions.py | 223 ++++++++++++++---- 1 file changed, 180 insertions(+), 43 deletions(-) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index 55b5ee459..0aac59879 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -1,175 +1,312 @@ +from typing import Any, Iterable, List, Mapping + class Expr: - """ - Represents an expression that can be evaluated to a value within the execution of a pipeline + """Represents an expression that can be evaluated to a value within the + execution of a pipeline. """ class Constant(Expr): + def __init__(self, value: Any): + self.value = value class ListOfExprs(Expr): + def __init__(self, exprs: List[Expr]): + self.exprs = exprs class Function(Expr): - """ - A type of Expression that takes in inputs and gives outputs - """ + """A type of Expression that takes in inputs and gives outputs.""" + + def __init__(self, name: str, params: List[Expr]): + self.name = name + self.params = params class Divide(Function): + def __init__(self, left: Expr, right: Expr): + super().__init__("divide", [left, right]) class DotProduct(Function): + def __init__(self, vector1: Expr, vector2: Expr): + super().__init__("dot_product", [vector1, vector2]) class EuclideanDistance(Function): - + def __init__(self, vector1: Expr, vector2: Expr): + super().__init__("euclidean_distance", [vector1, vector2]) class LogicalMax(Function): + def __init__(self, left: Expr, right: Expr): + super().__init__("logical_max", [left, right]) class LogicalMin(Function): + def __init__(self, left: Expr, right: Expr): + super().__init__("logical_min", [left, right]) class MapGet(Function): + def __init__(self, map: Expr, name: str): + super().__init__("map_get", [map, Constant(name)]) class Mod(Function): + def __init__(self, left: Expr, right: Expr): + super().__init__("mod", [left, right]) class Multiply(Function): + def __init__(self, left: Expr, right: Expr): + super().__init__("multiply", [left, right]) class Parent(Function): + def __init__(self, value: Expr): + super().__init__("parent", [value]) class ReplaceAll(Function): + def __init__(self, value: Expr, find: Expr, replacement: Expr): + super().__init__("replace_all", [value, find, replacement]) class ReplaceFirst(Function): + def __init__(self, value: Expr, find: Expr, replacement: Expr): + super().__init__("replace_first", [value, find, replacement]) class Reverse(Function): + def __init__(self, expr: Expr): + super().__init__("reverse", [expr]) class StrConcat(Function): + def __init__(self, first: Expr, exprs: List[Expr]): + super().__init__("str_concat", [first] + exprs) class Subtract(Function): + def __init__(self, left: Expr, right: Expr): + super().__init__("subtract", [left, right]) class TimestampAdd(Function): + def __init__(self, timestamp: Expr, unit: Expr, amount: Expr): + super().__init__("timestamp_add", [timestamp, unit, amount]) class TimestampSub(Function): + def __init__(self, timestamp: Expr, unit: Expr, amount: Expr): + super().__init__("timestamp_sub", [timestamp, unit, amount]) class TimestampToUnixMicros(Function): + def __init__(self, input: Expr): + super().__init__("timestamp_to_unix_micros", [input]) class TimestampToUnixMillis(Function): + def __init__(self, input: Expr): + super().__init__("timestamp_to_unix_millis", [input]) class TimestampToUnixSeconds(Function): + def __init__(self, input: Expr): + super().__init__("timestamp_to_unix_seconds", [input]) class ToLower(Function): + def __init__(self, expr: Expr): + super().__init__("to_lower", [expr]) class ToUpper(Function): + def __init__(self, expr: Expr): + super().__init__("to_upper", [expr]) class Trim(Function): + def __init__(self, expr: Expr): + super().__init__("trim", [expr]) class UnixMicrosToTimestamp(Function): + def __init__(self, input: Expr): + super().__init__("unix_micros_to_timestamp", [input]) class UnixMillisToTimestamp(Function): + def __init__(self, input: Expr): + super().__init__("unix_millis_to_timestamp", [input]) class UnixSecondsToTimestamp(Function): + def __init__(self, input: Expr): + super().__init__("unix_seconds_to_timestamp", [input]) class VectorLength(Function): + def __init__(self, array: Expr): + super().__init__("vector_length", [array]) class Add(Function): + def __init__(self, left: Expr, right: Expr): + super().__init__("add", [left, right]) class ArrayConcat(Function): + def __init__(self, array: Expr, rest: List[Expr]): + super().__init__("array_concat", [array] + rest) class ArrayElement(Function): + def __init__(self): + super().__init__("array_element", []) class ArrayFilter(Function): + def __init__(self, array: Expr, filter: "FilterCondition"): + super().__init__("array_filter", [array, filter]) class ArrayLength(Function): + def __init__(self, array: Expr): + super().__init__("array_length", [array]) class ArrayReverse(Function): + def __init__(self, array: Expr): + super().__init__("array_reverse", [array]) class ArrayTransform(Function): + def __init__(self, array: Expr, transform: Function): + super().__init__("array_transform", [array, transform]) class ByteLength(Function): + def __init__(self, expr: Expr): + super().__init__("byte_length", [expr]) class CharLength(Function): + def __init__(self, expr: Expr): + super().__init__("char_length", [expr]) class CollectionId(Function): + def __init__(self, value: Expr): + super().__init__("collection_id", [value]) class CosineDistance(Function): - + def __init__(self, vector1: Expr, vector2: Expr): + super().__init__("cosine_distance", [vector1, vector2]) class Accumulator(Function): - """ - A type of expression that takes in many, and results in one value - """ - + """A type of expression that takes in many, and results in one value.""" class Max(Accumulator): + def __init__(self, value: Expr, distinct: bool): + super().__init__("max", [value]) class Min(Accumulator): + def __init__(self, value: Expr, distinct: bool): + super().__init__("min", [value]) class Sum(Accumulator): - - - - + def __init__(self, value: Expr, distinct: bool): + super().__init__("sum", [value]) class Avg(Function, Accumulator): + def __init__(self, value: Expr, distinct: bool): + super(Function, self).__init__("avg", [value]) class Count(Function, Accumulator): + def __init__(self, value: Expr = None): + super(Function, self).__init__("count", [value] if value else []) + class CountIf(Function, Accumulator): + def __init__(self, value: Expr, distinct: bool): + super(Function, self).__init__("countif", [value] if value else []) class Selectable: - """ - Points at something in the database? - """ + """Points at something in the database?""" class AccumulatorTarget(Selectable): + def __init__(self, accumulator: Accumulator, field_name: str, distinct: bool): + self.accumulator = accumulator + self.field_name = field_name + self.distinct = distinct -class ExprWithAlies(Expr, Selectable): +class ExprWithAlias(Expr, Selectable): + def __init__(self, expr: Expr, alias: str): + self.expr = expr + self.alias = alias class Field(Expr, Selectable): + DOCUMENT_ID = "__name__" + def __init__(self, path: str): + self.path = path class FilterCondition(Function): - """ - filters the given data in some way - """ + """Filters the given data in some way.""" class And(FilterCondition): + def __init__(self, conditions: List["FilterCondition"]): + super().__init__("and", conditions) -class ArrayContains(FilterCondition) - -class ArrayContainsAll(FilterCondition) +class ArrayContains(FilterCondition): + def __init__(self, array: Expr, element: Expr): + super().__init__("array_contains", [array, element if element else Constant(None)]) -class ArrayContainsAny(FilterCondition) +class ArrayContainsAll(FilterCondition): + def __init__(self, array: Expr, elements: List[Expr]): + super().__init__("array_contains_all", [array, ListOfExprs(elements)]) -class EndsWith(FilterCondition) +class ArrayContainsAny(FilterCondition): + def __init__(self, array: Expr, elements: List[Expr]): + super().__init__("array_contains_any", [array, ListOfExprs(elements)]) -class Eq(FilterCondition) +class EndsWith(FilterCondition): + def __init__(self, expr: Expr, postfix: Expr): + super().__init__("ends_with", [expr, postfix]) -class Exists(FilterCondition) +class Eq(FilterCondition): + def __init__(self, left: Expr, right: Expr): + super().__init__("eq", [left, right if right else Constant(None)]) -class Gt(FilterCondition) +class Exists(FilterCondition): + def __init__(self, expr: Expr): + super().__init__("exists", [expr]) -class Gte(FilterCondition) +class Gt(FilterCondition): + def __init__(self, left: Expr, right: Expr): + super().__init__("gt", [left, right if right else Constant(None)]) -class If(FilterCondition) +class Gte(FilterCondition): + def __init__(self, left: Expr, right: Expr): + super().__init__("gte", [left, right if right else Constant(None)]) -class In(FilterCondition) +class If(FilterCondition): + def __init__(self, condition: "FilterCondition", true_expr: Expr, false_expr: Expr): + super().__init__("if", [condition, true_expr, false_expr if false_expr else Constant(None)]) -class IsNan(FilterCondition) +class In(FilterCondition): + def __init__(self, left: Expr, others: List[Expr]): + super().__init__("in", [left, ListOfExprs(others)]) -class Like(FilterCondition) +class IsNan(FilterCondition): + def __init__(self, value: Expr): + super().__init__("is_nan", [value]) +class Like(FilterCondition): + def __init__(self, expr: Expr, pattern: Expr): + super().__init__("like", [expr, pattern]) -class Lt(FilterCondition) +class Lt(FilterCondition): + def __init__(self, left: Expr, right: Expr): + super().__init__("lt", [left, right if right else Constant(None)]) -class Lte(FilterCondition) +class Lte(FilterCondition): + def __init__(self, left: Expr, right: Expr): + super().__init__("lte", [left, right if right else Constant(None)]) -class Neq(FilterCondition) +class Neq(FilterCondition): + def __init__(self, left: Expr, right: Expr): + super().__init__("neq", [left, right if right else Constant(None)]) -class Not(FilterCondition) +class Not(FilterCondition): + def __init__(self, condition: Expr): + super().__init__("not", [condition]) -class Or(FilterCondition) +class Or(FilterCondition): + def __init__(self, conditions: List["FilterCondition"]): + super().__init__("or", conditions) -class RegexContains(FilterCondition) +class RegexContains(FilterCondition): + def __init__(self, expr: Expr, regex: Expr): + super().__init__("regex_contains", [expr, regex]) -class RegexMatch(FilterCondition) +class RegexMatch(FilterCondition): + def __init__(self, expr: Expr, regex: Expr): + super().__init__("regex_match", [expr, regex]) -class StartsWith(FilterCondition) +class StartsWith(FilterCondition): + def __init__(self, expr: Expr, prefix: Expr): + super().__init__("starts_with", [expr, prefix]) class StrContains(FilterCondition): + def __init__(self, expr: Expr, substring: Expr): + super().__init__("str_contains", [expr, substring]) class Xor(FilterCondition): + def __init__(self, conditions: List["FilterCondition"]): + super().__init__("xor", conditions) From e56459cb2bbdf232858986546835bce89331153f Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 22 Jan 2025 16:20:06 -0800 Subject: [PATCH 003/131] added default implementation for pipeline stages --- google/cloud/firestore_v1/pipeline_stages.py | 209 +++++++++++++++++++ 1 file changed, 209 insertions(+) create mode 100644 google/cloud/firestore_v1/pipeline_stages.py diff --git a/google/cloud/firestore_v1/pipeline_stages.py b/google/cloud/firestore_v1/pipeline_stages.py new file mode 100644 index 000000000..3a67cfb84 --- /dev/null +++ b/google/cloud/firestore_v1/pipeline_stages.py @@ -0,0 +1,209 @@ +from typing import Any, Dict, Iterable, List, Optional, Union + +from google.cloud.firestore_v1.types import value +from google.cloud.firestore_v1.types.pipeline import Stage as GrpcStage +from google.cloud.firestore_v1.types.query import StructuredQuery + +from google.cloud.firestore_v1.base_document import DocumentReference +from google.cloud.firestore_v1.field_path import FieldPath +from google.cloud.firestore_v1.pipeline import Pipeline +from google.cloud.firestore_v1.pipeline_expressions import ( + Accumulator, + CompositeFilter, + Direction, + DistanceMeasure, + Expr, + ExprWithAlias, + Field, + FieldFilter, + Filter, + FilterCondition, + Ordering, + Scalar, + Selectable, + UnaryFilter, +) + + +class Stage: + def __init__(self, custom_name: Optional[str] = None): + self.name = custom_name or type(self).__name__.lower() + + +class AddFields(Stage): + def __init__(self, fields: Dict[str, Expr]): + super().__init__("add_fields") + self.fields = fields + +class Aggregate(Stage): + def __init__( + self, + groups: Optional[Dict[str, Expr]] = None, + accumulators: Optional[Dict[str, Accumulator]] = None, + ): + super().__init__() + self.groups = groups or {} + self.accumulators = accumulators or {} + + + +class Collection(Stage): + def __init__(self, path: str): + super().__init__() + if not path.startswith("/"): + path = "/" + path + self.path = path + + +class CollectionGroup(Stage): + def __init__(self, collection_id: str): + super().__init__("collection_group") + self.collection_id = collection_id + + + +class Database(Stage): + def __init__(self): + super().__init__() + + +class Distinct(Stage): + def __init__(self, groups: Dict[str, Expr]): + super().__init__() + self.groups = groups + + +class Documents(Stage): + def __init__(self, documents: List[str]): + super().__init__() + self.documents = documents + + @staticmethod + def of(*documents: DocumentReference) -> "Documents": + doc_paths = ["/" + doc.path for doc in documents] + return Documents(doc_paths) + + +class FindNearest(Stage): + def __init__( + self, + property: Expr, + vector: List[float], + distance_measure: DistanceMeasure, + options: Optional["FindNearestOptions"] = None, + ): + super().__init__("find_nearest") + self.property = property + self.vector = vector + self.distance_measure = distance_measure + self.options = options or FindNearestOptions() + + + +class GenericStage(Stage): + def __init__(self, name: str, params: List[Any]): + super().__init__(name) + self.params = params + + + +class Limit(Stage): + def __init__(self, limit: int): + super().__init__() + self.limit = limit + + + +class Offset(Stage): + def __init__(self, offset: int): + super().__init__() + self.offset = offset + + + +class RemoveFields(Stage): + def __init__(self, fields: List[Field]): + super().__init__("remove_fields") + self.fields = fields + + + +class Replace(Stage): + class Mode: + FULL_REPLACE = value.Value(string_value="full_replace") + MERGE_PREFER_NEXT = value.Value(string_value="merge_prefer_nest") + MERGE_PREFER_PARENT = value.Value(string_value="merge_prefer_parent") + + def __init__(self, field: Selectable, mode: Mode = Mode.FULL_REPLACE): + super().__init__() + self.field = field + self.mode = mode + + + +class Sample(Stage): + def __init__(self, options: "SampleOptions"): + super().__init__() + self.options = options + + + +class Select(Stage): + def __init__(self, projections: Dict[str, Expr]): + super().__init__() + self.projections = projections + + + +class Sort(Stage): + def __init__(self, orders: List[Ordering]): + super().__init__() + self.orders = orders + + + +class Union(Stage): + def __init__(self, other: Pipeline): + super().__init__() + self.other = other + + + +class Unnest(Stage): + def __init__(self, field: Field, options: Optional["UnnestOptions"] = None): + super().__init__() + self.field = field + self.options = options + + + +class Where(Stage): + def __init__(self, condition: FilterCondition): + super().__init__() + self.condition = condition + + + +class FindNearestOptions: + def __init__( + self, + limit: Optional[int] = None, + distance_field: Optional[Field] = None, + ): + self.limit = limit + self.distance_field = distance_field + + +class SampleOptions: + class Mode: + DOCUMENTS = value.Value(string_value="documents") + PERCENT = value.Value(string_value="percent") + + def __init__(self, n: Union[int, float], mode: Mode): + self.n = n + self.mode = mode + + +class UnnestOptions: + def __init__(self, index_field: str): + self.index_field = index_field From babafb20e7882dfc043b0ee232cd850ea6bfeeea Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 22 Jan 2025 16:21:52 -0800 Subject: [PATCH 004/131] ran black --- .../firestore_v1/pipeline_expressions.py | 85 ++++++++++++++++++- google/cloud/firestore_v1/pipeline_stages.py | 15 +--- 2 files changed, 83 insertions(+), 17 deletions(-) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index 0aac59879..7e1c691ba 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -1,19 +1,22 @@ - from typing import Any, Iterable, List, Mapping + class Expr: """Represents an expression that can be evaluated to a value within the execution of a pipeline. """ + class Constant(Expr): def __init__(self, value: Any): self.value = value + class ListOfExprs(Expr): def __init__(self, exprs: List[Expr]): self.exprs = exprs + class Function(Expr): """A type of Expression that takes in inputs and gives outputs.""" @@ -21,292 +24,368 @@ def __init__(self, name: str, params: List[Expr]): self.name = name self.params = params + class Divide(Function): def __init__(self, left: Expr, right: Expr): super().__init__("divide", [left, right]) + class DotProduct(Function): def __init__(self, vector1: Expr, vector2: Expr): super().__init__("dot_product", [vector1, vector2]) + class EuclideanDistance(Function): def __init__(self, vector1: Expr, vector2: Expr): super().__init__("euclidean_distance", [vector1, vector2]) + class LogicalMax(Function): def __init__(self, left: Expr, right: Expr): super().__init__("logical_max", [left, right]) + class LogicalMin(Function): def __init__(self, left: Expr, right: Expr): super().__init__("logical_min", [left, right]) + class MapGet(Function): def __init__(self, map: Expr, name: str): super().__init__("map_get", [map, Constant(name)]) + class Mod(Function): def __init__(self, left: Expr, right: Expr): super().__init__("mod", [left, right]) + class Multiply(Function): def __init__(self, left: Expr, right: Expr): super().__init__("multiply", [left, right]) + class Parent(Function): def __init__(self, value: Expr): super().__init__("parent", [value]) + class ReplaceAll(Function): def __init__(self, value: Expr, find: Expr, replacement: Expr): super().__init__("replace_all", [value, find, replacement]) + class ReplaceFirst(Function): def __init__(self, value: Expr, find: Expr, replacement: Expr): super().__init__("replace_first", [value, find, replacement]) + class Reverse(Function): def __init__(self, expr: Expr): super().__init__("reverse", [expr]) + class StrConcat(Function): def __init__(self, first: Expr, exprs: List[Expr]): super().__init__("str_concat", [first] + exprs) + class Subtract(Function): def __init__(self, left: Expr, right: Expr): super().__init__("subtract", [left, right]) + class TimestampAdd(Function): def __init__(self, timestamp: Expr, unit: Expr, amount: Expr): super().__init__("timestamp_add", [timestamp, unit, amount]) + class TimestampSub(Function): def __init__(self, timestamp: Expr, unit: Expr, amount: Expr): super().__init__("timestamp_sub", [timestamp, unit, amount]) + class TimestampToUnixMicros(Function): def __init__(self, input: Expr): super().__init__("timestamp_to_unix_micros", [input]) + class TimestampToUnixMillis(Function): def __init__(self, input: Expr): super().__init__("timestamp_to_unix_millis", [input]) + class TimestampToUnixSeconds(Function): def __init__(self, input: Expr): super().__init__("timestamp_to_unix_seconds", [input]) + class ToLower(Function): def __init__(self, expr: Expr): super().__init__("to_lower", [expr]) + class ToUpper(Function): def __init__(self, expr: Expr): super().__init__("to_upper", [expr]) + class Trim(Function): def __init__(self, expr: Expr): super().__init__("trim", [expr]) + class UnixMicrosToTimestamp(Function): def __init__(self, input: Expr): super().__init__("unix_micros_to_timestamp", [input]) + class UnixMillisToTimestamp(Function): def __init__(self, input: Expr): super().__init__("unix_millis_to_timestamp", [input]) + class UnixSecondsToTimestamp(Function): def __init__(self, input: Expr): super().__init__("unix_seconds_to_timestamp", [input]) + class VectorLength(Function): def __init__(self, array: Expr): super().__init__("vector_length", [array]) + class Add(Function): def __init__(self, left: Expr, right: Expr): super().__init__("add", [left, right]) + class ArrayConcat(Function): def __init__(self, array: Expr, rest: List[Expr]): super().__init__("array_concat", [array] + rest) + class ArrayElement(Function): def __init__(self): super().__init__("array_element", []) + class ArrayFilter(Function): def __init__(self, array: Expr, filter: "FilterCondition"): super().__init__("array_filter", [array, filter]) + class ArrayLength(Function): def __init__(self, array: Expr): super().__init__("array_length", [array]) + class ArrayReverse(Function): def __init__(self, array: Expr): super().__init__("array_reverse", [array]) + class ArrayTransform(Function): def __init__(self, array: Expr, transform: Function): super().__init__("array_transform", [array, transform]) + class ByteLength(Function): def __init__(self, expr: Expr): super().__init__("byte_length", [expr]) + class CharLength(Function): def __init__(self, expr: Expr): super().__init__("char_length", [expr]) + class CollectionId(Function): def __init__(self, value: Expr): super().__init__("collection_id", [value]) + class CosineDistance(Function): def __init__(self, vector1: Expr, vector2: Expr): super().__init__("cosine_distance", [vector1, vector2]) + class Accumulator(Function): """A type of expression that takes in many, and results in one value.""" + class Max(Accumulator): def __init__(self, value: Expr, distinct: bool): super().__init__("max", [value]) + class Min(Accumulator): def __init__(self, value: Expr, distinct: bool): super().__init__("min", [value]) + class Sum(Accumulator): def __init__(self, value: Expr, distinct: bool): super().__init__("sum", [value]) + class Avg(Function, Accumulator): def __init__(self, value: Expr, distinct: bool): super(Function, self).__init__("avg", [value]) + class Count(Function, Accumulator): def __init__(self, value: Expr = None): super(Function, self).__init__("count", [value] if value else []) + class CountIf(Function, Accumulator): def __init__(self, value: Expr, distinct: bool): super(Function, self).__init__("countif", [value] if value else []) + class Selectable: """Points at something in the database?""" + class AccumulatorTarget(Selectable): def __init__(self, accumulator: Accumulator, field_name: str, distinct: bool): self.accumulator = accumulator self.field_name = field_name self.distinct = distinct + class ExprWithAlias(Expr, Selectable): def __init__(self, expr: Expr, alias: str): self.expr = expr self.alias = alias + class Field(Expr, Selectable): DOCUMENT_ID = "__name__" def __init__(self, path: str): self.path = path + class FilterCondition(Function): """Filters the given data in some way.""" + class And(FilterCondition): def __init__(self, conditions: List["FilterCondition"]): super().__init__("and", conditions) + class ArrayContains(FilterCondition): def __init__(self, array: Expr, element: Expr): - super().__init__("array_contains", [array, element if element else Constant(None)]) + super().__init__( + "array_contains", [array, element if element else Constant(None)] + ) + class ArrayContainsAll(FilterCondition): def __init__(self, array: Expr, elements: List[Expr]): super().__init__("array_contains_all", [array, ListOfExprs(elements)]) + class ArrayContainsAny(FilterCondition): def __init__(self, array: Expr, elements: List[Expr]): super().__init__("array_contains_any", [array, ListOfExprs(elements)]) + class EndsWith(FilterCondition): def __init__(self, expr: Expr, postfix: Expr): super().__init__("ends_with", [expr, postfix]) + class Eq(FilterCondition): def __init__(self, left: Expr, right: Expr): super().__init__("eq", [left, right if right else Constant(None)]) + class Exists(FilterCondition): def __init__(self, expr: Expr): super().__init__("exists", [expr]) + class Gt(FilterCondition): def __init__(self, left: Expr, right: Expr): super().__init__("gt", [left, right if right else Constant(None)]) + class Gte(FilterCondition): def __init__(self, left: Expr, right: Expr): super().__init__("gte", [left, right if right else Constant(None)]) + class If(FilterCondition): def __init__(self, condition: "FilterCondition", true_expr: Expr, false_expr: Expr): - super().__init__("if", [condition, true_expr, false_expr if false_expr else Constant(None)]) + super().__init__( + "if", [condition, true_expr, false_expr if false_expr else Constant(None)] + ) + class In(FilterCondition): def __init__(self, left: Expr, others: List[Expr]): super().__init__("in", [left, ListOfExprs(others)]) + class IsNan(FilterCondition): def __init__(self, value: Expr): super().__init__("is_nan", [value]) + class Like(FilterCondition): def __init__(self, expr: Expr, pattern: Expr): super().__init__("like", [expr, pattern]) + class Lt(FilterCondition): def __init__(self, left: Expr, right: Expr): super().__init__("lt", [left, right if right else Constant(None)]) + class Lte(FilterCondition): def __init__(self, left: Expr, right: Expr): super().__init__("lte", [left, right if right else Constant(None)]) + class Neq(FilterCondition): def __init__(self, left: Expr, right: Expr): super().__init__("neq", [left, right if right else Constant(None)]) + class Not(FilterCondition): def __init__(self, condition: Expr): super().__init__("not", [condition]) + class Or(FilterCondition): def __init__(self, conditions: List["FilterCondition"]): super().__init__("or", conditions) + class RegexContains(FilterCondition): def __init__(self, expr: Expr, regex: Expr): super().__init__("regex_contains", [expr, regex]) + class RegexMatch(FilterCondition): def __init__(self, expr: Expr, regex: Expr): super().__init__("regex_match", [expr, regex]) + class StartsWith(FilterCondition): def __init__(self, expr: Expr, prefix: Expr): super().__init__("starts_with", [expr, prefix]) + class StrContains(FilterCondition): def __init__(self, expr: Expr, substring: Expr): super().__init__("str_contains", [expr, substring]) + class Xor(FilterCondition): def __init__(self, conditions: List["FilterCondition"]): super().__init__("xor", conditions) diff --git a/google/cloud/firestore_v1/pipeline_stages.py b/google/cloud/firestore_v1/pipeline_stages.py index 3a67cfb84..806f9588b 100644 --- a/google/cloud/firestore_v1/pipeline_stages.py +++ b/google/cloud/firestore_v1/pipeline_stages.py @@ -35,6 +35,7 @@ def __init__(self, fields: Dict[str, Expr]): super().__init__("add_fields") self.fields = fields + class Aggregate(Stage): def __init__( self, @@ -46,7 +47,6 @@ def __init__( self.accumulators = accumulators or {} - class Collection(Stage): def __init__(self, path: str): super().__init__() @@ -61,7 +61,6 @@ def __init__(self, collection_id: str): self.collection_id = collection_id - class Database(Stage): def __init__(self): super().__init__() @@ -99,35 +98,30 @@ def __init__( self.options = options or FindNearestOptions() - class GenericStage(Stage): def __init__(self, name: str, params: List[Any]): super().__init__(name) self.params = params - class Limit(Stage): def __init__(self, limit: int): super().__init__() self.limit = limit - class Offset(Stage): def __init__(self, offset: int): super().__init__() self.offset = offset - class RemoveFields(Stage): def __init__(self, fields: List[Field]): super().__init__("remove_fields") self.fields = fields - class Replace(Stage): class Mode: FULL_REPLACE = value.Value(string_value="full_replace") @@ -140,35 +134,30 @@ def __init__(self, field: Selectable, mode: Mode = Mode.FULL_REPLACE): self.mode = mode - class Sample(Stage): def __init__(self, options: "SampleOptions"): super().__init__() self.options = options - class Select(Stage): def __init__(self, projections: Dict[str, Expr]): super().__init__() self.projections = projections - class Sort(Stage): def __init__(self, orders: List[Ordering]): super().__init__() self.orders = orders - class Union(Stage): def __init__(self, other: Pipeline): super().__init__() self.other = other - class Unnest(Stage): def __init__(self, field: Field, options: Optional["UnnestOptions"] = None): super().__init__() @@ -176,14 +165,12 @@ def __init__(self, field: Field, options: Optional["UnnestOptions"] = None): self.options = options - class Where(Stage): def __init__(self, condition: FilterCondition): super().__init__() self.condition = condition - class FindNearestOptions: def __init__( self, From 149c61f15405823eec273f3c6aac9696ab628d98 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 22 Jan 2025 16:30:15 -0800 Subject: [PATCH 005/131] got code to run --- .../firestore_v1/pipeline_expressions.py | 6 +-- google/cloud/firestore_v1/pipeline_stages.py | 43 +++++++------------ 2 files changed, 18 insertions(+), 31 deletions(-) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index 7e1c691ba..eb4472fe8 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -229,17 +229,17 @@ def __init__(self, value: Expr, distinct: bool): super().__init__("sum", [value]) -class Avg(Function, Accumulator): +class Avg(Accumulator): def __init__(self, value: Expr, distinct: bool): super(Function, self).__init__("avg", [value]) -class Count(Function, Accumulator): +class Count(Accumulator): def __init__(self, value: Expr = None): super(Function, self).__init__("count", [value] if value else []) -class CountIf(Function, Accumulator): +class CountIf(Function): def __init__(self, value: Expr, distinct: bool): super(Function, self).__init__("countif", [value] if value else []) diff --git a/google/cloud/firestore_v1/pipeline_stages.py b/google/cloud/firestore_v1/pipeline_stages.py index 806f9588b..27761a4aa 100644 --- a/google/cloud/firestore_v1/pipeline_stages.py +++ b/google/cloud/firestore_v1/pipeline_stages.py @@ -1,27 +1,14 @@ -from typing import Any, Dict, Iterable, List, Optional, Union +from __future__ import annotations +from typing import Any, Dict, Iterable, List, Optional +from enum import Enum -from google.cloud.firestore_v1.types import value -from google.cloud.firestore_v1.types.pipeline import Stage as GrpcStage -from google.cloud.firestore_v1.types.query import StructuredQuery - -from google.cloud.firestore_v1.base_document import DocumentReference -from google.cloud.firestore_v1.field_path import FieldPath -from google.cloud.firestore_v1.pipeline import Pipeline from google.cloud.firestore_v1.pipeline_expressions import ( Accumulator, - CompositeFilter, - Direction, - DistanceMeasure, Expr, ExprWithAlias, Field, - FieldFilter, - Filter, FilterCondition, - Ordering, - Scalar, Selectable, - UnaryFilter, ) @@ -78,7 +65,7 @@ def __init__(self, documents: List[str]): self.documents = documents @staticmethod - def of(*documents: DocumentReference) -> "Documents": + def of(*documents: "DocumentReference") -> "Documents": doc_paths = ["/" + doc.path for doc in documents] return Documents(doc_paths) @@ -88,7 +75,7 @@ def __init__( self, property: Expr, vector: List[float], - distance_measure: DistanceMeasure, + distance_measure: "DistanceMeasure", options: Optional["FindNearestOptions"] = None, ): super().__init__("find_nearest") @@ -123,10 +110,10 @@ def __init__(self, fields: List[Field]): class Replace(Stage): - class Mode: - FULL_REPLACE = value.Value(string_value="full_replace") - MERGE_PREFER_NEXT = value.Value(string_value="merge_prefer_nest") - MERGE_PREFER_PARENT = value.Value(string_value="merge_prefer_parent") + class Mode(Enum): + FULL_REPLACE = "full_replace" + MERGE_PREFER_NEXT = "merge_prefer_nest" + MERGE_PREFER_PARENT = "merge_prefer_parent" def __init__(self, field: Selectable, mode: Mode = Mode.FULL_REPLACE): super().__init__() @@ -147,13 +134,13 @@ def __init__(self, projections: Dict[str, Expr]): class Sort(Stage): - def __init__(self, orders: List[Ordering]): + def __init__(self, orders: List["Ordering"]): super().__init__() self.orders = orders class Union(Stage): - def __init__(self, other: Pipeline): + def __init__(self, other: "Pipeline"): super().__init__() self.other = other @@ -182,11 +169,11 @@ def __init__( class SampleOptions: - class Mode: - DOCUMENTS = value.Value(string_value="documents") - PERCENT = value.Value(string_value="percent") + class Mode(Enum): + DOCUMENTS = "documents" + PERCENT = "percent" - def __init__(self, n: Union[int, float], mode: Mode): + def __init__(self, n: int | float, mode: Mode): self.n = n self.mode = mode From 320cea168352269f8fbdd3212afca3fa3bf84adb Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 22 Jan 2025 16:30:58 -0800 Subject: [PATCH 006/131] moved helpers --- google/cloud/firestore_v1/pipeline_stages.py | 48 ++++++++++---------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/google/cloud/firestore_v1/pipeline_stages.py b/google/cloud/firestore_v1/pipeline_stages.py index 27761a4aa..b13ceb7a4 100644 --- a/google/cloud/firestore_v1/pipeline_stages.py +++ b/google/cloud/firestore_v1/pipeline_stages.py @@ -11,6 +11,30 @@ Selectable, ) +class FindNearestOptions: + def __init__( + self, + limit: Optional[int] = None, + distance_field: Optional[Field] = None, + ): + self.limit = limit + self.distance_field = distance_field + + +class SampleOptions: + class Mode(Enum): + DOCUMENTS = "documents" + PERCENT = "percent" + + def __init__(self, n: int | float, mode: Mode): + self.n = n + self.mode = mode + + +class UnnestOptions: + def __init__(self, index_field: str): + self.index_field = index_field + class Stage: def __init__(self, custom_name: Optional[str] = None): @@ -157,27 +181,3 @@ def __init__(self, condition: FilterCondition): super().__init__() self.condition = condition - -class FindNearestOptions: - def __init__( - self, - limit: Optional[int] = None, - distance_field: Optional[Field] = None, - ): - self.limit = limit - self.distance_field = distance_field - - -class SampleOptions: - class Mode(Enum): - DOCUMENTS = "documents" - PERCENT = "percent" - - def __init__(self, n: int | float, mode: Mode): - self.n = n - self.mode = mode - - -class UnnestOptions: - def __init__(self, index_field: str): - self.index_field = index_field From cd2963d551792ce6fe2b89e079640808a5807dd4 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 22 Jan 2025 16:44:55 -0800 Subject: [PATCH 007/131] added basic pipelines.py file --- google/cloud/firestore_v1/pipeline.py | 101 +++++++++++++++++++ google/cloud/firestore_v1/pipeline_stages.py | 2 +- 2 files changed, 102 insertions(+), 1 deletion(-) create mode 100644 google/cloud/firestore_v1/pipeline.py diff --git a/google/cloud/firestore_v1/pipeline.py b/google/cloud/firestore_v1/pipeline.py new file mode 100644 index 000000000..ec8c5c9bb --- /dev/null +++ b/google/cloud/firestore_v1/pipeline.py @@ -0,0 +1,101 @@ +from __future__ import annotations +from typing import Any, Dict, Iterable, List, Optional +from google.cloud.firestore_v1 import pipeline_stages as stages + +from google.cloud.firestore_v1.pipeline_expressions import ( + Accumulator, + Expr, + ExprWithAlias, + Field, + FilterCondition, + Selectable, +) + + +class Pipeline: + def __init__(self): + self.stages = [] + + def add_fields(self, fields: Dict[str, Expr]) -> Pipeline: + self.stages.append(stages.AddFields(fields)) + return self + + def remove_fields(self, fields: List[Field]) -> Pipeline: + self.stages.append(stages.RemoveFields(fields)) + return self + + def select(self, projections: Dict[str, Expr]) -> Pipeline: + self.stages.append(stages.Select(projections)) + return self + + def where(self, condition: FilterCondition) -> Pipeline: + self.stages.append(stages.Where(condition)) + return self + + def find_nearest( + self, + field: str | Expr, + vector: "Vector", + distance_measure: "FindNearest.DistanceMeasure", + limit: int | None, + options: Optional[stages.FindNearestOptions] = None, + ) -> Pipeline: + self.stages.append(stages.FindNearest(field, vector, distance_measure, options)) + return self + + def sort(self, orders: List[stages.Ordering]) -> Pipeline: + self.stages.append(stages.Sort(orders)) + return self + + def replace( + self, + field: Selectable, + mode: stages.Replace.Mode = stages.Replace.Mode.FULL_REPLACE, + ) -> Pipeline: + self.stages.append(stages.Replace(field, mode)) + return self + + def sample(self, options: stages.SampleOptions) -> Pipeline: + self.stages.append(stages.Sample(options)) + return self + + def union(self, other: Pipeline) -> Pipeline: + self.stages.append(stages.Union(other)) + return self + + def unnest( + self, + field_name: str, + options: Optional[stages.UnnestOptions] = None, + ) -> Pipeline: + self.stages.append(stages.Unnest(field_name, options)) + return self + + def generic_stage(self, name: str, params: List[Any]) -> Pipeline: + self.stages.append(stages.GenericStage(name, params)) + return self + + def offset(self, offset: int) -> Pipeline: + self.stages.append(stages.Offset(offset)) + return self + + def limit(self, limit: int) -> Pipeline: + self.stages.append(stages.Limit(limit)) + return self + + def aggregate( + self, + accumulators: Optional[Dict[str, Accumulator]] = None, + ) -> Pipeline: + self.stages.append(stages.Aggregate(accumulators=accumulators)) + return self + + def distinct(self, fields: Dict[str, Expr]) -> Pipeline: + self.stages.append(stages.Distinct(fields)) + return self + + def execute(self) -> list["PipelineResult"]: + return [] + + async def execute_async(self) -> List["PipelineResult"]: + return [] diff --git a/google/cloud/firestore_v1/pipeline_stages.py b/google/cloud/firestore_v1/pipeline_stages.py index b13ceb7a4..fba390d26 100644 --- a/google/cloud/firestore_v1/pipeline_stages.py +++ b/google/cloud/firestore_v1/pipeline_stages.py @@ -11,6 +11,7 @@ Selectable, ) + class FindNearestOptions: def __init__( self, @@ -180,4 +181,3 @@ class Where(Stage): def __init__(self, condition: FilterCondition): super().__init__() self.condition = condition - From 3cbc2d519af5a4109f129ed4c2ac68bfa48f3fc7 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 23 Jan 2025 16:34:05 -0800 Subject: [PATCH 008/131] added yaml test file --- tests/system/pipeline_e2e.yaml | 158 +++++++++++++++++++++++++++++++++ 1 file changed, 158 insertions(+) create mode 100644 tests/system/pipeline_e2e.yaml diff --git a/tests/system/pipeline_e2e.yaml b/tests/system/pipeline_e2e.yaml new file mode 100644 index 000000000..825014aa0 --- /dev/null +++ b/tests/system/pipeline_e2e.yaml @@ -0,0 +1,158 @@ +data: + books: + book1: + title: "The Hitchhiker's Guide to the Galaxy" + author: "Douglas Adams" + genre: "Science Fiction" + published: 1979 + rating: 4.2 + tags: + - comedy + - space + - adventure + awards: + hugo: true + nebula: false + book2: + title: "Pride and Prejudice" + author: "Jane Austen" + genre: "Romance" + published: 1813 + rating: 4.5 + tags: + - classic + - social commentary + - love + awards: + none: true + book3: + title: "One Hundred Years of Solitude" + author: "Gabriel García Márquez" + genre: "Magical Realism" + published: 1967 + rating: 4.3 + tags: + - family + - history + - fantasy + awards: + nobel: true + nebula: false + book4: + title: "The Lord of the Rings" + author: "J.R.R. Tolkien" + genre: "Fantasy" + published: 1954 + rating: 4.7 + tags: + - adventure + - magic + - epic + awards: + hugo: false + nebula: false + book5: + title: "The Handmaid's Tale" + author: "Margaret Atwood" + genre: "Dystopian" + published: 1985 + rating: 4.1 + tags: + - feminism + - totalitarianism + - resistance + awards: + arthur c. clarke: true + booker prize: false + book6: + title: "Crime and Punishment" + author: "Fyodor Dostoevsky" + genre: "Psychological Thriller" + published: 1866 + rating: 4.3 + tags: + - philosophy + - crime + - redemption + awards: + none: true + book7: + title: "To Kill a Mockingbird" + author: "Harper Lee" + genre: "Southern Gothic" + published: 1960 + rating: 4.2 + tags: + - racism + - injustice + - coming-of-age + awards: + pulitzer: true + book8: + title: "1984" + author: "George Orwell" + genre: "Dystopian" + published: 1949 + rating: 4.2 + tags: + - surveillance + - totalitarianism + - propaganda + awards: + prometheus: true + book9: + title: "The Great Gatsby" + author: "F. Scott Fitzgerald" + genre: "Modernist" + published: 1925 + rating: 4.0 + tags: + - wealth + - american dream + - love + awards: + none: true + book10: + title: "Dune" + author: "Frank Herbert" + genre: "Science Fiction" + published: 1965 + rating: 4.6 + tags: + - politics + - desert + - ecology + awards: + hugo: true + nebula: true +tests: + - description: "test accumulators" + pipeline: + - Collection: "books" + - Where: + - Eq: + left: + Field: "genre" + right: + Constant: "Science Fiction" + - Aggregate: + accumulators: + - ExprWithAlias: + expr: CountAll + alias: "count" + - ExprWithAlias: + expr: + Avg: + value: "rating" + alias: "avg_rating" + - ExprWithAlias: + expr: + Max: + value: "rating" + alias: "max_rating" + results: + - count: 10 + max_rating: 4.6 + avg_rating: 4.4 + + From 0c2241da3ad796906f298705e3fd87ba6b16ff2c Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 24 Jan 2025 11:34:35 -0800 Subject: [PATCH 009/131] wrote basic parser for pipeline yaml --- .../firestore_v1/pipeline_expressions.py | 4 + google/cloud/firestore_v1/pipeline_stages.py | 3 +- tests/system/pipeline_e2e.yaml | 11 ++- tests/system/test_pipeline_acceptance.py | 94 +++++++++++++++++++ 4 files changed, 105 insertions(+), 7 deletions(-) create mode 100644 tests/system/test_pipeline_acceptance.py diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index eb4472fe8..dee209c34 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -6,6 +6,10 @@ class Expr: execution of a pipeline. """ + def __repr__(self): + items = ("%s = %r" % (k, v) for k, v in self.__dict__.items()) + return "<%s: {%s}>" % (self.__class__.__name__, ', '.join(items)) + class Constant(Expr): def __init__(self, value: Any): diff --git a/google/cloud/firestore_v1/pipeline_stages.py b/google/cloud/firestore_v1/pipeline_stages.py index fba390d26..ca7e51ea8 100644 --- a/google/cloud/firestore_v1/pipeline_stages.py +++ b/google/cloud/firestore_v1/pipeline_stages.py @@ -11,7 +11,6 @@ Selectable, ) - class FindNearestOptions: def __init__( self, @@ -41,7 +40,6 @@ class Stage: def __init__(self, custom_name: Optional[str] = None): self.name = custom_name or type(self).__name__.lower() - class AddFields(Stage): def __init__(self, fields: Dict[str, Expr]): super().__init__("add_fields") @@ -181,3 +179,4 @@ class Where(Stage): def __init__(self, condition: FilterCondition): super().__init__() self.condition = condition + diff --git a/tests/system/pipeline_e2e.yaml b/tests/system/pipeline_e2e.yaml index 825014aa0..ad7eed55d 100644 --- a/tests/system/pipeline_e2e.yaml +++ b/tests/system/pipeline_e2e.yaml @@ -130,11 +130,12 @@ tests: pipeline: - Collection: "books" - Where: - - Eq: - left: - Field: "genre" - right: - Constant: "Science Fiction" + condition: + Eq: + left: + Field: "genre" + right: + Constant: "Science Fiction" - Aggregate: accumulators: - ExprWithAlias: diff --git a/tests/system/test_pipeline_acceptance.py b/tests/system/test_pipeline_acceptance.py new file mode 100644 index 000000000..00d5b72bb --- /dev/null +++ b/tests/system/test_pipeline_acceptance.py @@ -0,0 +1,94 @@ +from __future__ import annotations +import ast +import sys +import os +import black +import pytest +import yaml +from typing import Any + +# from google.cloud.firestore_v1.pipeline_stages import * +from google.cloud.firestore_v1 import pipeline_stages +from google.cloud.firestore_v1 import pipeline_expressions + +test_dir_name = os.path.dirname(__file__) + + +def loader(): + # load test cases + with open(f"{test_dir_name}/pipeline_e2e.yaml") as f: + test_cases = yaml.safe_load(f) + for test in test_cases["tests"]: + yield test + +def parse_pipeline(pipeline: list[dict[str, Any], str]): + """ + parse a yaml list of pipeline stages into firestore.pipeline_stages.Stage classes + """ + result_list = [] + for stage in pipeline: + # stage will be either a map of the stage_name and its args, or just the stage_name itself + stage_name: str = stage if isinstance(stage, str) else list(stage.keys())[0] + stage_cls: type[pipeline_stages.Stage] = getattr(pipeline_stages, stage_name) + # breakpoint() + # find arguments if given + if isinstance(stage, dict): + stage_yaml_args = stage[stage_name] + if isinstance(stage_yaml_args, dict): + # yaml has a mapping of arguments. Treat as kwargs + stage_obj = stage_cls(**parse_expressions(stage_yaml_args)) + elif isinstance(stage_yaml_args, list): + # yaml has an array of arguments. Treat as args + stage_obj = stage_cls(*parse_expressions(stage_yaml_args)) + else: + # yaml has a single argument + stage_obj = stage_cls(parse_expressions(stage_yaml_args)) + else: + # yaml has no arguments + stage_obj = stage_cls() + result_list.append(stage_obj) + return result_list + + +def parse_expressions(yaml_element: Any): + if isinstance(yaml_element, list): + return [parse_expressions(v) for v in yaml_element] + elif isinstance(yaml_element, dict): + return {parse_expressions(k): parse_expressions(v) for k,v in yaml_element.items()} + elif hasattr(pipeline_expressions, yaml_element): + return getattr(pipeline_expressions, yaml_element) + else: + return yaml_element + + +@pytest.mark.parametrize( + "test_dict", loader(), ids=lambda x: f"{x.get('description', '')}" +) +@pytest.mark.skipif( + sys.version_info < (3, 9), reason="ast.unparse requires python3.9 or higher" +) +def test_e2e_scenario(test_dict): + + pipeline = parse_pipeline(test_dict["pipeline"]) + + # before_ast = ast.parse(test_dict["before"]) + # got_ast = before_ast + # for transformer_info in test_dict["transformers"]: + # # transformer can be passed as a string, or a dict with name and args + # if isinstance(transformer_info, str): + # transformer_class = globals()[transformer_info] + # transformer_args = {} + # else: + # transformer_class = globals()[transformer_info["name"]] + # transformer_args = transformer_info.get("args", {}) + # transformer = transformer_class(**transformer_args) + # got_ast = transformer.visit(got_ast) + # if got_ast is None: + # final_str = "" + # else: + # final_str = black.format_str(ast.unparse(got_ast), mode=black.FileMode()) + # if test_dict.get("after") is None: + # expected_str = "" + # else: + # expected_str = black.format_str(test_dict["after"], mode=black.FileMode()) + # assert final_str == expected_str, f"Expected:\n{expected_str}\nGot:\n{final_str}" From 9c8982cca3649f076fe54d9c9df76ff057005417 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 24 Jan 2025 14:27:19 -0800 Subject: [PATCH 010/131] encoded extra java system tests --- tests/system/pipeline_e2e.yaml | 991 ++++++++++++++++++++++- tests/system/test_pipeline_acceptance.py | 2 +- 2 files changed, 979 insertions(+), 14 deletions(-) diff --git a/tests/system/pipeline_e2e.yaml b/tests/system/pipeline_e2e.yaml index ad7eed55d..154a3ea83 100644 --- a/tests/system/pipeline_e2e.yaml +++ b/tests/system/pipeline_e2e.yaml @@ -126,34 +126,999 @@ data: hugo: true nebula: true tests: - - description: "test accumulators" + - description: "testAggregates - count" pipeline: - - Collection: "books" + - Collection: books + - Aggregate: + accumulators: + - ExprWithAlias: + expr: CountAll + alias: count + results: + - count: 10 + - description: "testAggregates - avg, count, max" + pipeline: + - Collection: books - Where: condition: Eq: left: - Field: "genre" + Field: genre right: - Constant: "Science Fiction" + Constant: Science Fiction - Aggregate: accumulators: - ExprWithAlias: expr: CountAll - alias: "count" + alias: count - ExprWithAlias: expr: Avg: - value: "rating" - alias: "avg_rating" + value: rating + alias: avg_rating - ExprWithAlias: expr: Max: - value: "rating" - alias: "max_rating" + value: rating + alias: max_rating results: - - count: 10 - max_rating: 4.6 + - count: 2 avg_rating: 4.4 - - + max_rating: 4.6 + - description: testGroupBysWithoutAccumulators + pipeline: + - Collection: books + - Where: + condition: + Lt: + left: + Field: published + right: + Constant: 1900 + - Aggregate: + groups: + - genre + error: "Cannot groupBy without accumulators" + - description: testDistinct + pipeline: + - Collection: books + - Where: + condition: + Lt: + left: + Field: published + right: + Constant: 1900 + - Distinct: + - ExprWithAlias: + expr: + ToLower: + value: + Field: genre + alias: lower_genre + results: + - lower_genre: romance + - lower_genre: psychological thriller + - description: testGroupBysAndAggregate + pipeline: + - Collection: books + - Where: + condition: + Lt: + left: + Field: published + right: + Constant: 1984 + - Aggregate: + accumulators: + - ExprWithAlias: + expr: + Avg: + value: rating + alias: avg_rating + groups: + - genre + - Where: + condition: + Gt: + left: + Field: avg_rating + right: + Constant: 4.3 + results: + - avg_rating: 4.7 + genre: Fantasy + - avg_rating: 4.5 + genre: Romance + - avg_rating: 4.4 + genre: Science Fiction + - description: testMinMax + pipeline: + - Collection: books + - Aggregate: + accumulators: + - ExprWithAlias: + expr: CountAll + alias: count + - ExprWithAlias: + expr: + Max: + value: rating + alias: max_rating + - ExprWithAlias: + expr: + Min: + value: published + alias: min_published + results: + - count: 10 + max_rating: 4.7 + min_published: 1813 + - description: selectSpecificFields + pipeline: + - Collection: books + - Select: + - title + - author + - Sort: + - OrderBy: + - Field: author + - ASCENDING + results: + - title: "The Hitchhiker's Guide to the Galaxy" + author: "Douglas Adams" + - title: "Pride and Prejudice" + author: "Jane Austen" + - title: "The Handmaid's Tale" + author: "Margaret Atwood" + - title: "Crime and Punishment" + author: "Fyodor Dostoevsky" + - title: "The Great Gatsby" + author: "F. Scott Fitzgerald" + - title: "Dune" + author: "Frank Herbert" + - title: "To Kill a Mockingbird" + author: "Harper Lee" + - title: "One Hundred Years of Solitude" + author: "Gabriel García Márquez" + - title: "1984" + author: "George Orwell" + - title: "The Lord of the Rings" + author: "J.R.R. Tolkien" + - description: addAndRemoveFields + pipeline: + - Collection: books + - AddFields: + - ExprWithAlias: + expr: + StrConcat: + - Field: author + - Constant: _ + - Field: title + alias: author_title + - ExprWithAlias: + expr: + StrConcat: + - Field: title + - Constant: _ + - Field: author + alias: title_author + - RemoveFields: + - title_author + - tags + - awards + - rating + - title + - Field: published + - Field: genre + - Field: nestedField # Field does not exist, should be ignored + - Sort: + - OrderBy: + - Field: author_title + - ASCENDING + results: + - author: Douglas Adams + author_title: Douglas Adams_The Hitchhiker's Guide to the Galaxy + - author: Jane Austen + author_title: Jane Austen_Pride and Prejudice + - author: Margaret Atwood + author_title: Margaret Atwood_The Handmaid's Tale + - author: Fyodor Dostoevsky + author_title: Fyodor Dostoevsky_Crime and Punishment + - author: F. Scott Fitzgerald + author_title: F. Scott Fitzgerald_The Great Gatsby + - author: Frank Herbert + author_title: Frank Herbert_Dune + - author: Harper Lee + author_title: Harper Lee_To Kill a Mockingbird + - author: Gabriel García Márquez + author_title: Gabriel García Márquez_One Hundred Years of Solitude + - author: George Orwell + author_title: George Orwell_1984 + - author: J.R.R. Tolkien + author_title: J.R.R. Tolkien_The Lord of the Rings + - description: whereByMultipleConditions + pipeline: + - Collection: books + - Where: + condition: + And: + - Gt: + left: + Field: rating + right: + Constant: 4.5 + - Eq: + left: + Field: genre + right: + Constant: Science Fiction + results: + - title: Dune + author: Frank Herbert + genre: Science Fiction + published: 1965 + rating: 4.6 + tags: + - politics + - desert + - ecology + awards: + hugo: true + nebula: true + - description: whereByOrCondition + pipeline: + - Collection: books + - Where: + condition: + Or: + - Eq: + left: + Field: genre + right: + Constant: Romance + - Eq: + left: + Field: genre + right: + Constant: Dystopian + - Select: + - title + results: + - title: Pride and Prejudice + - title: The Handmaid's Tale + - title: 1984 + - description: testPipelineWithOffsetAndLimit + pipeline: + - Collection: books + - Sort: + - OrderBy: + - Field: author + - ASCENDING + - Offset: 5 + - Limit: 3 + - Select: + - title + - author + results: + - title: 1984 + author: George Orwell + - title: To Kill a Mockingbird + author: Harper Lee + - title: The Lord of the Rings + author: J.R.R. Tolkien + - description: testArrayContains + pipeline: + - Collection: books + - Where: + condition: + ArrayContains: + left: + Field: tags + right: + Constant: comedy + results: + - title: The Hitchhiker's Guide to the Galaxy + author: Douglas Adams + genre: Science Fiction + published: 1979 + rating: 4.2 + tags: + - comedy + - space + - adventure + awards: + hugo: true + nebula: false + - description: testArrayContainsAny + pipeline: + - Collection: books + - Where: + condition: + ArrayContainsAny: + left: + Field: tags + right: + Constant: + - comedy + - classic + - Select: + - title + results: + - title: The Hitchhiker's Guide to the Galaxy + - title: Pride and Prejudice + - description: testArrayContainsAll + pipeline: + - Collection: books + - Where: + condition: + ArrayContainsAll: + left: + Field: tags + right: + Constant: + - adventure + - magic + - Select: + - title + results: + - title: The Lord of the Rings + - description: testArrayLength + pipeline: + - Collection: books + - Select: + - ExprWithAlias: + expr: + ArrayLength: + value: + Field: tags + alias: tagsCount + - Where: + condition: + Eq: + left: + Field: tagsCount + right: + Constant: 3 + results: # All documents have 3 tags + - tagsCount: 3 + - tagsCount: 3 + - tagsCount: 3 + - tagsCount: 3 + - tagsCount: 3 + - tagsCount: 3 + - tagsCount: 3 + - tagsCount: 3 + - tagsCount: 3 + - tagsCount: 3 + - description: testArrayConcat + pipeline: + - Collection: books + - Select: + - ExprWithAlias: + expr: + ArrayConcat: + - Field: tags + - Constant: + - newTag1 + - newTag2 + alias: modifiedTags + - Limit: 1 + results: + - modifiedTags: + - comedy + - space + - adventure + - newTag1 + - newTag2 + - description: testStrConcat + pipeline: + - Collection: books + - Select: + - ExprWithAlias: + expr: + StrConcat: + - Field: author + - Constant: " - " + - Field: title + alias: bookInfo + - Limit: 1 + results: + - bookInfo: Douglas Adams - The Hitchhiker's Guide to the Galaxy + - description: testStartsWith + pipeline: + - Collection: books + - Where: + condition: + StartsWith: + left: + Field: title + right: + Constant: The + - Select: + - title + - Sort: + - OrderBy: + - Field: title + - ASCENDING + results: + - title: The Great Gatsby + - title: The Handmaid's Tale + - title: The Hitchhiker's Guide to the Galaxy + - title: The Lord of the Rings + - description: testEndsWith + pipeline: + - Collection: books + - Where: + condition: + EndsWith: + left: + Field: title + right: + Constant: y + - Select: + - title + - Sort: + - OrderBy: + - Field: title + - DESCENDING + results: + - title: The Hitchhiker's Guide to the Galaxy + - title: The Great Gatsby + - description: testLength + pipeline: + - Collection: books + - Select: + - ExprWithAlias: + expr: + CharLength: + value: + Field: title + alias: titleLength + - title + - Where: + condition: + Gt: + left: + Field: titleLength + right: + Constant: 20 + results: + - titleLength: 32 + title: The Hitchhiker's Guide to the Galaxy + - titleLength: 27 + title: One Hundred Years of Solitude + - description: testStringFunctions - Reverse + pipeline: + - Collection: books + - Select: + - ExprWithAlias: + expr: + Reverse: + value: + Field: title + alias: reversed_title + - Where: + condition: + Eq: + left: + Field: author + right: + Constant: Douglas Adams + results: + - reversed_title: yxalaG ot ediug s'reknhiHcH ehT + - description: testStringFunctions - ReplaceFirst + pipeline: + - Collection: books + - Select: + - ExprWithAlias: + expr: + ReplaceFirst: + value: + Field: title + pattern: The + replacement: A + alias: replaced_title + - Where: + condition: + Eq: + left: + Field: author + right: + Constant: Douglas Adams + results: + - replaced_title: A Hitchhiker's Guide to the Galaxy + - description: testStringFunctions - ReplaceAll + pipeline: + - Collection: books + - Select: + - ExprWithAlias: + expr: + ReplaceAll: + value: + Field: title + pattern: " " + replacement: _ + alias: replaced_title + - Where: + condition: + Eq: + left: + Field: author + right: + Constant: Douglas Adams + results: + - replaced_title: The_Hitchhiker's_Guide_to_the_Galaxy + - description: testStringFunctions - CharLength + pipeline: + - Collection: books + - Select: + - ExprWithAlias: + expr: + CharLength: + value: + Field: title + alias: title_length + - Where: + condition: + Eq: + left: + Field: author + right: + Constant: Douglas Adams + results: + - title_length: 30 + - description: testStringFunctions - ByteLength + pipeline: + - Collection: books + - Select: + - ExprWithAlias: + expr: + ByteLength: + value: + StrConcat: + - Field: title + - Constant: _银河系漫游指南 + alias: title_byte_length + - Where: + condition: + Eq: + left: + Field: author + right: + Constant: Douglas Adams + results: + - title_byte_length: 42 + - description: testToLowercase + pipeline: + - Collection: books + - Select: + - ExprWithAlias: + expr: + ToLower: + value: + Field: title + alias: lowercaseTitle + - Limit: 1 + results: + - lowercaseTitle: the hitchhiker's guide to the galaxy + - description: testToUppercase + pipeline: + - Collection: books + - Select: + - ExprWithAlias: + expr: + ToUpper: + value: + Field: author + alias: uppercaseAuthor + - Limit: 1 + results: + - uppercaseAuthor: DOUGLAS ADAMS + - description: testTrim + pipeline: + - Collection: books + - AddFields: + - ExprWithAlias: + expr: + StrConcat: + - Constant: " " + - Field: title + - Constant: " " + alias: spacedTitle + - Select: + - ExprWithAlias: + expr: + Trim: + value: + Field: spacedTitle + alias: trimmedTitle + - spacedTitle + - Limit: 1 + results: + - trimmedTitle: The Hitchhiker's Guide to the Galaxy + spacedTitle: " The Hitchhiker's Guide to the Galaxy " + - description: testLike + pipeline: + - Collection: books + - Where: + condition: + Like: + left: + Field: title + right: + Constant: "%Guide%" + - Select: + - title + results: + - title: The Hitchhiker's Guide to the Galaxy + - description: testRegexContains + pipeline: + - Collection: books + - Where: + condition: + RegexContains: + left: + Field: title + right: + Constant: "(?i)(the|of)" + results: + - title: The Hitchhiker's Guide to the Galaxy + - title: One Hundred Years of Solitude + - title: The Lord of the Rings + - title: To Kill a Mockingbird + - title: The Great Gatsby + - description: testRegexMatches + pipeline: + - Collection: books + - Where: + condition: + RegexMatch: + left: + Field: title + right: + Constant: ".*(?i)(the|of).*" + results: + - title: The Hitchhiker's Guide to the Galaxy + - title: One Hundred Years of Solitude + - title: The Lord of the Rings + - title: To Kill a Mockingbird + - title: The Great Gatsby + - description: testArithmeticOperations + pipeline: + - Collection: books + - Select: + - ExprWithAlias: + expr: + Add: + - Field: rating + - Constant: 1 + alias: ratingPlusOne + - ExprWithAlias: + expr: + Subtract: + - Field: published + - Constant: 1900 + alias: yearsSince1900 + - ExprWithAlias: + expr: + Multiply: + - Field: rating + - Constant: 10 + alias: ratingTimesTen + - ExprWithAlias: + expr: + Divide: + - Field: rating + - Constant: 2 + alias: ratingDividedByTwo + - Limit: 1 + results: + - ratingPlusOne: 5.2 + yearsSince1900: 79 + ratingTimesTen: 42.0 + ratingDividedByTwo: 2.1 + - description: testComparisonOperators + pipeline: + - Collection: books + - Where: + condition: + And: + - Gt: + left: + Field: rating + right: + Constant: 4.2 + - Lte: + left: + Field: rating + right: + Constant: 4.5 + - Neq: + left: + Field: genre + right: + Constant: Science Fiction + - Select: + - rating + - title + - Sort: + - OrderBy: + expr: + Field: title + direction: ASCENDING + results: + - rating: 4.3 + title: Crime and Punishment + - rating: 4.3 + title: One Hundred Years of Solitude + - rating: 4.5 + title: Pride and Prejudice + - description: testLogicalOperators + pipeline: + - Collection: books + - Where: + condition: + Or: + - And: + - Gt: + left: + Field: rating + right: + Constant: 4.5 + - Eq: + left: + Field: genre + right: + Constant: Science Fiction + - Lt: + left: + Field: published + right: + Constant: 1900 + - Select: + - title + - Sort: + - Ordering: + - Field: title + - ASCENDING + results: + - title: Crime and Punishment + - title: Dune + - title: Pride and Prejudice + - description: testChecks + pipeline: + - Collection: books + - Where: + condition: + Not: + IsNaN: + Field: rating + - Select: + - ExprWithAlias: + expr: + Eq: + left: + Field: rating + right: + Constant: null + alias: ratingIsNull + - ExprWithAlias: + expr: + Not: + IsNaN: + Field: rating + alias: ratingIsNotNaN + - Limit: 1 + results: + - ratingIsNull: false + ratingIsNotNaN: true + - description: testLogicalMinMax + pipeline: + - Collection: books + - Where: + condition: + Eq: + left: + Field: author + right: + Constant: Douglas Adams + - Select: + - ExprWithAlias: + expr: + LogicalMax: + - Field: rating + - Constant: 4.5 + alias: max_rating + - ExprWithAlias: + expr: + LogicalMax: + - Field: published + - Constant: 1900 + alias: max_published + results: + - max_rating: 4.5 + max_published: 1979 + - description: testLogicalMinMax - min + pipeline: + - Collection: books + - Select: + - ExprWithAlias: + expr: + LogicalMin: + - Field: rating + - Constant: 4.5 + alias: min_rating + - ExprWithAlias: + expr: + LogicalMin: + - Field: published + - Constant: 1900 + alias: min_published + results: + - min_rating: 4.2 + min_published: 1900 + - description: testMapGet + pipeline: + - Collection: books + - Select: + - ExprWithAlias: + expr: + MapGet: + map: + Field: awards + key: hugo + alias: hugoAward + - Field: title + - Where: + condition: + Eq: + left: + Field: hugoAward + right: + Constant: true + results: + - hugoAward: true + title: The Hitchhiker's Guide to the Galaxy + - hugoAward: true + title: Dune + - description: testDistanceFunctions + pipeline: + - Collection: books + - Select: + - ExprWithAlias: + expr: + CosineDistance: + - Vector: + - Constant: 0.1 + - Constant: 0.1 + - Vector: + - Constant: 0.5 + - Constant: 0.8 + alias: cosineDistance + - ExprWithAlias: + expr: + DotProduct: + - Vector: + - Constant: 0.1 + - Constant: 0.1 + - Vector: + - Constant: 0.5 + - Constant: 0.8 + alias: dotProductDistance + - ExprWithAlias: + expr: + EuclideanDistance: + - Vector: + - Constant: 0.1 + - Constant: 0.1 + - Vector: + - Constant: 0.5 + - Constant: 0.8 + alias: euclideanDistance + - Limit: 1 + results: + - cosineDistance: 0.02560880430538015 + dotProductDistance: 0.13 + euclideanDistance: 0.806225774829855 + - description: testNestedFields + pipeline: + - Collection: books + - Where: + condition: + Eq: + left: + Field: awards.hugo + right: + Constant: true + - Select: + - title + - Field: awards.hugo + results: + - title: The Hitchhiker's Guide to the Galaxy + awards.hugo: true + - title: Dune + awards.hugo: true + - description: testPipelineInTransactions + pipeline: + - Collection: books + - Where: + condition: + Eq: + left: + Field: awards.hugo + right: + Constant: true + - Select: + - title + - Field: awards.hugo + - __name__ + results: + - title: The Hitchhiker's Guide to the Galaxy + awards.hugo: true + - title: Dune + awards.hugo: true +# - description: testReplace +# pipeline: +# - Collection: books +# - Replace: awards +# results: +# - title: The Hitchhiker's Guide to the Galaxy +# author: Douglas Adams +# genre: Science Fiction +# published: 1979 +# rating: 4.2 +# tags: +# - comedy +# - space +# - adventure +# hugo: true +# nebula: false +# # ... other results with replaced awards + #- description: testSampleLimit + # pipeline: + # - Collection: books + # - Sample: + # method: LIMIT + # n: 3 + # results: # Results will vary due to randomness + # - # document data + # - # document data + # - # document data + #- description: testSamplePercentage + # pipeline: + # - Collection: books + # - Sample: + # method: PERCENTAGE + # n: 60 + # results: # Results will vary due to randomness + # - # document data + # - # document data + # - # document data + # - # document data + # - # document data + # - # document data + #- description: testUnion + # pipeline: + # - Union: + # - Collection: books + # - Collection: books + # results: # Results will be duplicated + # - # document data + # - # document data + # # ... 20 results total + - description: testUnnest + pipeline: + - Collection: books + - Where: + condition: + Eq: + left: + Field: title + right: + Constant: The Hitchhiker's Guide to the Galaxy + - Unnest: tags + results: + - tags: comedy + - tags: space + - tags: adventure diff --git a/tests/system/test_pipeline_acceptance.py b/tests/system/test_pipeline_acceptance.py index 00d5b72bb..ef97e25c9 100644 --- a/tests/system/test_pipeline_acceptance.py +++ b/tests/system/test_pipeline_acceptance.py @@ -55,7 +55,7 @@ def parse_expressions(yaml_element: Any): return [parse_expressions(v) for v in yaml_element] elif isinstance(yaml_element, dict): return {parse_expressions(k): parse_expressions(v) for k,v in yaml_element.items()} - elif hasattr(pipeline_expressions, yaml_element): + elif isinstance(yaml_element, str) and hasattr(pipeline_expressions, yaml_element): return getattr(pipeline_expressions, yaml_element) else: return yaml_element From 863bd1ddd7e281d035069e32e7dd7cb2ba4b7c8f Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 24 Jan 2025 15:59:49 -0800 Subject: [PATCH 011/131] reconstruct pipeline expr objects --- tests/system/test_pipeline_acceptance.py | 32 ++++++++++++++++-------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/tests/system/test_pipeline_acceptance.py b/tests/system/test_pipeline_acceptance.py index ef97e25c9..357fcec98 100644 --- a/tests/system/test_pipeline_acceptance.py +++ b/tests/system/test_pipeline_acceptance.py @@ -21,6 +21,18 @@ def loader(): for test in test_cases["tests"]: yield test +def _apply_yaml_args(cls, yaml_args): + if isinstance(yaml_args, dict): + # yaml has a mapping of arguments. Treat as kwargs + return cls(**parse_expressions(yaml_args)) + elif isinstance(yaml_args, list): + # yaml has an array of arguments. Treat as args + return cls(*parse_expressions(yaml_args)) + else: + # yaml has a single argument + return cls(parse_expressions(yaml_args)) + + def parse_pipeline(pipeline: list[dict[str, Any], str]): """ parse a yaml list of pipeline stages into firestore.pipeline_stages.Stage classes @@ -34,15 +46,7 @@ def parse_pipeline(pipeline: list[dict[str, Any], str]): # find arguments if given if isinstance(stage, dict): stage_yaml_args = stage[stage_name] - if isinstance(stage_yaml_args, dict): - # yaml has a mapping of arguments. Treat as kwargs - stage_obj = stage_cls(**parse_expressions(stage_yaml_args)) - elif isinstance(stage_yaml_args, list): - # yaml has an array of arguments. Treat as args - stage_obj = stage_cls(*parse_expressions(stage_yaml_args)) - else: - # yaml has a single argument - stage_obj = stage_cls(parse_expressions(stage_yaml_args)) + stage_obj = _apply_yaml_args(stage_cls, stage_yaml_args) else: # yaml has no arguments stage_obj = stage_cls() @@ -54,7 +58,15 @@ def parse_expressions(yaml_element: Any): if isinstance(yaml_element, list): return [parse_expressions(v) for v in yaml_element] elif isinstance(yaml_element, dict): - return {parse_expressions(k): parse_expressions(v) for k,v in yaml_element.items()} + if len(yaml_element) == 1 and isinstance(list(yaml_element)[0], str) and hasattr(pipeline_expressions, list(yaml_element)[0]): + # build pipeline expressions if possible + cls_str = list(yaml_element)[0] + cls = parse_expressions(cls_str) + yaml_args = yaml_element[cls_str] + return _apply_yaml_args(cls, yaml_args) + else: + # otherwise, return dict + return {parse_expressions(k): parse_expressions(v) for k,v in yaml_element.items()} elif isinstance(yaml_element, str) and hasattr(pipeline_expressions, yaml_element): return getattr(pipeline_expressions, yaml_element) else: From 7cd2c11e2cd9447f05895570e49fda5fb2520be4 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 24 Jan 2025 16:42:20 -0800 Subject: [PATCH 012/131] got yaml to run --- .../firestore_v1/pipeline_expressions.py | 62 +++++++---- google/cloud/firestore_v1/pipeline_stages.py | 38 +++++-- tests/system/pipeline_e2e.yaml | 103 +++++++----------- 3 files changed, 109 insertions(+), 94 deletions(-) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index dee209c34..f5f6824fb 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -1,6 +1,12 @@ from typing import Any, Iterable, List, Mapping +class Ordering: + + def __init__(self, expr, order_dir): + self.expr = expr + self.order_dir = order_dir + class Expr: """Represents an expression that can be evaluated to a value within the execution of a pipeline. @@ -55,8 +61,8 @@ def __init__(self, left: Expr, right: Expr): class MapGet(Function): - def __init__(self, map: Expr, name: str): - super().__init__("map_get", [map, Constant(name)]) + def __init__(self, map_: Expr, key: str): + super().__init__("map_get", [map_, Constant(key)]) class Mod(Function): @@ -75,13 +81,13 @@ def __init__(self, value: Expr): class ReplaceAll(Function): - def __init__(self, value: Expr, find: Expr, replacement: Expr): - super().__init__("replace_all", [value, find, replacement]) + def __init__(self, value: Expr, pattern: Expr, replacement: Expr): + super().__init__("replace_all", [value, pattern, replacement]) class ReplaceFirst(Function): - def __init__(self, value: Expr, find: Expr, replacement: Expr): - super().__init__("replace_first", [value, find, replacement]) + def __init__(self, value: Expr, pattern: Expr, replacement: Expr): + super().__init__("replace_first", [value, pattern, replacement]) class Reverse(Function): @@ -90,8 +96,8 @@ def __init__(self, expr: Expr): class StrConcat(Function): - def __init__(self, first: Expr, exprs: List[Expr]): - super().__init__("str_concat", [first] + exprs) + def __init__(self, *exprs: Expr): + super().__init__("str_concat", exprs) class Subtract(Function): @@ -125,13 +131,13 @@ def __init__(self, input: Expr): class ToLower(Function): - def __init__(self, expr: Expr): - super().__init__("to_lower", [expr]) + def __init__(self, value: Expr): + super().__init__("to_lower", [value]) class ToUpper(Function): - def __init__(self, expr: Expr): - super().__init__("to_upper", [expr]) + def __init__(self, value: Expr): + super().__init__("to_upper", [value]) class Trim(Function): @@ -219,23 +225,23 @@ class Accumulator(Function): class Max(Accumulator): - def __init__(self, value: Expr, distinct: bool): + def __init__(self, value: Expr, distinct: bool=False): super().__init__("max", [value]) class Min(Accumulator): - def __init__(self, value: Expr, distinct: bool): + def __init__(self, value: Expr, distinct: bool=False): super().__init__("min", [value]) class Sum(Accumulator): - def __init__(self, value: Expr, distinct: bool): + def __init__(self, value: Expr, distinct: bool=False): super().__init__("sum", [value]) class Avg(Accumulator): - def __init__(self, value: Expr, distinct: bool): - super(Function, self).__init__("avg", [value]) + def __init__(self, value: Expr, distinct: bool=False): + super().__init__("avg", [value]) class Count(Accumulator): @@ -244,26 +250,35 @@ def __init__(self, value: Expr = None): class CountIf(Function): - def __init__(self, value: Expr, distinct: bool): + def __init__(self, value: Expr, distinct: bool=False): super(Function, self).__init__("countif", [value] if value else []) class Selectable: """Points at something in the database?""" + def _to_map(self): + raise NotImplementedError + class AccumulatorTarget(Selectable): - def __init__(self, accumulator: Accumulator, field_name: str, distinct: bool): + def __init__(self, accumulator: Accumulator, field_name: str, distinct: bool=False): self.accumulator = accumulator self.field_name = field_name self.distinct = distinct + def _to_map(self): + return self.field_name, self.accumulator + class ExprWithAlias(Expr, Selectable): def __init__(self, expr: Expr, alias: str): self.expr = expr self.alias = alias + def _to_map(self): + return self.alias, self.expr + class Field(Expr, Selectable): DOCUMENT_ID = "__name__" @@ -271,13 +286,16 @@ class Field(Expr, Selectable): def __init__(self, path: str): self.path = path + def _to_map(self): + return self.path, self + class FilterCondition(Function): """Filters the given data in some way.""" class And(FilterCondition): - def __init__(self, conditions: List["FilterCondition"]): + def __init__(self, *conditions: "FilterCondition"): super().__init__("and", conditions) @@ -335,7 +353,7 @@ def __init__(self, left: Expr, others: List[Expr]): super().__init__("in", [left, ListOfExprs(others)]) -class IsNan(FilterCondition): +class IsNaN(FilterCondition): def __init__(self, value: Expr): super().__init__("is_nan", [value]) @@ -366,7 +384,7 @@ def __init__(self, condition: Expr): class Or(FilterCondition): - def __init__(self, conditions: List["FilterCondition"]): + def __init__(self, *conditions: "FilterCondition"): super().__init__("or", conditions) diff --git a/google/cloud/firestore_v1/pipeline_stages.py b/google/cloud/firestore_v1/pipeline_stages.py index ca7e51ea8..be24a7081 100644 --- a/google/cloud/firestore_v1/pipeline_stages.py +++ b/google/cloud/firestore_v1/pipeline_stages.py @@ -40,17 +40,18 @@ class Stage: def __init__(self, custom_name: Optional[str] = None): self.name = custom_name or type(self).__name__.lower() + class AddFields(Stage): - def __init__(self, fields: Dict[str, Expr]): + def __init__(self, *fields: Selectable): super().__init__("add_fields") - self.fields = fields + self.fields = dict(f._to_map() for f in fields) class Aggregate(Stage): def __init__( self, - groups: Optional[Dict[str, Expr]] = None, - accumulators: Optional[Dict[str, Accumulator]] = None, + groups={}, + accumulators={} ): super().__init__() self.groups = groups or {} @@ -77,9 +78,14 @@ def __init__(self): class Distinct(Stage): - def __init__(self, groups: Dict[str, Expr]): + def __init__(self, *fields: str | Selectable): super().__init__() - self.groups = groups + self.fields = dict( + f._to_map() + if isinstance(f, Selectable) + else (f,Field(f)) + for f in fields + ) class Documents(Stage): @@ -127,9 +133,14 @@ def __init__(self, offset: int): class RemoveFields(Stage): - def __init__(self, fields: List[Field]): + def __init__(self, *fields: str | Field): super().__init__("remove_fields") - self.fields = fields + self.fields = dict( + f._to_map() + if isinstance(f, Selectable) + else (f,Field(f)) + for f in fields + ) class Replace(Stage): @@ -151,9 +162,16 @@ def __init__(self, options: "SampleOptions"): class Select(Stage): - def __init__(self, projections: Dict[str, Expr]): + def __init__(self, *fields: str | Selectable): super().__init__() - self.projections = projections + self.projections = dict( + f._to_map() + if isinstance(f, Selectable) + else (f,Field(f)) + for f in fields + ) + + class Sort(Stage): diff --git a/tests/system/pipeline_e2e.yaml b/tests/system/pipeline_e2e.yaml index 154a3ea83..eca247be9 100644 --- a/tests/system/pipeline_e2e.yaml +++ b/tests/system/pipeline_e2e.yaml @@ -413,10 +413,8 @@ tests: - Where: condition: ArrayContains: - left: - Field: tags - right: - Constant: comedy + - tags + - comedy results: - title: The Hitchhiker's Guide to the Galaxy author: Douglas Adams @@ -436,12 +434,10 @@ tests: - Where: condition: ArrayContainsAny: - left: - Field: tags - right: - Constant: - - comedy - - classic + array: tags + elements: + - comedy + - classic - Select: - title results: @@ -453,12 +449,11 @@ tests: - Where: condition: ArrayContainsAll: - left: + array: Field: tags - right: - Constant: - - adventure - - magic + elements: + - adventure + - magic - Select: - title results: @@ -470,8 +465,7 @@ tests: - ExprWithAlias: expr: ArrayLength: - value: - Field: tags + - Field: tags alias: tagsCount - Where: condition: @@ -498,10 +492,11 @@ tests: - ExprWithAlias: expr: ArrayConcat: - - Field: tags - - Constant: - - newTag1 - - newTag2 + array: + Field: tags + rest: + - newTag1 + - newTag2 alias: modifiedTags - Limit: 1 results: @@ -531,10 +526,8 @@ tests: - Where: condition: StartsWith: - left: - Field: title - right: - Constant: The + - Field: title + - Constant: the - Select: - title - Sort: @@ -552,10 +545,8 @@ tests: - Where: condition: EndsWith: - left: - Field: title - right: - Constant: y + - Field: title + - Constant: y - Select: - title - Sort: @@ -572,8 +563,7 @@ tests: - ExprWithAlias: expr: CharLength: - value: - Field: title + - Field: title alias: titleLength - title - Where: @@ -595,8 +585,7 @@ tests: - ExprWithAlias: expr: Reverse: - value: - Field: title + - Field: title alias: reversed_title - Where: condition: @@ -656,8 +645,7 @@ tests: - ExprWithAlias: expr: CharLength: - value: - Field: title + - Field: title alias: title_length - Where: condition: @@ -675,10 +663,9 @@ tests: - ExprWithAlias: expr: ByteLength: - value: - StrConcat: - - Field: title - - Constant: _银河系漫游指南 + - StrConcat: + - Field: title + - Constant: _银河系漫游指南 alias: title_byte_length - Where: condition: @@ -730,8 +717,7 @@ tests: - ExprWithAlias: expr: Trim: - value: - Field: spacedTitle + - Field: spacedTitle alias: trimmedTitle - spacedTitle - Limit: 1 @@ -744,10 +730,8 @@ tests: - Where: condition: Like: - left: - Field: title - right: - Constant: "%Guide%" + - Field: title + - Constant: "%Guide%" - Select: - title results: @@ -758,10 +742,8 @@ tests: - Where: condition: RegexContains: - left: - Field: title - right: - Constant: "(?i)(the|of)" + - Field: title + - Constant: "(?i)(the|of)" results: - title: The Hitchhiker's Guide to the Galaxy - title: One Hundred Years of Solitude @@ -774,10 +756,8 @@ tests: - Where: condition: RegexMatch: - left: - Field: title - right: - Constant: ".*(?i)(the|of).*" + - Field: title + - Constant: ".*(?i)(the|of).*" results: - title: The Hitchhiker's Guide to the Galaxy - title: One Hundred Years of Solitude @@ -891,9 +871,9 @@ tests: - Collection: books - Where: condition: - Not: - IsNaN: - Field: rating + - Not: + - IsNaN: + - Field: rating - Select: - ExprWithAlias: expr: @@ -905,9 +885,9 @@ tests: alias: ratingIsNull - ExprWithAlias: expr: - Not: - IsNaN: - Field: rating + - Not: + - IsNaN: + - Field: rating alias: ratingIsNotNaN - Limit: 1 results: @@ -965,9 +945,8 @@ tests: - ExprWithAlias: expr: MapGet: - map: - Field: awards - key: hugo + - Field: awards + - hugo alias: hugoAward - Field: title - Where: From c32893497501526115f660eab35455f2e12bfab0 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 24 Jan 2025 17:06:29 -0800 Subject: [PATCH 013/131] add data loading code --- tests/system/test_pipeline_acceptance.py | 34 +++++++++++++++++++----- 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/tests/system/test_pipeline_acceptance.py b/tests/system/test_pipeline_acceptance.py index 357fcec98..03e881605 100644 --- a/tests/system/test_pipeline_acceptance.py +++ b/tests/system/test_pipeline_acceptance.py @@ -1,8 +1,6 @@ from __future__ import annotations -import ast import sys import os -import black import pytest import yaml from typing import Any @@ -11,15 +9,43 @@ from google.cloud.firestore_v1 import pipeline_stages from google.cloud.firestore_v1 import pipeline_expressions +from google.cloud.firestore import Client + +FIRESTORE_TEST_DB = os.environ.get("SYSTEM_TESTS_DATABASE", "system-tests-named-db") +FIRESTORE_PROJECT = os.environ.get("GCLOUD_PROJECT") + test_dir_name = os.path.dirname(__file__) + def loader(): # load test cases with open(f"{test_dir_name}/pipeline_e2e.yaml") as f: test_cases = yaml.safe_load(f) + # load data + data = test_cases["data"] + client = Client(project=FIRESTORE_PROJECT, database=FIRESTORE_TEST_DB) + try: + # setup data + batch = client.batch() + for collection_name, documents in data.items(): + collection_ref = client.collection(collection_name) + for document_id, document_data in documents.items(): + document_ref = collection_ref.document(document_id) + batch.set(document_ref, document_data) + batch.commit() + + # run tests for test in test_cases["tests"]: yield test + finally: + # clear data + for collection_name, documents in data.items(): + collection_ref = client.collection(collection_name) + for document_id in documents: + document_ref = collection_ref.document(document_id) + document_ref.delete() + def _apply_yaml_args(cls, yaml_args): if isinstance(yaml_args, dict): @@ -76,11 +102,7 @@ def parse_expressions(yaml_element: Any): @pytest.mark.parametrize( "test_dict", loader(), ids=lambda x: f"{x.get('description', '')}" ) -@pytest.mark.skipif( - sys.version_info < (3, 9), reason="ast.unparse requires python3.9 or higher" -) def test_e2e_scenario(test_dict): - pipeline = parse_pipeline(test_dict["pipeline"]) # before_ast = ast.parse(test_dict["before"]) From de0a862fc9e604f3867c9711a6684143f840bbf6 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 26 Feb 2025 16:26:59 -0800 Subject: [PATCH 014/131] updated protos --- .../services/firestore/async_client.py | 53 +++++++ .../firestore_v1/services/firestore/client.py | 67 ++++++++ .../services/firestore/transports/base.py | 5 + .../services/firestore/transports/grpc.py | 28 ++++ .../firestore/transports/grpc_asyncio.py | 28 ++++ .../services/firestore/transports/rest.py | 122 +++++++++++++++ google/cloud/firestore_v1/types/__init__.py | 16 ++ google/cloud/firestore_v1/types/document.py | 134 ++++++++++++++++ .../cloud/firestore_v1/types/explain_stats.py | 53 +++++++ google/cloud/firestore_v1/types/firestore.py | 143 ++++++++++++++++++ google/cloud/firestore_v1/types/pipeline.py | 61 ++++++++ google/cloud/firestore_v1/types/query.py | 1 + 12 files changed, 711 insertions(+) create mode 100644 google/cloud/firestore_v1/types/explain_stats.py create mode 100644 google/cloud/firestore_v1/types/pipeline.py diff --git a/google/cloud/firestore_v1/services/firestore/async_client.py b/google/cloud/firestore_v1/services/firestore/async_client.py index ec1d55e76..a11916f2d 100644 --- a/google/cloud/firestore_v1/services/firestore/async_client.py +++ b/google/cloud/firestore_v1/services/firestore/async_client.py @@ -51,6 +51,7 @@ from google.cloud.firestore_v1.types import common from google.cloud.firestore_v1.types import document from google.cloud.firestore_v1.types import document as gf_document +from google.cloud.firestore_v1.types import explain_stats from google.cloud.firestore_v1.types import firestore from google.cloud.firestore_v1.types import query from google.cloud.firestore_v1.types import query_profile @@ -1182,6 +1183,58 @@ async def sample_run_query(): # Done; return the response. return response + def execute_pipeline( + self, + request: Optional[Union[firestore.ExecutePipelineRequest, dict]] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> Awaitable[AsyncIterable[firestore.ExecutePipelineResponse]]: + r"""Executes a pipeline query. + + Args: + request (Optional[Union[google.cloud.firestore_v1.types.ExecutePipelineRequest, dict]]): + The request object. The request for + [Firestore.ExecutePipeline][google.firestore.v1.Firestore.ExecutePipeline]. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + AsyncIterable[google.cloud.firestore_v1.types.ExecutePipelineResponse]: + The response for [Firestore.Execute][]. + """ + # Create or coerce a protobuf request object. + request = firestore.ExecutePipelineRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.execute_pipeline, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("database", request.database),)), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + def run_aggregation_query( self, request: Optional[Union[firestore.RunAggregationQueryRequest, dict]] = None, diff --git a/google/cloud/firestore_v1/services/firestore/client.py b/google/cloud/firestore_v1/services/firestore/client.py index 888c88e80..c8974eaa4 100644 --- a/google/cloud/firestore_v1/services/firestore/client.py +++ b/google/cloud/firestore_v1/services/firestore/client.py @@ -55,6 +55,7 @@ from google.cloud.firestore_v1.types import common from google.cloud.firestore_v1.types import document from google.cloud.firestore_v1.types import document as gf_document +from google.cloud.firestore_v1.types import explain_stats from google.cloud.firestore_v1.types import firestore from google.cloud.firestore_v1.types import query from google.cloud.firestore_v1.types import query_profile @@ -1565,6 +1566,72 @@ def sample_run_query(): # Done; return the response. return response + def execute_pipeline( + self, + request: Optional[Union[firestore.ExecutePipelineRequest, dict]] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> Iterable[firestore.ExecutePipelineResponse]: + r"""Executes a pipeline query. + + Args: + request (Union[google.cloud.firestore_v1.types.ExecutePipelineRequest, dict]): + The request object. The request for + [Firestore.ExecutePipeline][google.firestore.v1.Firestore.ExecutePipeline]. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + Iterable[google.cloud.firestore_v1.types.ExecutePipelineResponse]: + The response for [Firestore.Execute][]. + """ + # Create or coerce a protobuf request object. + # Minor optimization to avoid making a copy if the user passes + # in a firestore.ExecutePipelineRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, firestore.ExecutePipelineRequest): + request = firestore.ExecutePipelineRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.execute_pipeline] + + header_params = {} + + routing_param_regex = re.compile("^projects/(?P[^/]+)(?:/.*)?$") + regex_match = routing_param_regex.match(request.database) + if regex_match and regex_match.group("project_id"): + header_params["project_id"] = regex_match.group("project_id") + + routing_param_regex = re.compile( + "^projects/[^/]+/databases/(?P[^/]+)(?:/.*)?$" + ) + regex_match = routing_param_regex.match(request.database) + if regex_match and regex_match.group("database_id"): + header_params["database_id"] = regex_match.group("database_id") + + if header_params: + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(header_params), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + def run_aggregation_query( self, request: Optional[Union[firestore.RunAggregationQueryRequest, dict]] = None, diff --git a/google/cloud/firestore_v1/services/firestore/transports/base.py b/google/cloud/firestore_v1/services/firestore/transports/base.py index d22e6ce3b..a5f1f52ed 100644 --- a/google/cloud/firestore_v1/services/firestore/transports/base.py +++ b/google/cloud/firestore_v1/services/firestore/transports/base.py @@ -286,6 +286,11 @@ def _prep_wrapped_messages(self, client_info): default_timeout=300.0, client_info=client_info, ), + self.execute_pipeline: gapic_v1.method.wrap_method( + self.execute_pipeline, + default_timeout=None, + client_info=client_info, + ), self.run_aggregation_query: gapic_v1.method.wrap_method( self.run_aggregation_query, default_retry=retries.Retry( diff --git a/google/cloud/firestore_v1/services/firestore/transports/grpc.py b/google/cloud/firestore_v1/services/firestore/transports/grpc.py index 7d334a539..508fa93db 100644 --- a/google/cloud/firestore_v1/services/firestore/transports/grpc.py +++ b/google/cloud/firestore_v1/services/firestore/transports/grpc.py @@ -486,6 +486,34 @@ def run_query( ) return self._stubs["run_query"] + @property + def execute_pipeline( + self, + ) -> Callable[ + [firestore.ExecutePipelineRequest], firestore.ExecutePipelineResponse + ]: + r"""Return a callable for the execute pipeline method over gRPC. + + Executes a pipeline query. + + Returns: + Callable[[~.ExecutePipelineRequest], + ~.ExecutePipelineResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "execute_pipeline" not in self._stubs: + self._stubs["execute_pipeline"] = self.grpc_channel.unary_stream( + "/google.firestore.v1.Firestore/ExecutePipeline", + request_serializer=firestore.ExecutePipelineRequest.serialize, + response_deserializer=firestore.ExecutePipelineResponse.deserialize, + ) + return self._stubs["execute_pipeline"] + @property def run_aggregation_query( self, diff --git a/google/cloud/firestore_v1/services/firestore/transports/grpc_asyncio.py b/google/cloud/firestore_v1/services/firestore/transports/grpc_asyncio.py index c8eaab433..ae0dc1c04 100644 --- a/google/cloud/firestore_v1/services/firestore/transports/grpc_asyncio.py +++ b/google/cloud/firestore_v1/services/firestore/transports/grpc_asyncio.py @@ -498,6 +498,34 @@ def run_query( ) return self._stubs["run_query"] + @property + def execute_pipeline( + self, + ) -> Callable[ + [firestore.ExecutePipelineRequest], Awaitable[firestore.ExecutePipelineResponse] + ]: + r"""Return a callable for the execute pipeline method over gRPC. + + Executes a pipeline query. + + Returns: + Callable[[~.ExecutePipelineRequest], + Awaitable[~.ExecutePipelineResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "execute_pipeline" not in self._stubs: + self._stubs["execute_pipeline"] = self.grpc_channel.unary_stream( + "/google.firestore.v1.Firestore/ExecutePipeline", + request_serializer=firestore.ExecutePipelineRequest.serialize, + response_deserializer=firestore.ExecutePipelineResponse.deserialize, + ) + return self._stubs["execute_pipeline"] + @property def run_aggregation_query( self, diff --git a/google/cloud/firestore_v1/services/firestore/transports/rest.py b/google/cloud/firestore_v1/services/firestore/transports/rest.py index c85f4f2ed..6364c4e3f 100644 --- a/google/cloud/firestore_v1/services/firestore/transports/rest.py +++ b/google/cloud/firestore_v1/services/firestore/transports/rest.py @@ -115,6 +115,14 @@ def pre_delete_document(self, request, metadata): logging.log(f"Received request: {request}") return request, metadata + def pre_execute_pipeline(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_execute_pipeline(self, response): + logging.log(f"Received response: {response}") + return response + def pre_get_document(self, request, metadata): logging.log(f"Received request: {request}") return request, metadata @@ -302,6 +310,29 @@ def pre_delete_document( """ return request, metadata + def pre_execute_pipeline( + self, + request: firestore.ExecutePipelineRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[firestore.ExecutePipelineRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for execute_pipeline + + Override in a subclass to manipulate the request or metadata + before they are sent to the Firestore server. + """ + return request, metadata + + def post_execute_pipeline( + self, response: rest_streaming.ResponseIterator + ) -> rest_streaming.ResponseIterator: + """Post-rpc interceptor for execute_pipeline + + Override in a subclass to manipulate the response + after it is returned by the Firestore server but before + it is returned to user code. + """ + return response + def pre_get_document( self, request: firestore.GetDocumentRequest, metadata: Sequence[Tuple[str, str]] ) -> Tuple[firestore.GetDocumentRequest, Sequence[Tuple[str, str]]]: @@ -1114,6 +1145,8 @@ def __call__( query_params.update(self._get_unset_required_fields(query_params)) query_params["$alt"] = "json;enum-encoding=int" + class _ExecutePipeline(FirestoreRestStub): + def __hash__(self): # Send the request headers = dict(metadata) @@ -1197,6 +1230,77 @@ def __call__( query_params.update(self._get_unset_required_fields(query_params)) query_params["$alt"] = "json;enum-encoding=int" + class _ExecutePipeline(FirestoreRestStub): + def __hash__(self): + return hash("ExecutePipeline") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: firestore.ExecutePipelineRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> rest_streaming.ResponseIterator: + r"""Call the execute pipeline method over HTTP. + + Args: + request (~.firestore.ExecutePipelineRequest): + The request object. The request for + [Firestore.ExecutePipeline][google.firestore.v1.Firestore.ExecutePipeline]. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.firestore.ExecutePipelineResponse: + The response for [Firestore.Execute][]. + """ + + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1/{database=projects/*/databases/*}/documents:executePipeline", + "body": "*", + }, + ] + request, metadata = self._interceptor.pre_execute_pipeline( + request, metadata + ) + pb_request = firestore.ExecutePipelineRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], + including_default_value_fields=False, + use_integers_for_enums=False, + ) + uri = transcoded_request["uri"] + method = transcoded_request["method"] + + # Jsonify the query params + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) + query_params.update(self._get_unset_required_fields(query_params)) # Send the request headers = dict(metadata) @@ -1206,6 +1310,7 @@ def __call__( timeout=timeout, headers=headers, params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, ) # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception @@ -1213,6 +1318,13 @@ def __call__( if response.status_code >= 400: raise core_exceptions.from_http_response(response) + # Return the response + resp = rest_streaming.ResponseIterator( + response, firestore.ExecutePipelineResponse + ) + resp = self._interceptor.post_execute_pipeline(resp) + return resp + class _GetDocument(FirestoreRestStub): def __hash__(self): return hash("GetDocument") @@ -2053,6 +2165,16 @@ def delete_document( # In C++ this would require a dynamic_cast return self._DeleteDocument(self._session, self._host, self._interceptor) # type: ignore + @property + def execute_pipeline( + self, + ) -> Callable[ + [firestore.ExecutePipelineRequest], firestore.ExecutePipelineResponse + ]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._ExecutePipeline(self._session, self._host, self._interceptor) # type: ignore + @property def get_document( self, diff --git a/google/cloud/firestore_v1/types/__init__.py b/google/cloud/firestore_v1/types/__init__.py index 433c8a012..1e6b0c729 100644 --- a/google/cloud/firestore_v1/types/__init__.py +++ b/google/cloud/firestore_v1/types/__init__.py @@ -28,9 +28,14 @@ from .document import ( ArrayValue, Document, + Function, MapValue, + Pipeline, Value, ) +from .explain_stats import ( + ExplainStats, +) from .firestore import ( BatchGetDocumentsRequest, BatchGetDocumentsResponse, @@ -42,6 +47,8 @@ CommitResponse, CreateDocumentRequest, DeleteDocumentRequest, + ExecutePipelineRequest, + ExecutePipelineResponse, GetDocumentRequest, ListCollectionIdsRequest, ListCollectionIdsResponse, @@ -62,6 +69,9 @@ WriteRequest, WriteResponse, ) +from .pipeline import ( + StructuredPipeline, +) from .query import ( Cursor, StructuredAggregationQuery, @@ -92,8 +102,11 @@ "TransactionOptions", "ArrayValue", "Document", + "Function", "MapValue", + "Pipeline", "Value", + "ExplainStats", "BatchGetDocumentsRequest", "BatchGetDocumentsResponse", "BatchWriteRequest", @@ -104,6 +117,8 @@ "CommitResponse", "CreateDocumentRequest", "DeleteDocumentRequest", + "ExecutePipelineRequest", + "ExecutePipelineResponse", "GetDocumentRequest", "ListCollectionIdsRequest", "ListCollectionIdsResponse", @@ -123,6 +138,7 @@ "UpdateDocumentRequest", "WriteRequest", "WriteResponse", + "StructuredPipeline", "Cursor", "StructuredAggregationQuery", "StructuredQuery", diff --git a/google/cloud/firestore_v1/types/document.py b/google/cloud/firestore_v1/types/document.py index 4def67f9a..432e043df 100644 --- a/google/cloud/firestore_v1/types/document.py +++ b/google/cloud/firestore_v1/types/document.py @@ -31,6 +31,8 @@ "Value", "ArrayValue", "MapValue", + "Function", + "Pipeline", }, ) @@ -246,6 +248,23 @@ class Value(proto.Message): oneof="value_type", message="MapValue", ) + field_reference_value: str = proto.Field( + proto.STRING, + number=19, + oneof="value_type", + ) + function_value: "Function" = proto.Field( + proto.MESSAGE, + number=20, + oneof="value_type", + message="Function", + ) + pipeline_value: "Pipeline" = proto.Field( + proto.MESSAGE, + number=21, + oneof="value_type", + message="Pipeline", + ) class ArrayValue(proto.Message): @@ -285,4 +304,119 @@ class MapValue(proto.Message): ) +class Function(proto.Message): + r"""Represents an unevaluated scalar expression. + + For example, the expression ``like(user_name, "%alice%")`` is + represented as: + + :: + + name: "like" + args { field_reference: "user_name" } + args { string_value: "%alice%" } + + Attributes: + name (str): + Required. The name of the function to evaluate. + + **Requires:** + + - must be in snake case (lower case with underscore + separator). + args (MutableSequence[google.cloud.firestore_v1.types.Value]): + Optional. Ordered list of arguments the given + function expects. + options (MutableMapping[str, google.cloud.firestore_v1.types.Value]): + Optional. Optional named arguments that + certain functions may support. + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + args: MutableSequence["Value"] = proto.RepeatedField( + proto.MESSAGE, + number=2, + message="Value", + ) + options: MutableMapping[str, "Value"] = proto.MapField( + proto.STRING, + proto.MESSAGE, + number=3, + message="Value", + ) + + +class Pipeline(proto.Message): + r"""A Firestore query represented as an ordered list of + operations / stages. + + Attributes: + stages (MutableSequence[google.cloud.firestore_v1.types.Pipeline.Stage]): + Required. Ordered list of stages to evaluate. + """ + + class Stage(proto.Message): + r"""A single operation within a pipeline. + + A stage is made up of a unique name, and a list of arguments. The + exact number of arguments & types is dependent on the stage type. + + To give an example, the stage ``filter(state = "MD")`` would be + encoded as: + + :: + + name: "filter" + args { + function_value { + name: "eq" + args { field_reference_value: "state" } + args { string_value: "MD" } + } + } + + See public documentation for the full list. + + Attributes: + name (str): + Required. The name of the stage to evaluate. + + **Requires:** + + - must be in snake case (lower case with underscore + separator). + args (MutableSequence[google.cloud.firestore_v1.types.Value]): + Optional. Ordered list of arguments the given + stage expects. + options (MutableMapping[str, google.cloud.firestore_v1.types.Value]): + Optional. Optional named arguments that + certain functions may support. + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + args: MutableSequence["Value"] = proto.RepeatedField( + proto.MESSAGE, + number=2, + message="Value", + ) + options: MutableMapping[str, "Value"] = proto.MapField( + proto.STRING, + proto.MESSAGE, + number=3, + message="Value", + ) + + stages: MutableSequence[Stage] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message=Stage, + ) + + __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/firestore_v1/types/explain_stats.py b/google/cloud/firestore_v1/types/explain_stats.py new file mode 100644 index 000000000..9d12e8c31 --- /dev/null +++ b/google/cloud/firestore_v1/types/explain_stats.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +import proto # type: ignore + +from google.protobuf import any_pb2 # type: ignore + + +__protobuf__ = proto.module( + package="google.firestore.v1", + manifest={ + "ExplainStats", + }, +) + + +class ExplainStats(proto.Message): + r"""Explain stats for an RPC request, includes both the optimized + plan and execution stats. + + Attributes: + data (google.protobuf.any_pb2.Any): + The format depends on the ``output_format`` options in the + request. + + The only option today is ``TEXT``, which is a + ``google.protobuf.StringValue``. + """ + + data: any_pb2.Any = proto.Field( + proto.MESSAGE, + number=1, + message=any_pb2.Any, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/firestore_v1/types/firestore.py b/google/cloud/firestore_v1/types/firestore.py index 22388676f..c2d9f8475 100644 --- a/google/cloud/firestore_v1/types/firestore.py +++ b/google/cloud/firestore_v1/types/firestore.py @@ -22,6 +22,8 @@ from google.cloud.firestore_v1.types import aggregation_result from google.cloud.firestore_v1.types import common from google.cloud.firestore_v1.types import document as gf_document +from google.cloud.firestore_v1.types import explain_stats as gf_explain_stats +from google.cloud.firestore_v1.types import pipeline from google.cloud.firestore_v1.types import query as gf_query from google.cloud.firestore_v1.types import query_profile from google.cloud.firestore_v1.types import write @@ -48,6 +50,8 @@ "RollbackRequest", "RunQueryRequest", "RunQueryResponse", + "ExecutePipelineRequest", + "ExecutePipelineResponse", "RunAggregationQueryRequest", "RunAggregationQueryResponse", "PartitionQueryRequest", @@ -835,6 +839,145 @@ class RunQueryResponse(proto.Message): ) +class ExecutePipelineRequest(proto.Message): + r"""The request for + [Firestore.ExecutePipeline][google.firestore.v1.Firestore.ExecutePipeline]. + + This message has `oneof`_ fields (mutually exclusive fields). + For each oneof, at most one member field can be set at the same time. + Setting any member of the oneof automatically clears all other + members. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + database (str): + Required. Database identifier, in the form + ``projects/{project}/databases/{database}``. + structured_pipeline (google.cloud.firestore_v1.types.StructuredPipeline): + A pipelined operation. + + This field is a member of `oneof`_ ``pipeline_type``. + transaction (bytes): + Run the query within an already active + transaction. + The value here is the opaque transaction ID to + execute the query in. + + This field is a member of `oneof`_ ``consistency_selector``. + new_transaction (google.cloud.firestore_v1.types.TransactionOptions): + Execute the pipeline in a new transaction. + The identifier of the newly created transaction + will be returned in the first response on the + stream. This defaults to a read-only + transaction. + + This field is a member of `oneof`_ ``consistency_selector``. + read_time (google.protobuf.timestamp_pb2.Timestamp): + Execute the pipeline in a snapshot + transaction at the given time. + This must be a microsecond precision timestamp + within the past one hour, or if Point-in-Time + Recovery is enabled, can additionally be a whole + minute timestamp within the past 7 days. + + This field is a member of `oneof`_ ``consistency_selector``. + """ + + database: str = proto.Field( + proto.STRING, + number=1, + ) + structured_pipeline: pipeline.StructuredPipeline = proto.Field( + proto.MESSAGE, + number=2, + oneof="pipeline_type", + message=pipeline.StructuredPipeline, + ) + transaction: bytes = proto.Field( + proto.BYTES, + number=5, + oneof="consistency_selector", + ) + new_transaction: common.TransactionOptions = proto.Field( + proto.MESSAGE, + number=6, + oneof="consistency_selector", + message=common.TransactionOptions, + ) + read_time: timestamp_pb2.Timestamp = proto.Field( + proto.MESSAGE, + number=7, + oneof="consistency_selector", + message=timestamp_pb2.Timestamp, + ) + + +class ExecutePipelineResponse(proto.Message): + r"""The response for [Firestore.Execute][]. + + Attributes: + transaction (bytes): + Newly created transaction identifier. + + This field is only specified as part of the first response + from the server, alongside the ``results`` field when the + original request specified + [ExecuteRequest.new_transaction][]. + results (MutableSequence[google.cloud.firestore_v1.types.Document]): + An ordered batch of results returned executing a pipeline. + + The batch size is variable, and can even be zero for when + only a partial progress message is returned. + + The fields present in the returned documents are only those + that were explicitly requested in the pipeline, this include + those like [``__name__``][google.firestore.v1.Document.name] + & + [``__update_time__``][google.firestore.v1.Document.update_time]. + This is explicitly a divergence from ``Firestore.RunQuery`` + / ``Firestore.GetDocument`` RPCs which always return such + fields even when they are not specified in the + [``mask``][google.firestore.v1.DocumentMask]. + execution_time (google.protobuf.timestamp_pb2.Timestamp): + The time at which the document(s) were read. + + This may be monotonically increasing; in this case, the + previous documents in the result stream are guaranteed not + to have changed between their ``execution_time`` and this + one. + + If the query returns no results, a response with + ``execution_time`` and no ``results`` will be sent, and this + represents the time at which the operation was run. + explain_stats (google.cloud.firestore_v1.types.ExplainStats): + Query explain stats. + Contains all metadata related to pipeline + planning and execution, specific contents depend + on the supplied pipeline options. + """ + + transaction: bytes = proto.Field( + proto.BYTES, + number=1, + ) + results: MutableSequence[gf_document.Document] = proto.RepeatedField( + proto.MESSAGE, + number=2, + message=gf_document.Document, + ) + execution_time: timestamp_pb2.Timestamp = proto.Field( + proto.MESSAGE, + number=3, + message=timestamp_pb2.Timestamp, + ) + explain_stats: gf_explain_stats.ExplainStats = proto.Field( + proto.MESSAGE, + number=4, + message=gf_explain_stats.ExplainStats, + ) + + class RunAggregationQueryRequest(proto.Message): r"""The request for [Firestore.RunAggregationQuery][google.firestore.v1.Firestore.RunAggregationQuery]. diff --git a/google/cloud/firestore_v1/types/pipeline.py b/google/cloud/firestore_v1/types/pipeline.py new file mode 100644 index 000000000..0aed187cf --- /dev/null +++ b/google/cloud/firestore_v1/types/pipeline.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +import proto # type: ignore + +from google.cloud.firestore_v1.types import document + + +__protobuf__ = proto.module( + package="google.firestore.v1", + manifest={ + "StructuredPipeline", + }, +) + + +class StructuredPipeline(proto.Message): + r"""A Firestore query represented as an ordered list of operations / + stages. + + This is considered the top-level function which plans & executes a + query. It is logically equivalent to ``query(stages, options)``, but + prevents the client from having to build a function wrapper. + + Attributes: + pipeline (google.cloud.firestore_v1.types.Pipeline): + Required. The pipeline query to execute. + options (MutableMapping[str, google.cloud.firestore_v1.types.Value]): + Optional. Optional query-level arguments. + """ + + pipeline: document.Pipeline = proto.Field( + proto.MESSAGE, + number=1, + message=document.Pipeline, + ) + options: MutableMapping[str, document.Value] = proto.MapField( + proto.STRING, + proto.MESSAGE, + number=2, + message=document.Value, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/firestore_v1/types/query.py b/google/cloud/firestore_v1/types/query.py index 2fda44ebe..3e53208aa 100644 --- a/google/cloud/firestore_v1/types/query.py +++ b/google/cloud/firestore_v1/types/query.py @@ -44,6 +44,7 @@ class StructuredQuery(proto.Message): 4. order_by + start_at + end_at 5. offset 6. limit + 7. find_nearest Attributes: select (google.cloud.firestore_v1.types.StructuredQuery.Projection): From 45c83aa8a3b3f8e6c1a5819b31362963ec5195bd Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 26 Feb 2025 16:27:07 -0800 Subject: [PATCH 015/131] added pyyaml to system test dependencies --- noxfile.py | 1 + 1 file changed, 1 insertion(+) diff --git a/noxfile.py b/noxfile.py index 41f545a68..1a5625a30 100644 --- a/noxfile.py +++ b/noxfile.py @@ -62,6 +62,7 @@ SYSTEM_TEST_EXTERNAL_DEPENDENCIES: List[str] = [ "pytest-asyncio==0.21.2", "six", + "pyyaml", ] SYSTEM_TEST_LOCAL_DEPENDENCIES: List[str] = [] SYSTEM_TEST_DEPENDENCIES: List[str] = [] From fc57f04f75bb63a41062902332c624a60e2f852e Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 27 Jan 2025 16:30:03 -0800 Subject: [PATCH 016/131] added Expr methods --- .../firestore_v1/pipeline_expressions.py | 200 +++++++++++++++++- google/cloud/firestore_v1/pipeline_stages.py | 4 +- 2 files changed, 201 insertions(+), 3 deletions(-) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index f5f6824fb..49c534418 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -1,4 +1,4 @@ -from typing import Any, Iterable, List, Mapping +from typing import Any, Iterable, List, Mapping, Union class Ordering: @@ -7,6 +7,14 @@ def __init__(self, expr, order_dir): self.expr = expr self.order_dir = order_dir + @staticmethod + def ascending(expr): + return Ordering(expr, "asc") + + @staticmethod + def descending(expr): + return Ordering(expr, "desc") + class Expr: """Represents an expression that can be evaluated to a value within the execution of a pipeline. @@ -16,11 +24,197 @@ def __repr__(self): items = ("%s = %r" % (k, v) for k, v in self.__dict__.items()) return "<%s: {%s}>" % (self.__class__.__name__, ', '.join(items)) + @staticmethod + def _cast_to_expr_or_convert_to_constant(o: Any) -> "Expr": + return o if isinstance(o, Expr) else Constant(o) + + def add(self, other: Any) -> "Add": + return Add(self, self._cast_to_expr_or_convert_to_constant(other)) + + def subtract(self, other: Any) -> "Subtract": + return Subtract(self, self._cast_to_expr_or_convert_to_constant(other)) + + def multiply(self, other: Any) -> "Multiply": + return Multiply(self, self._cast_to_expr_or_convert_to_constant(other)) + + def divide(self, other: Any) -> "Divide": + return Divide(self, self._cast_to_expr_or_convert_to_constant(other)) + + def mod(self, other: Any) -> "Mod": + return Mod(self, self._cast_to_expr_or_convert_to_constant(other)) + + def logical_max(self, other: Any) -> "LogicalMax": + return LogicalMax(self, self._cast_to_expr_or_convert_to_constant(other)) + + def logical_min(self, other: Any) -> "LogicalMin": + return LogicalMin(self, self._cast_to_expr_or_convert_to_constant(other)) + + def eq(self, other: Any) -> "Eq": + return Eq(self, self._cast_to_expr_or_convert_to_constant(other)) + + def neq(self, other: Any) -> "Neq": + return Neq(self, self._cast_to_expr_or_convert_to_constant(other)) + + def gt(self, other: Any) -> "Gt": + return Gt(self, self._cast_to_expr_or_convert_to_constant(other)) + + def gte(self, other: Any) -> "Gte": + return Gte(self, self._cast_to_expr_or_convert_to_constant(other)) + + def lt(self, other: Any) -> "Lt": + return Lt(self, self._cast_to_expr_or_convert_to_constant(other)) + + def lte(self, other: Any) -> "Lte": + return Lte(self, self._cast_to_expr_or_convert_to_constant(other)) + + def in_(self, *others: Any) -> "In": + return In(self, ListOfExprs([self._cast_to_expr_or_convert_to_constant(o) for o in others])) + + def not_in(self, *others: Any) -> "Not": + return Not(self.in_(*others)) + + def array_concat(self, array: List[Any]) -> "ArrayConcat": + return ArrayConcat(self, ListOfExprs([self._cast_to_expr_or_convert_to_constant(o) for o in array])) + + def array_contains(self, element: Any) -> "ArrayContains": + return ArrayContains(self, self._cast_to_expr_or_convert_to_constant(element)) + + def array_contains_all(self, elements: List[Any]) -> "ArrayContainsAll": + return ArrayContainsAll(self, ListOfExprs([self._cast_to_expr_or_convert_to_constant(e) for e in elements])) + + def array_contains_any(self, elements: List[Any]) -> "ArrayContainsAny": + return ArrayContainsAny(self, ListOfExprs([self._cast_to_expr_or_convert_to_constant(e) for e in elements])) + + def array_length(self) -> "ArrayLength": + return ArrayLength(self) + + def array_reverse(self) -> "ArrayReverse": + return ArrayReverse(self) + + def is_nan(self) -> "IsNaN": + return IsNaN(self) + + def exists(self) -> "Exists": + return Exists(self) + + def sum(self) -> "Sum": + return Sum(self, False) + + def avg(self) -> "Avg": + return Avg(self, False) + + def count(self) -> "Count": + return Count(self) + + def min(self) -> "Min": + return Min(self, False) + + def max(self) -> "Max": + return Max(self, False) + + def char_length(self) -> "CharLength": + return CharLength(self) + + def byte_length(self) -> "ByteLength": + return ByteLength(self) + + def like(self, pattern: Any) -> "Like": + return Like(self, self._cast_to_expr_or_convert_to_constant(pattern)) + + def regex_contains(self, regex: Any) -> "RegexContains": + return RegexContains(self, self._cast_to_expr_or_convert_to_constant(regex)) + + def regex_matches(self, regex: Any) -> "RegexMatch": + return RegexMatch(self, self._cast_to_expr_or_convert_to_constant(regex)) + + def str_contains(self, substring: Any) -> "StrContains": + return StrContains(self, self._cast_to_expr_or_convert_to_constant(substring)) + + def starts_with(self, prefix: Any) -> "StartsWith": + return StartsWith(self, self._cast_to_expr_or_convert_to_constant(prefix)) + + def ends_with(self, postfix: Any) -> "EndsWith": + return EndsWith(self, self._cast_to_expr_or_convert_to_constant(postfix)) + + def str_concat(self, *elements: Any) -> "StrConcat": + return StrConcat(*[self._cast_to_expr_or_convert_to_constant(el) for el in elements]) + + def to_lower(self) -> "ToLower": + return ToLower(self) + + def to_upper(self) -> "ToUpper": + return ToUpper(self) + + def trim(self) -> "Trim": + return Trim(self) + + def reverse(self) -> "Reverse": + return Reverse(self) + + def replace_first(self, find: Any, replace: Any) -> "ReplaceFirst": + return ReplaceFirst(self, self._cast_to_expr_or_convert_to_constant(find), self._cast_to_expr_or_convert_to_constant(replace)) + + def replace_all(self, find: Any, replace: Any) -> "ReplaceAll": + return ReplaceAll(self, self._cast_to_expr_or_convert_to_constant(find), self._cast_to_expr_or_convert_to_constant(replace)) + + def map_get(self, key: str) -> "MapGet": + return MapGet(self, key) + + def cosine_distance(self, other: Any) -> "CosineDistance": + return CosineDistance(self, self._cast_to_expr_or_convert_to_constant(other)) + + def euclidean_distance(self, other: Any) -> "EuclideanDistance": + return EuclideanDistance(self, self._cast_to_expr_or_convert_to_constant(other)) + + def dot_product(self, other: Any) -> "DotProduct": + return DotProduct(self, self._cast_to_expr_or_convert_to_constant(other)) + + def vector_length(self) -> "VectorLength": + return VectorLength(self) + + def timestamp_to_unix_micros(self) -> "TimestampToUnixMicros": + return TimestampToUnixMicros(self) + + def unix_micros_to_timestamp(self) -> "UnixMicrosToTimestamp": + return UnixMicrosToTimestamp(self) + + def timestamp_to_unix_millis(self) -> "TimestampToUnixMillis": + return TimestampToUnixMillis(self) + + def unix_millis_to_timestamp(self) -> "UnixMillisToTimestamp": + return UnixMillisToTimestamp(self) + + def timestamp_to_unix_seconds(self) -> "TimestampToUnixSeconds": + return TimestampToUnixSeconds(self) + + def unix_seconds_to_timestamp(self) -> "UnixSecondsToTimestamp": + return UnixSecondsToTimestamp(self) + + def timestamp_add(self, unit: Any, amount: Any) -> "TimestampAdd": + return TimestampAdd(self, self._cast_to_expr_or_convert_to_constant(unit), self._cast_to_expr_or_convert_to_constant(amount)) + + def timestamp_sub(self, unit: Any, amount: Any) -> "TimestampSub": + return TimestampSub(self, self._cast_to_expr_or_convert_to_constant(unit), self._cast_to_expr_or_convert_to_constant(amount)) + + def ascending(self) -> Ordering: + return Ordering.ascending(self) + + def descending(self) -> Ordering: + return Ordering.descending(self) + + def as_(self, alias: str) -> "ExprWithAlias": + return ExprWithAlias(self, alias) + class Constant(Expr): def __init__(self, value: Any): self.value = value + @staticmethod + def of(value): + return Constant(value) + + class ListOfExprs(Expr): def __init__(self, exprs: List[Expr]): @@ -286,6 +480,10 @@ class Field(Expr, Selectable): def __init__(self, path: str): self.path = path + @staticmethod + def of(path: str) + return Field(path) + def _to_map(self): return self.path, self diff --git a/google/cloud/firestore_v1/pipeline_stages.py b/google/cloud/firestore_v1/pipeline_stages.py index be24a7081..374b4c4aa 100644 --- a/google/cloud/firestore_v1/pipeline_stages.py +++ b/google/cloud/firestore_v1/pipeline_stages.py @@ -50,8 +50,8 @@ def __init__(self, *fields: Selectable): class Aggregate(Stage): def __init__( self, - groups={}, - accumulators={} + groups:dict[str, Expr] | None = None, + accumulators:dict[str, Accumulator] | None = None, ): super().__init__() self.groups = groups or {} From 0224f940688f21bf1b3547f3af8ebd34f361fbdd Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 27 Jan 2025 17:00:40 -0800 Subject: [PATCH 017/131] trying to improve accumulator api --- google/cloud/firestore_v1/pipeline_stages.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/google/cloud/firestore_v1/pipeline_stages.py b/google/cloud/firestore_v1/pipeline_stages.py index 374b4c4aa..4a71db5a1 100644 --- a/google/cloud/firestore_v1/pipeline_stages.py +++ b/google/cloud/firestore_v1/pipeline_stages.py @@ -50,8 +50,8 @@ def __init__(self, *fields: Selectable): class Aggregate(Stage): def __init__( self, - groups:dict[str, Expr] | None = None, - accumulators:dict[str, Accumulator] | None = None, + *accumulators: AccumulatorTarget, + groups: list[str | Selectable] | None = None, ): super().__init__() self.groups = groups or {} From cdee9376842a0a8bf8936d473a9793086859a13d Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 28 Jan 2025 11:12:59 -0800 Subject: [PATCH 018/131] improved aggregate/accumulators --- .../firestore_v1/pipeline_expressions.py | 19 +++-------- google/cloud/firestore_v1/pipeline_stages.py | 16 ++++++--- tests/system/pipeline_e2e.yaml | 34 +++++++++---------- 3 files changed, 32 insertions(+), 37 deletions(-) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index 49c534418..0b2d76439 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -1,4 +1,4 @@ -from typing import Any, Iterable, List, Mapping, Union +from typing import Any, Iterable, List, Mapping, Union, Generic, TypeVar class Ordering: @@ -455,18 +455,9 @@ def _to_map(self): raise NotImplementedError -class AccumulatorTarget(Selectable): - def __init__(self, accumulator: Accumulator, field_name: str, distinct: bool=False): - self.accumulator = accumulator - self.field_name = field_name - self.distinct = distinct - - def _to_map(self): - return self.field_name, self.accumulator - - -class ExprWithAlias(Expr, Selectable): - def __init__(self, expr: Expr, alias: str): +T = TypeVar('T', bound=Expr) +class ExprWithAlias(Expr, Selectable, Generic[T]): + def __init__(self, expr: T, alias: str): self.expr = expr self.alias = alias @@ -481,7 +472,7 @@ def __init__(self, path: str): self.path = path @staticmethod - def of(path: str) + def of(path: str): return Field(path) def _to_map(self): diff --git a/google/cloud/firestore_v1/pipeline_stages.py b/google/cloud/firestore_v1/pipeline_stages.py index 4a71db5a1..86715356b 100644 --- a/google/cloud/firestore_v1/pipeline_stages.py +++ b/google/cloud/firestore_v1/pipeline_stages.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Any, Dict, Iterable, List, Optional +from typing import Any, Dict, Iterable, List, Optional, Sequence from enum import Enum from google.cloud.firestore_v1.pipeline_expressions import ( @@ -50,12 +50,18 @@ def __init__(self, *fields: Selectable): class Aggregate(Stage): def __init__( self, - *accumulators: AccumulatorTarget, - groups: list[str | Selectable] | None = None, + *extra_accumulators: ExprWithAlias[Accumulator], + accumulators: Sequence[ExprWithAlias[Accumulator]] = (), + groups: Sequence[str | Selectable] = (), ): super().__init__() - self.groups = groups or {} - self.accumulators = accumulators or {} + self.groups: dict[str, Expr] = dict( + f._to_map() + if isinstance(f, Selectable) + else (f,Field(f)) + for f in groups + ) + self.accumulators: dict[str, Expr] = dict(f._to_map() for f in [*accumulators, *extra_accumulators]) class Collection(Stage): diff --git a/tests/system/pipeline_e2e.yaml b/tests/system/pipeline_e2e.yaml index eca247be9..adfb25525 100644 --- a/tests/system/pipeline_e2e.yaml +++ b/tests/system/pipeline_e2e.yaml @@ -130,10 +130,9 @@ tests: pipeline: - Collection: books - Aggregate: - accumulators: - - ExprWithAlias: - expr: CountAll - alias: count + - ExprWithAlias: + expr: CountAll + alias: count results: - count: 10 - description: "testAggregates - avg, count, max" @@ -147,20 +146,19 @@ tests: right: Constant: Science Fiction - Aggregate: - accumulators: - - ExprWithAlias: - expr: CountAll - alias: count - - ExprWithAlias: - expr: - Avg: - value: rating - alias: avg_rating - - ExprWithAlias: - expr: - Max: - value: rating - alias: max_rating + - ExprWithAlias: + expr: CountAll + alias: count + - ExprWithAlias: + expr: + Avg: + value: rating + alias: avg_rating + - ExprWithAlias: + expr: + Max: + value: rating + alias: max_rating results: - count: 2 avg_rating: 4.4 From 5d35c540c68c41cb8ba3c60c8f77f0fd47c4a247 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 28 Jan 2025 13:18:36 -0800 Subject: [PATCH 019/131] fixed naming in yaml --- tests/system/pipeline_e2e.yaml | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/tests/system/pipeline_e2e.yaml b/tests/system/pipeline_e2e.yaml index adfb25525..774ae14f1 100644 --- a/tests/system/pipeline_e2e.yaml +++ b/tests/system/pipeline_e2e.yaml @@ -131,7 +131,7 @@ tests: - Collection: books - Aggregate: - ExprWithAlias: - expr: CountAll + expr: Count alias: count results: - count: 10 @@ -147,7 +147,7 @@ tests: Constant: Science Fiction - Aggregate: - ExprWithAlias: - expr: CountAll + expr: Count alias: count - ExprWithAlias: expr: @@ -236,7 +236,7 @@ tests: - Aggregate: accumulators: - ExprWithAlias: - expr: CountAll + expr: Count alias: count - ExprWithAlias: expr: @@ -259,7 +259,7 @@ tests: - title - author - Sort: - - OrderBy: + - Ordering: - Field: author - ASCENDING results: @@ -311,7 +311,7 @@ tests: - Field: genre - Field: nestedField # Field does not exist, should be ignored - Sort: - - OrderBy: + - Ordering: - Field: author_title - ASCENDING results: @@ -390,7 +390,7 @@ tests: pipeline: - Collection: books - Sort: - - OrderBy: + - Ordering: - Field: author - ASCENDING - Offset: 5 @@ -529,7 +529,7 @@ tests: - Select: - title - Sort: - - OrderBy: + - Ordering: - Field: title - ASCENDING results: @@ -548,7 +548,7 @@ tests: - Select: - title - Sort: - - OrderBy: + - Ordering: - Field: title - DESCENDING results: @@ -821,10 +821,9 @@ tests: - rating - title - Sort: - - OrderBy: - expr: - Field: title - direction: ASCENDING + - Ordering: + - Field: title + - direction: ASCENDING results: - rating: 4.3 title: Crime and Punishment From 11373e33b33af5bc8473c1f527bbb57ddcaa551b Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 28 Jan 2025 13:19:08 -0800 Subject: [PATCH 020/131] create objects in parsing code --- tests/system/test_pipeline_acceptance.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/system/test_pipeline_acceptance.py b/tests/system/test_pipeline_acceptance.py index 03e881605..231cabb2d 100644 --- a/tests/system/test_pipeline_acceptance.py +++ b/tests/system/test_pipeline_acceptance.py @@ -87,14 +87,14 @@ def parse_expressions(yaml_element: Any): if len(yaml_element) == 1 and isinstance(list(yaml_element)[0], str) and hasattr(pipeline_expressions, list(yaml_element)[0]): # build pipeline expressions if possible cls_str = list(yaml_element)[0] - cls = parse_expressions(cls_str) + cls = getattr(pipeline_expressions, cls_str) yaml_args = yaml_element[cls_str] return _apply_yaml_args(cls, yaml_args) else: # otherwise, return dict return {parse_expressions(k): parse_expressions(v) for k,v in yaml_element.items()} elif isinstance(yaml_element, str) and hasattr(pipeline_expressions, yaml_element): - return getattr(pipeline_expressions, yaml_element) + return getattr(pipeline_expressions, yaml_element)() else: return yaml_element From d2b4153e59bc69e81d80968d4494083f580a5597 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 28 Jan 2025 13:25:47 -0800 Subject: [PATCH 021/131] use order enum --- .../cloud/firestore_v1/pipeline_expressions.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index 0b2d76439..985dfd6e6 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -1,19 +1,18 @@ from typing import Any, Iterable, List, Mapping, Union, Generic, TypeVar +from enum import Enum +from enum import auto +class OrderingDirection(Enum): + ASCENDING = auto() + DESCENDING = auto() + class Ordering: - def __init__(self, expr, order_dir): + def __init__(self, expr, order_dir: OrderingDirection | str): self.expr = expr - self.order_dir = order_dir - - @staticmethod - def ascending(expr): - return Ordering(expr, "asc") + self.order_dir = OrderingDirection[order_dir] if isinstance(order_dir, str) else order_dir - @staticmethod - def descending(expr): - return Ordering(expr, "desc") class Expr: """Represents an expression that can be evaluated to a value within the From bfdfba3961323addfda9d740a658201a5a86a777 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 28 Jan 2025 13:28:02 -0800 Subject: [PATCH 022/131] standardize how I deal with map stages --- google/cloud/firestore_v1/pipeline_stages.py | 64 +++++++++++--------- 1 file changed, 36 insertions(+), 28 deletions(-) diff --git a/google/cloud/firestore_v1/pipeline_stages.py b/google/cloud/firestore_v1/pipeline_stages.py index 86715356b..c72ca5525 100644 --- a/google/cloud/firestore_v1/pipeline_stages.py +++ b/google/cloud/firestore_v1/pipeline_stages.py @@ -44,7 +44,10 @@ def __init__(self, custom_name: Optional[str] = None): class AddFields(Stage): def __init__(self, *fields: Selectable): super().__init__("add_fields") - self.fields = dict(f._to_map() for f in fields) + self.fields = list(fields) + + def _fields_map(self) -> dict[str, Expr]: + return dict(f._to_map() for f in self.fields) class Aggregate(Stage): @@ -55,13 +58,16 @@ def __init__( groups: Sequence[str | Selectable] = (), ): super().__init__() - self.groups: dict[str, Expr] = dict( - f._to_map() - if isinstance(f, Selectable) - else (f,Field(f)) - for f in groups - ) - self.accumulators: dict[str, Expr] = dict(f._to_map() for f in [*accumulators, *extra_accumulators]) + self.groups: list[Selectable] = [Field(f) if isinstance(f, str) else f for f in groups] + self.accumulators: list[ExprWithAlias[Accumulator]] = [*accumulators, *extra_accumulators] + + @property + def _group_map(self) -> dict[str, Expr]: + return dict(f._to_map() for f in self.groups) + + @property + def _accumulators_map(self) -> dict[str, Expr]: + return dict(f._to_map() for f in self.accumulators) class Collection(Stage): @@ -86,18 +92,22 @@ def __init__(self): class Distinct(Stage): def __init__(self, *fields: str | Selectable): super().__init__() - self.fields = dict( + self.fields: list[Selectable] = [Field(f) if isinstance(f, str) else f for f in fields] + + @property + def _fields_dict(self) -> dict[str, Selectable]: + return dict( f._to_map() if isinstance(f, Selectable) else (f,Field(f)) - for f in fields + for f in self.fields ) class Documents(Stage): - def __init__(self, documents: List[str]): + def __init__(self, *documents: str): super().__init__() - self.documents = documents + self.documents = list(documents) @staticmethod def of(*documents: "DocumentReference") -> "Documents": @@ -121,9 +131,9 @@ def __init__( class GenericStage(Stage): - def __init__(self, name: str, params: List[Any]): + def __init__(self, name: str, *params: Any): super().__init__(name) - self.params = params + self.params = list(params) class Limit(Stage): @@ -141,12 +151,11 @@ def __init__(self, offset: int): class RemoveFields(Stage): def __init__(self, *fields: str | Field): super().__init__("remove_fields") - self.fields = dict( - f._to_map() - if isinstance(f, Selectable) - else (f,Field(f)) - for f in fields - ) + self.fields = [Field(f) if isinstance(f, str) else f for f in fields] + + @property + def _fields_map(self) -> dict[str, Field]: + dict(f._to_map() for f in self.fields) class Replace(Stage): @@ -170,20 +179,19 @@ def __init__(self, options: "SampleOptions"): class Select(Stage): def __init__(self, *fields: str | Selectable): super().__init__() - self.projections = dict( - f._to_map() - if isinstance(f, Selectable) - else (f,Field(f)) - for f in fields - ) + self.projections = [Field(f) if isinstance(f, str) else f for f in fields] + + @property + def _projections_map(self) -> dict[str, Expr]: + return dict(f._to_map() for f in self.projections) class Sort(Stage): - def __init__(self, orders: List["Ordering"]): + def __init__(self, *orders: "Ordering"): super().__init__() - self.orders = orders + self.orders = list(orders) class Union(Stage): From afa823a6c34cb083a87083bd6ff873150382cadc Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 28 Jan 2025 13:28:32 -0800 Subject: [PATCH 023/131] fix broke super.__init__ calls --- google/cloud/firestore_v1/pipeline_expressions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index 985dfd6e6..1d9e3d14d 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -439,12 +439,12 @@ def __init__(self, value: Expr, distinct: bool=False): class Count(Accumulator): def __init__(self, value: Expr = None): - super(Function, self).__init__("count", [value] if value else []) + super().__init__("count", [value] if value else []) class CountIf(Function): def __init__(self, value: Expr, distinct: bool=False): - super(Function, self).__init__("countif", [value] if value else []) + super().__init__("countif", [value] if value else []) class Selectable: From 33481276ef1644ac6f9f6730f04d68cd0dd70f9c Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 28 Jan 2025 13:28:51 -0800 Subject: [PATCH 024/131] added repr for custom classes --- .../firestore_v1/pipeline_expressions.py | 20 +++++++++++++++++-- google/cloud/firestore_v1/pipeline_stages.py | 14 +++++++++++++ tests/system/test_pipeline_acceptance.py | 1 + 3 files changed, 33 insertions(+), 2 deletions(-) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index 1d9e3d14d..cbce37dc8 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -13,6 +13,12 @@ def __init__(self, expr, order_dir: OrderingDirection | str): self.expr = expr self.order_dir = OrderingDirection[order_dir] if isinstance(order_dir, str) else order_dir + def __repr__(self): + if self.order_dir is OrderingDirection.ASCENDING: + order_str = ".ascending()" + else: + order_str = ".descending()" + return f"{self.expr!r}{order_str}" class Expr: """Represents an expression that can be evaluated to a value within the @@ -20,8 +26,7 @@ class Expr: """ def __repr__(self): - items = ("%s = %r" % (k, v) for k, v in self.__dict__.items()) - return "<%s: {%s}>" % (self.__class__.__name__, ', '.join(items)) + return f"{self.__class__.__name__}()" @staticmethod def _cast_to_expr_or_convert_to_constant(o: Any) -> "Expr": @@ -213,6 +218,9 @@ def __init__(self, value: Any): def of(value): return Constant(value) + def __repr__(self): + return f"Constant.of({self.value!r})" + class ListOfExprs(Expr): @@ -227,6 +235,8 @@ def __init__(self, name: str, params: List[Expr]): self.name = name self.params = params + def __repr__(self): + return f"{self.__class__.__name__}({', '.join([repr(p) for p in self.params])})" class Divide(Function): def __init__(self, left: Expr, right: Expr): @@ -463,6 +473,9 @@ def __init__(self, expr: T, alias: str): def _to_map(self): return self.alias, self.expr + def __repr__(self): + return f"{self.expr}.as('{self.alias}')" + class Field(Expr, Selectable): DOCUMENT_ID = "__name__" @@ -477,6 +490,9 @@ def of(path: str): def _to_map(self): return self.path, self + def __repr__(self): + return f"Field.of({self.path!r})" + class FilterCondition(Function): """Filters the given data in some way.""" diff --git a/google/cloud/firestore_v1/pipeline_stages.py b/google/cloud/firestore_v1/pipeline_stages.py index c72ca5525..77639096f 100644 --- a/google/cloud/firestore_v1/pipeline_stages.py +++ b/google/cloud/firestore_v1/pipeline_stages.py @@ -40,6 +40,10 @@ class Stage: def __init__(self, custom_name: Optional[str] = None): self.name = custom_name or type(self).__name__.lower() + def __repr__(self): + items = ("%s=%r" % (k, v) for k, v in self.__dict__.items() if k != "name") + return f"{self.__class__.__name__}({', '.join(items)})" + class AddFields(Stage): def __init__(self, *fields: Selectable): @@ -70,6 +74,16 @@ def _accumulators_map(self) -> dict[str, Expr]: return dict(f._to_map() for f in self.accumulators) + def __repr__(self): + accumulator_str = ', '.join(repr(v) for v in self.accumulators) + group_str = "" + if self.groups: + if self.accumulators: + group_str = ", " + group_str += f"groups={self.groups}" + return f"{self.__class__.__name__}({accumulator_str}{group_str})" + + class Collection(Stage): def __init__(self, path: str): super().__init__() diff --git a/tests/system/test_pipeline_acceptance.py b/tests/system/test_pipeline_acceptance.py index 231cabb2d..ad96019ab 100644 --- a/tests/system/test_pipeline_acceptance.py +++ b/tests/system/test_pipeline_acceptance.py @@ -104,6 +104,7 @@ def parse_expressions(yaml_element: Any): ) def test_e2e_scenario(test_dict): pipeline = parse_pipeline(test_dict["pipeline"]) + print(pipeline) # before_ast = ast.parse(test_dict["before"]) # got_ast = before_ast From e3c995dfeab20bcc75d321833df4ce739b7950c4 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 28 Jan 2025 13:40:14 -0800 Subject: [PATCH 025/131] added repr to pipeline --- google/cloud/firestore_v1/pipeline.py | 13 +++++++++++-- tests/system/test_pipeline_acceptance.py | 4 ++-- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/google/cloud/firestore_v1/pipeline.py b/google/cloud/firestore_v1/pipeline.py index ec8c5c9bb..d3b959da8 100644 --- a/google/cloud/firestore_v1/pipeline.py +++ b/google/cloud/firestore_v1/pipeline.py @@ -13,8 +13,17 @@ class Pipeline: - def __init__(self): - self.stages = [] + def __init__(self, *stages: stages.Stage): + self.stages = list(stages) + + def __repr__(self): + if not self.stages: + return "Pipeline()" + elif len(self.stages) == 1: + return f"Pipeline({self.stages[0]!r})" + else: + stages_str = ",\n ".join([repr(s) for s in self.stages]) + return f"Pipeline(\n {stages_str}\n)" def add_fields(self, fields: Dict[str, Expr]) -> Pipeline: self.stages.append(stages.AddFields(fields)) diff --git a/tests/system/test_pipeline_acceptance.py b/tests/system/test_pipeline_acceptance.py index ad96019ab..ff397087d 100644 --- a/tests/system/test_pipeline_acceptance.py +++ b/tests/system/test_pipeline_acceptance.py @@ -8,6 +8,7 @@ # from google.cloud.firestore_v1.pipeline_stages import * from google.cloud.firestore_v1 import pipeline_stages from google.cloud.firestore_v1 import pipeline_expressions +from google.cloud.firestore_v1.pipeline import Pipeline from google.cloud.firestore import Client @@ -17,7 +18,6 @@ test_dir_name = os.path.dirname(__file__) - def loader(): # load test cases with open(f"{test_dir_name}/pipeline_e2e.yaml") as f: @@ -77,7 +77,7 @@ def parse_pipeline(pipeline: list[dict[str, Any], str]): # yaml has no arguments stage_obj = stage_cls() result_list.append(stage_obj) - return result_list + return Pipeline(*result_list) def parse_expressions(yaml_element: Any): From 2d604e156b97c5dbdb1789c07acd14b6a47182d4 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 28 Jan 2025 13:47:29 -0800 Subject: [PATCH 026/131] fixed vector formatting --- tests/system/pipeline_e2e.yaml | 36 +++++++++++++++++----------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/tests/system/pipeline_e2e.yaml b/tests/system/pipeline_e2e.yaml index 774ae14f1..0f9dcc73f 100644 --- a/tests/system/pipeline_e2e.yaml +++ b/tests/system/pipeline_e2e.yaml @@ -965,32 +965,32 @@ tests: - ExprWithAlias: expr: CosineDistance: - - Vector: - - Constant: 0.1 - - Constant: 0.1 - - Vector: - - Constant: 0.5 - - Constant: 0.8 + vector1: + - 0.1 + - 0.1 + vector2: + - 0.5 + - 0.8 alias: cosineDistance - ExprWithAlias: expr: DotProduct: - - Vector: - - Constant: 0.1 - - Constant: 0.1 - - Vector: - - Constant: 0.5 - - Constant: 0.8 + vector1: + - 0.1 + - 0.1 + vector2: + - 0.5 + - 0.8 alias: dotProductDistance - ExprWithAlias: expr: EuclideanDistance: - - Vector: - - Constant: 0.1 - - Constant: 0.1 - - Vector: - - Constant: 0.5 - - Constant: 0.8 + vector1: + - 0.1 + - 0.1 + vector2: + - 0.5 + - 0.8 alias: euclideanDistance - Limit: 1 results: From 3bed8074fe17e36556e2b2072ac2de3602a9bdb8 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 28 Jan 2025 14:37:11 -0800 Subject: [PATCH 027/131] only treat capitalized strings as possible exprs --- tests/system/test_pipeline_acceptance.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/system/test_pipeline_acceptance.py b/tests/system/test_pipeline_acceptance.py index ff397087d..6ea6a1884 100644 --- a/tests/system/test_pipeline_acceptance.py +++ b/tests/system/test_pipeline_acceptance.py @@ -79,12 +79,16 @@ def parse_pipeline(pipeline: list[dict[str, Any], str]): result_list.append(stage_obj) return Pipeline(*result_list) +def _is_expr_string(yaml_str): + return isinstance(yaml_str, str) and \ + yaml_str[0].isupper() and \ + hasattr(pipeline_expressions, yaml_str) def parse_expressions(yaml_element: Any): if isinstance(yaml_element, list): return [parse_expressions(v) for v in yaml_element] elif isinstance(yaml_element, dict): - if len(yaml_element) == 1 and isinstance(list(yaml_element)[0], str) and hasattr(pipeline_expressions, list(yaml_element)[0]): + if len(yaml_element) == 1 and _is_expr_string(list(yaml_element)[0]): # build pipeline expressions if possible cls_str = list(yaml_element)[0] cls = getattr(pipeline_expressions, cls_str) @@ -93,7 +97,7 @@ def parse_expressions(yaml_element: Any): else: # otherwise, return dict return {parse_expressions(k): parse_expressions(v) for k,v in yaml_element.items()} - elif isinstance(yaml_element, str) and hasattr(pipeline_expressions, yaml_element): + elif _is_expr_string(yaml_element): return getattr(pipeline_expressions, yaml_element)() else: return yaml_element From 9b9eaa87b493f87f95c3af16a047237607b2edbf Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 28 Jan 2025 14:54:35 -0800 Subject: [PATCH 028/131] Where uses positional args in yaml --- tests/system/pipeline_e2e.yaml | 88 +++++++++++----------------------- 1 file changed, 29 insertions(+), 59 deletions(-) diff --git a/tests/system/pipeline_e2e.yaml b/tests/system/pipeline_e2e.yaml index 0f9dcc73f..ed6b4d0f3 100644 --- a/tests/system/pipeline_e2e.yaml +++ b/tests/system/pipeline_e2e.yaml @@ -139,8 +139,7 @@ tests: pipeline: - Collection: books - Where: - condition: - Eq: + - Eq: left: Field: genre right: @@ -167,8 +166,7 @@ tests: pipeline: - Collection: books - Where: - condition: - Lt: + - Lt: left: Field: published right: @@ -181,8 +179,7 @@ tests: pipeline: - Collection: books - Where: - condition: - Lt: + - Lt: left: Field: published right: @@ -201,8 +198,7 @@ tests: pipeline: - Collection: books - Where: - condition: - Lt: + - Lt: left: Field: published right: @@ -217,8 +213,7 @@ tests: groups: - genre - Where: - condition: - Gt: + - Gt: left: Field: avg_rating right: @@ -339,8 +334,7 @@ tests: pipeline: - Collection: books - Where: - condition: - And: + - And: - Gt: left: Field: rating @@ -368,8 +362,7 @@ tests: pipeline: - Collection: books - Where: - condition: - Or: + - Or: - Eq: left: Field: genre @@ -409,8 +402,7 @@ tests: pipeline: - Collection: books - Where: - condition: - ArrayContains: + - ArrayContains: - tags - comedy results: @@ -430,8 +422,7 @@ tests: pipeline: - Collection: books - Where: - condition: - ArrayContainsAny: + - ArrayContainsAny: array: tags elements: - comedy @@ -445,8 +436,7 @@ tests: pipeline: - Collection: books - Where: - condition: - ArrayContainsAll: + - ArrayContainsAll: array: Field: tags elements: @@ -466,8 +456,7 @@ tests: - Field: tags alias: tagsCount - Where: - condition: - Eq: + - Eq: left: Field: tagsCount right: @@ -522,8 +511,7 @@ tests: pipeline: - Collection: books - Where: - condition: - StartsWith: + - StartsWith: - Field: title - Constant: the - Select: @@ -541,8 +529,7 @@ tests: pipeline: - Collection: books - Where: - condition: - EndsWith: + - EndsWith: - Field: title - Constant: y - Select: @@ -565,8 +552,7 @@ tests: alias: titleLength - title - Where: - condition: - Gt: + - Gt: left: Field: titleLength right: @@ -586,8 +572,7 @@ tests: - Field: title alias: reversed_title - Where: - condition: - Eq: + - Eq: left: Field: author right: @@ -607,8 +592,7 @@ tests: replacement: A alias: replaced_title - Where: - condition: - Eq: + - Eq: left: Field: author right: @@ -628,8 +612,7 @@ tests: replacement: _ alias: replaced_title - Where: - condition: - Eq: + - Eq: left: Field: author right: @@ -646,8 +629,7 @@ tests: - Field: title alias: title_length - Where: - condition: - Eq: + - Eq: left: Field: author right: @@ -666,8 +648,7 @@ tests: - Constant: _银河系漫游指南 alias: title_byte_length - Where: - condition: - Eq: + - Eq: left: Field: author right: @@ -726,8 +707,7 @@ tests: pipeline: - Collection: books - Where: - condition: - Like: + - Like: - Field: title - Constant: "%Guide%" - Select: @@ -738,8 +718,7 @@ tests: pipeline: - Collection: books - Where: - condition: - RegexContains: + - RegexContains: - Field: title - Constant: "(?i)(the|of)" results: @@ -752,8 +731,7 @@ tests: pipeline: - Collection: books - Where: - condition: - RegexMatch: + - RegexMatch: - Field: title - Constant: ".*(?i)(the|of).*" results: @@ -800,8 +778,7 @@ tests: pipeline: - Collection: books - Where: - condition: - And: + - And: - Gt: left: Field: rating @@ -835,8 +812,7 @@ tests: pipeline: - Collection: books - Where: - condition: - Or: + - Or: - And: - Gt: left: @@ -867,7 +843,6 @@ tests: pipeline: - Collection: books - Where: - condition: - Not: - IsNaN: - Field: rating @@ -894,8 +869,7 @@ tests: pipeline: - Collection: books - Where: - condition: - Eq: + - Eq: left: Field: author right: @@ -947,8 +921,7 @@ tests: alias: hugoAward - Field: title - Where: - condition: - Eq: + - Eq: left: Field: hugoAward right: @@ -1001,8 +974,7 @@ tests: pipeline: - Collection: books - Where: - condition: - Eq: + - Eq: left: Field: awards.hugo right: @@ -1019,8 +991,7 @@ tests: pipeline: - Collection: books - Where: - condition: - Eq: + - Eq: left: Field: awards.hugo right: @@ -1087,8 +1058,7 @@ tests: pipeline: - Collection: books - Where: - condition: - Eq: + - Eq: left: Field: title right: From e35accf96ea31cb1646bac5a66897fbdc75061cf Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 28 Jan 2025 15:03:50 -0800 Subject: [PATCH 029/131] support union stage --- tests/system/pipeline_e2e.yaml | 18 ++++++++---------- tests/system/test_pipeline_acceptance.py | 3 +++ 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/tests/system/pipeline_e2e.yaml b/tests/system/pipeline_e2e.yaml index ed6b4d0f3..64e47ee4e 100644 --- a/tests/system/pipeline_e2e.yaml +++ b/tests/system/pipeline_e2e.yaml @@ -999,7 +999,7 @@ tests: - Select: - title - Field: awards.hugo - - __name__ + - Field: "__name__" results: - title: The Hitchhiker's Guide to the Galaxy awards.hugo: true @@ -1045,15 +1045,13 @@ tests: # - # document data # - # document data # - # document data - #- description: testUnion - # pipeline: - # - Union: - # - Collection: books - # - Collection: books - # results: # Results will be duplicated - # - # document data - # - # document data - # # ... 20 results total + - description: testUnion + pipeline: + - Collection: books + - Union: + - Pipeline: + - Collection: books + results_num: 20 # Results will be duplicated - description: testUnnest pipeline: - Collection: books diff --git a/tests/system/test_pipeline_acceptance.py b/tests/system/test_pipeline_acceptance.py index 6ea6a1884..2fa8bab69 100644 --- a/tests/system/test_pipeline_acceptance.py +++ b/tests/system/test_pipeline_acceptance.py @@ -94,6 +94,9 @@ def parse_expressions(yaml_element: Any): cls = getattr(pipeline_expressions, cls_str) yaml_args = yaml_element[cls_str] return _apply_yaml_args(cls, yaml_args) + elif len(yaml_element) == 1 and list(yaml_element)[0] == "Pipeline": + # find Pipeline objects for Union expressions + return parse_pipeline(yaml_element["Pipeline"]) else: # otherwise, return dict return {parse_expressions(k): parse_expressions(v) for k,v in yaml_element.items()} From 72301a1a1eba8a264d3658a36c8c92279ba2e585 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 28 Jan 2025 15:13:11 -0800 Subject: [PATCH 030/131] fixed testReplace --- tests/system/pipeline_e2e.yaml | 37 ++++++++++++++++++---------------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/tests/system/pipeline_e2e.yaml b/tests/system/pipeline_e2e.yaml index 64e47ee4e..5956b8af8 100644 --- a/tests/system/pipeline_e2e.yaml +++ b/tests/system/pipeline_e2e.yaml @@ -1005,23 +1005,6 @@ tests: awards.hugo: true - title: Dune awards.hugo: true -# - description: testReplace -# pipeline: -# - Collection: books -# - Replace: awards -# results: -# - title: The Hitchhiker's Guide to the Galaxy -# author: Douglas Adams -# genre: Science Fiction -# published: 1979 -# rating: 4.2 -# tags: -# - comedy -# - space -# - adventure -# hugo: true -# nebula: false -# # ... other results with replaced awards #- description: testSampleLimit # pipeline: # - Collection: books @@ -1045,6 +1028,26 @@ tests: # - # document data # - # document data # - # document data + - description: testReplace + pipeline: + - Collection: books + - Where: + - Eq: + - Field: title + - "The Hitchhiker's Guide to the Galaxy" + - Replace: awards + results: + - title: The Hitchhiker's Guide to the Galaxy + author: Douglas Adams + genre: Science Fiction + published: 1979 + rating: 4.2 + tags: + - comedy + - space + - adventure + hugo: true + nebula: false - description: testUnion pipeline: - Collection: books From 07d25dcde932d098b52deb55729c9377222ba902 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 28 Jan 2025 15:20:52 -0800 Subject: [PATCH 031/131] fixed sample --- google/cloud/firestore_v1/pipeline.py | 4 +-- google/cloud/firestore_v1/pipeline_stages.py | 22 +++++------- tests/system/pipeline_e2e.yaml | 35 +++++++------------- 3 files changed, 23 insertions(+), 38 deletions(-) diff --git a/google/cloud/firestore_v1/pipeline.py b/google/cloud/firestore_v1/pipeline.py index d3b959da8..1dcbd60e9 100644 --- a/google/cloud/firestore_v1/pipeline.py +++ b/google/cloud/firestore_v1/pipeline.py @@ -64,8 +64,8 @@ def replace( self.stages.append(stages.Replace(field, mode)) return self - def sample(self, options: stages.SampleOptions) -> Pipeline: - self.stages.append(stages.Sample(options)) + def sample(self, n: int, mode: Sample.Mode = Sample.Mode.DOCUMENTS) -> Pipeline: + self.stages.append(stages.Sample(n, mode) return self def union(self, other: Pipeline) -> Pipeline: diff --git a/google/cloud/firestore_v1/pipeline_stages.py b/google/cloud/firestore_v1/pipeline_stages.py index 77639096f..75cb613a7 100644 --- a/google/cloud/firestore_v1/pipeline_stages.py +++ b/google/cloud/firestore_v1/pipeline_stages.py @@ -1,6 +1,7 @@ from __future__ import annotations from typing import Any, Dict, Iterable, List, Optional, Sequence from enum import Enum +from enum import auto from google.cloud.firestore_v1.pipeline_expressions import ( Accumulator, @@ -21,16 +22,6 @@ def __init__( self.distance_field = distance_field -class SampleOptions: - class Mode(Enum): - DOCUMENTS = "documents" - PERCENT = "percent" - - def __init__(self, n: int | float, mode: Mode): - self.n = n - self.mode = mode - - class UnnestOptions: def __init__(self, index_field: str): self.index_field = index_field @@ -185,10 +176,15 @@ def __init__(self, field: Selectable, mode: Mode = Mode.FULL_REPLACE): class Sample(Stage): - def __init__(self, options: "SampleOptions"): - super().__init__() - self.options = options + class Mode(Enum): + DOCUMENTS = auto() + PERCENTAGE = auto() + + def __init__(self, n: int, mode: Mode = Mode.DOCUMENTS): + super().__init__() + self.n = n + self.mode = mode class Select(Stage): def __init__(self, *fields: str | Selectable): diff --git a/tests/system/pipeline_e2e.yaml b/tests/system/pipeline_e2e.yaml index 5956b8af8..75083f68f 100644 --- a/tests/system/pipeline_e2e.yaml +++ b/tests/system/pipeline_e2e.yaml @@ -1005,29 +1005,6 @@ tests: awards.hugo: true - title: Dune awards.hugo: true - #- description: testSampleLimit - # pipeline: - # - Collection: books - # - Sample: - # method: LIMIT - # n: 3 - # results: # Results will vary due to randomness - # - # document data - # - # document data - # - # document data - #- description: testSamplePercentage - # pipeline: - # - Collection: books - # - Sample: - # method: PERCENTAGE - # n: 60 - # results: # Results will vary due to randomness - # - # document data - # - # document data - # - # document data - # - # document data - # - # document data - # - # document data - description: testReplace pipeline: - Collection: books @@ -1048,6 +1025,18 @@ tests: - adventure hugo: true nebula: false + - description: testSampleLimit + pipeline: + - Collection: books + - Sample: 3 + results_num: 3 # Results will vary due to randomness + - description: testSamplePercentage + pipeline: + - Collection: books + - Sample: + - 60 + - PERCENTAGE + results_num: 6 # Results will vary due to randomness - description: testUnion pipeline: - Collection: books From 38fb7fa07cbef78814ac4a679d89ddb5bd630d4e Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 28 Jan 2025 15:33:48 -0800 Subject: [PATCH 032/131] iterating on sample options --- google/cloud/firestore_v1/pipeline.py | 5 ++-- .../firestore_v1/pipeline_expressions.py | 25 +++++++++++++------ google/cloud/firestore_v1/pipeline_stages.py | 12 ++++----- 3 files changed, 25 insertions(+), 17 deletions(-) diff --git a/google/cloud/firestore_v1/pipeline.py b/google/cloud/firestore_v1/pipeline.py index 1dcbd60e9..b1b356d71 100644 --- a/google/cloud/firestore_v1/pipeline.py +++ b/google/cloud/firestore_v1/pipeline.py @@ -9,6 +9,7 @@ Field, FilterCondition, Selectable, + SampleOptions, ) @@ -64,8 +65,8 @@ def replace( self.stages.append(stages.Replace(field, mode)) return self - def sample(self, n: int, mode: Sample.Mode = Sample.Mode.DOCUMENTS) -> Pipeline: - self.stages.append(stages.Sample(n, mode) + def sample(self, limit_or_options: int | SampleOptions) -> Pipeline: + self.stages.append(stages.Sample(limit_or_options)) return self def union(self, other: Pipeline) -> Pipeline: diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index cbce37dc8..209f20f58 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -1,25 +1,34 @@ from typing import Any, Iterable, List, Mapping, Union, Generic, TypeVar from enum import Enum from enum import auto - - -class OrderingDirection(Enum): - ASCENDING = auto() - DESCENDING = auto() +from dataclass import dataclass class Ordering: - def __init__(self, expr, order_dir: OrderingDirection | str): + class Direction(Enum): + ASCENDING = auto() + DESCENDING = auto() + + def __init__(self, expr, order_dir: Direction | str): self.expr = expr - self.order_dir = OrderingDirection[order_dir] if isinstance(order_dir, str) else order_dir + self.order_dir = Ordering.Direction[order_dir] if isinstance(order_dir, str) else order_dir def __repr__(self): - if self.order_dir is OrderingDirection.ASCENDING: + if self.order_dir is Ordering.Direction.ASCENDING: order_str = ".ascending()" else: order_str = ".descending()" return f"{self.expr!r}{order_str}" +@dataclass +class SampleOptions: + class Mode(Enum): + DOCUMENTS = auto() + PERCENTAGE = auto() + + n: int + mode: Mode + class Expr: """Represents an expression that can be evaluated to a value within the execution of a pipeline. diff --git a/google/cloud/firestore_v1/pipeline_stages.py b/google/cloud/firestore_v1/pipeline_stages.py index 75cb613a7..03e65e376 100644 --- a/google/cloud/firestore_v1/pipeline_stages.py +++ b/google/cloud/firestore_v1/pipeline_stages.py @@ -10,6 +10,7 @@ Field, FilterCondition, Selectable, + SampleOptions, ) class FindNearestOptions: @@ -177,14 +178,11 @@ def __init__(self, field: Selectable, mode: Mode = Mode.FULL_REPLACE): class Sample(Stage): - class Mode(Enum): - DOCUMENTS = auto() - PERCENTAGE = auto() - - def __init__(self, n: int, mode: Mode = Mode.DOCUMENTS): + def __init__(self, limit_or_options: int | SampleOptions): super().__init__() - self.n = n - self.mode = mode + if isinstance(limit_or_options, int): + limit_or_options = SampleOptions(limit_or_options, SampleOptions.Mode.DOCUMENTS) + self.options = limit_or_options class Select(Stage): def __init__(self, *fields: str | Selectable): From 0c3852ffdb60d146e4d07b7267eb15037f12c6fb Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 28 Jan 2025 15:45:29 -0800 Subject: [PATCH 033/131] fixed import --- google/cloud/firestore_v1/pipeline_expressions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index 209f20f58..a66bfe9ff 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -1,7 +1,7 @@ from typing import Any, Iterable, List, Mapping, Union, Generic, TypeVar from enum import Enum from enum import auto -from dataclass import dataclass +from dataclasses import dataclass class Ordering: From 1d95e544942ba6922851f78d23e11ed72f56afad Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 28 Jan 2025 16:11:52 -0800 Subject: [PATCH 034/131] removing kwargs --- tests/system/pipeline_e2e.yaml | 195 ++++++++++++--------------------- 1 file changed, 69 insertions(+), 126 deletions(-) diff --git a/tests/system/pipeline_e2e.yaml b/tests/system/pipeline_e2e.yaml index 75083f68f..164861e5d 100644 --- a/tests/system/pipeline_e2e.yaml +++ b/tests/system/pipeline_e2e.yaml @@ -130,7 +130,7 @@ tests: pipeline: - Collection: books - Aggregate: - - ExprWithAlias: + - - ExprWithAlias: expr: Count alias: count results: @@ -140,12 +140,10 @@ tests: - Collection: books - Where: - Eq: - left: - Field: genre - right: - Constant: Science Fiction + - Field: genre + - Constant: Science Fiction - Aggregate: - - ExprWithAlias: + - - ExprWithAlias: expr: Count alias: count - ExprWithAlias: @@ -167,29 +165,24 @@ tests: - Collection: books - Where: - Lt: - left: - Field: published - right: - Constant: 1900 + - Field: published + - Constant: 1900 - Aggregate: - groups: - - genre + - [] + - - genre error: "Cannot groupBy without accumulators" - description: testDistinct pipeline: - Collection: books - Where: - Lt: - left: - Field: published - right: - Constant: 1900 + - Field: published + - Constant: 1900 - Distinct: - ExprWithAlias: expr: ToLower: - value: - Field: genre + - Field: genre alias: lower_genre results: - lower_genre: romance @@ -199,25 +192,19 @@ tests: - Collection: books - Where: - Lt: - left: - Field: published - right: - Constant: 1984 + - Field: published + - Constant: 1984 - Aggregate: - accumulators: - - ExprWithAlias: - expr: - Avg: - value: rating - alias: avg_rating - groups: - - genre + - - ExprWithAlias: + expr: + Avg: + value: rating + alias: avg_rating + - - genre - Where: - Gt: - left: - Field: avg_rating - right: - Constant: 4.3 + - Field: avg_rating + - Constant: 4.3 results: - avg_rating: 4.7 genre: Fantasy @@ -229,8 +216,7 @@ tests: pipeline: - Collection: books - Aggregate: - accumulators: - - ExprWithAlias: + - - ExprWithAlias: expr: Count alias: count - ExprWithAlias: @@ -336,15 +322,11 @@ tests: - Where: - And: - Gt: - left: - Field: rating - right: - Constant: 4.5 + - Field: rating + - Constant: 4.5 - Eq: - left: - Field: genre - right: - Constant: Science Fiction + - Field: genre + - Constant: Science Fiction results: - title: Dune author: Frank Herbert @@ -364,15 +346,11 @@ tests: - Where: - Or: - Eq: - left: - Field: genre - right: - Constant: Romance + - Field: genre + - Constant: Romance - Eq: - left: - Field: genre - right: - Constant: Dystopian + - Field: genre + - Constant: Dystopian - Select: - title results: @@ -553,10 +531,8 @@ tests: - title - Where: - Gt: - left: - Field: titleLength - right: - Constant: 20 + - Field: titleLength + - Constant: 20 results: - titleLength: 32 title: The Hitchhiker's Guide to the Galaxy @@ -573,10 +549,8 @@ tests: alias: reversed_title - Where: - Eq: - left: - Field: author - right: - Constant: Douglas Adams + - Field: author + - Constant: Douglas Adams results: - reversed_title: yxalaG ot ediug s'reknhiHcH ehT - description: testStringFunctions - ReplaceFirst @@ -593,10 +567,8 @@ tests: alias: replaced_title - Where: - Eq: - left: - Field: author - right: - Constant: Douglas Adams + - Field: author + - Constant: Douglas Adams results: - replaced_title: A Hitchhiker's Guide to the Galaxy - description: testStringFunctions - ReplaceAll @@ -613,10 +585,8 @@ tests: alias: replaced_title - Where: - Eq: - left: - Field: author - right: - Constant: Douglas Adams + - Field: author + - Constant: Douglas Adams results: - replaced_title: The_Hitchhiker's_Guide_to_the_Galaxy - description: testStringFunctions - CharLength @@ -630,10 +600,8 @@ tests: alias: title_length - Where: - Eq: - left: - Field: author - right: - Constant: Douglas Adams + - Field: author + - Constant: Douglas Adams results: - title_length: 30 - description: testStringFunctions - ByteLength @@ -649,10 +617,8 @@ tests: alias: title_byte_length - Where: - Eq: - left: - Field: author - right: - Constant: Douglas Adams + - Field: author + - Constant: Douglas Adams results: - title_byte_length: 42 - description: testToLowercase @@ -780,20 +746,14 @@ tests: - Where: - And: - Gt: - left: - Field: rating - right: - Constant: 4.2 + - Field: rating + - Constant: 4.2 - Lte: - left: - Field: rating - right: - Constant: 4.5 + - Field: rating + - Constant: 4.5 - Neq: - left: - Field: genre - right: - Constant: Science Fiction + - Field: genre + - Constant: Science Fiction - Select: - rating - title @@ -815,20 +775,14 @@ tests: - Or: - And: - Gt: - left: - Field: rating - right: - Constant: 4.5 + - Field: rating + - Constant: 4.5 - Eq: - left: - Field: genre - right: - Constant: Science Fiction + - Field: genre + - Constant: Science Fiction - Lt: - left: - Field: published - right: - Constant: 1900 + - Field: published + - Constant: 1900 - Select: - title - Sort: @@ -850,10 +804,8 @@ tests: - ExprWithAlias: expr: Eq: - left: - Field: rating - right: - Constant: null + - Field: rating + - Constant: null alias: ratingIsNull - ExprWithAlias: expr: @@ -870,10 +822,8 @@ tests: - Collection: books - Where: - Eq: - left: - Field: author - right: - Constant: Douglas Adams + - Field: author + - Constant: Douglas Adams - Select: - ExprWithAlias: expr: @@ -922,10 +872,8 @@ tests: - Field: title - Where: - Eq: - left: - Field: hugoAward - right: - Constant: true + - Field: hugoAward + - Constant: true results: - hugoAward: true title: The Hitchhiker's Guide to the Galaxy @@ -975,10 +923,8 @@ tests: - Collection: books - Where: - Eq: - left: - Field: awards.hugo - right: - Constant: true + - Field: awards.hugo + - Constant: true - Select: - title - Field: awards.hugo @@ -992,10 +938,8 @@ tests: - Collection: books - Where: - Eq: - left: - Field: awards.hugo - right: - Constant: true + - Field: awards.hugo + - Constant: true - Select: - title - Field: awards.hugo @@ -1034,8 +978,9 @@ tests: pipeline: - Collection: books - Sample: - - 60 - - PERCENTAGE + - SampleOptions: + - 60 + - PERCENTAGE results_num: 6 # Results will vary due to randomness - description: testUnion pipeline: @@ -1049,10 +994,8 @@ tests: - Collection: books - Where: - Eq: - left: - Field: title - right: - Constant: The Hitchhiker's Guide to the Galaxy + - Field: title + - Constant: The Hitchhiker's Guide to the Galaxy - Unnest: tags results: - tags: comedy From a3daa070e27d82d797759e6ea5d2cc5163d2b0e8 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 28 Jan 2025 16:14:32 -0800 Subject: [PATCH 035/131] removing keargs --- tests/system/pipeline_e2e.yaml | 35 +++++++++++++--------------------- 1 file changed, 13 insertions(+), 22 deletions(-) diff --git a/tests/system/pipeline_e2e.yaml b/tests/system/pipeline_e2e.yaml index 164861e5d..ad3257377 100644 --- a/tests/system/pipeline_e2e.yaml +++ b/tests/system/pipeline_e2e.yaml @@ -148,13 +148,11 @@ tests: alias: count - ExprWithAlias: expr: - Avg: - value: rating + Avg: rating alias: avg_rating - ExprWithAlias: expr: - Max: - value: rating + Max: rating alias: max_rating results: - count: 2 @@ -197,8 +195,7 @@ tests: - Aggregate: - - ExprWithAlias: expr: - Avg: - value: rating + Avg: rating alias: avg_rating - - genre - Where: @@ -221,13 +218,11 @@ tests: alias: count - ExprWithAlias: expr: - Max: - value: rating + Max: rating alias: max_rating - ExprWithAlias: expr: - Min: - value: published + Min: published alias: min_published results: - count: 10 @@ -560,10 +555,9 @@ tests: - ExprWithAlias: expr: ReplaceFirst: - value: - Field: title - pattern: The - replacement: A + - Field: title + - The + - A alias: replaced_title - Where: - Eq: @@ -578,10 +572,9 @@ tests: - ExprWithAlias: expr: ReplaceAll: - value: - Field: title - pattern: " " - replacement: _ + - Field: title + - " " + - "_" alias: replaced_title - Where: - Eq: @@ -628,8 +621,7 @@ tests: - ExprWithAlias: expr: ToLower: - value: - Field: title + - Field: title alias: lowercaseTitle - Limit: 1 results: @@ -641,8 +633,7 @@ tests: - ExprWithAlias: expr: ToUpper: - value: - Field: author + - Field: author alias: uppercaseAuthor - Limit: 1 results: From 8e45b689b9a8ff4ba73ff861ab9cfab6b9e3dc2f Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 28 Jan 2025 16:29:47 -0800 Subject: [PATCH 036/131] use only positional arguments in yaml --- tests/system/pipeline_e2e.yaml | 237 +++++++++-------------- tests/system/test_pipeline_acceptance.py | 5 +- 2 files changed, 93 insertions(+), 149 deletions(-) diff --git a/tests/system/pipeline_e2e.yaml b/tests/system/pipeline_e2e.yaml index ad3257377..eab5b1be2 100644 --- a/tests/system/pipeline_e2e.yaml +++ b/tests/system/pipeline_e2e.yaml @@ -131,8 +131,8 @@ tests: - Collection: books - Aggregate: - - ExprWithAlias: - expr: Count - alias: count + - Count + - "count" results: - count: 10 - description: "testAggregates - avg, count, max" @@ -144,16 +144,14 @@ tests: - Constant: Science Fiction - Aggregate: - - ExprWithAlias: - expr: Count - alias: count + - Count + - "count" - ExprWithAlias: - expr: - Avg: rating - alias: avg_rating + - Avg: rating + - "avg_rating" - ExprWithAlias: - expr: - Max: rating - alias: max_rating + - Max: rating + - "max_rating" results: - count: 2 avg_rating: 4.4 @@ -178,10 +176,9 @@ tests: - Constant: 1900 - Distinct: - ExprWithAlias: - expr: - ToLower: + - ToLower: - Field: genre - alias: lower_genre + - "lower_genre" results: - lower_genre: romance - lower_genre: psychological thriller @@ -194,9 +191,8 @@ tests: - Constant: 1984 - Aggregate: - - ExprWithAlias: - expr: - Avg: rating - alias: avg_rating + - Avg: rating + - "avg_rating" - - genre - Where: - Gt: @@ -214,16 +210,14 @@ tests: - Collection: books - Aggregate: - - ExprWithAlias: - expr: Count - alias: count + - Count + - "count" - ExprWithAlias: - expr: - Max: rating - alias: max_rating + - Max: rating + - "max_rating" - ExprWithAlias: - expr: - Min: published - alias: min_published + - Min: published + - "min_published" results: - count: 10 max_rating: 4.7 @@ -264,19 +258,17 @@ tests: - Collection: books - AddFields: - ExprWithAlias: - expr: - StrConcat: + - StrConcat: - Field: author - Constant: _ - Field: title - alias: author_title + - "author_title" - ExprWithAlias: - expr: - StrConcat: + - StrConcat: - Field: title - Constant: _ - Field: author - alias: title_author + - "title_author" - RemoveFields: - title_author - tags @@ -396,10 +388,8 @@ tests: - Collection: books - Where: - ArrayContainsAny: - array: tags - elements: - - comedy - - classic + - tags + - ["comedy", "classic"] - Select: - title results: @@ -410,11 +400,8 @@ tests: - Collection: books - Where: - ArrayContainsAll: - array: - Field: tags - elements: - - adventure - - magic + - Field: tags + - ["adventure", "magic"] - Select: - title results: @@ -424,16 +411,13 @@ tests: - Collection: books - Select: - ExprWithAlias: - expr: - ArrayLength: + - ArrayLength: - Field: tags - alias: tagsCount + - "tagsCount" - Where: - Eq: - left: - Field: tagsCount - right: - Constant: 3 + - Field: tagsCount + - Constant: 3 results: # All documents have 3 tags - tagsCount: 3 - tagsCount: 3 @@ -450,14 +434,10 @@ tests: - Collection: books - Select: - ExprWithAlias: - expr: - ArrayConcat: - array: - Field: tags - rest: - - newTag1 - - newTag2 - alias: modifiedTags + - ArrayConcat: + - Field: tags + - ["newTag1", "newTag2"] + - "modifiedTags" - Limit: 1 results: - modifiedTags: @@ -471,12 +451,11 @@ tests: - Collection: books - Select: - ExprWithAlias: - expr: - StrConcat: + - StrConcat: - Field: author - Constant: " - " - Field: title - alias: bookInfo + - "bookInfo" - Limit: 1 results: - bookInfo: Douglas Adams - The Hitchhiker's Guide to the Galaxy @@ -519,10 +498,9 @@ tests: - Collection: books - Select: - ExprWithAlias: - expr: - CharLength: + - CharLength: - Field: title - alias: titleLength + - "titleLength" - title - Where: - Gt: @@ -538,10 +516,9 @@ tests: - Collection: books - Select: - ExprWithAlias: - expr: - Reverse: + - Reverse: - Field: title - alias: reversed_title + - "reversed_title" - Where: - Eq: - Field: author @@ -553,12 +530,11 @@ tests: - Collection: books - Select: - ExprWithAlias: - expr: - ReplaceFirst: + - ReplaceFirst: - Field: title - The - A - alias: replaced_title + - "replaced_title" - Where: - Eq: - Field: author @@ -570,12 +546,11 @@ tests: - Collection: books - Select: - ExprWithAlias: - expr: - ReplaceAll: + - ReplaceAll: - Field: title - " " - "_" - alias: replaced_title + - "replaced_title" - Where: - Eq: - Field: author @@ -587,10 +562,9 @@ tests: - Collection: books - Select: - ExprWithAlias: - expr: - CharLength: + - CharLength: - Field: title - alias: title_length + - "title_length" - Where: - Eq: - Field: author @@ -602,12 +576,11 @@ tests: - Collection: books - Select: - ExprWithAlias: - expr: - ByteLength: + - ByteLength: - StrConcat: - Field: title - Constant: _银河系漫游指南 - alias: title_byte_length + - "title_byte_length" - Where: - Eq: - Field: author @@ -619,10 +592,9 @@ tests: - Collection: books - Select: - ExprWithAlias: - expr: - ToLower: + - ToLower: - Field: title - alias: lowercaseTitle + - "lowercaseTitle" - Limit: 1 results: - lowercaseTitle: the hitchhiker's guide to the galaxy @@ -631,10 +603,9 @@ tests: - Collection: books - Select: - ExprWithAlias: - expr: - ToUpper: + - ToUpper: - Field: author - alias: uppercaseAuthor + - "uppercaseAuthor" - Limit: 1 results: - uppercaseAuthor: DOUGLAS ADAMS @@ -643,18 +614,16 @@ tests: - Collection: books - AddFields: - ExprWithAlias: - expr: - StrConcat: + - StrConcat: - Constant: " " - Field: title - Constant: " " - alias: spacedTitle + - "spacedTitle" - Select: - ExprWithAlias: - expr: - Trim: + - Trim: - Field: spacedTitle - alias: trimmedTitle + - "trimmedTitle" - spacedTitle - Limit: 1 results: @@ -702,29 +671,25 @@ tests: - Collection: books - Select: - ExprWithAlias: - expr: - Add: + - Add: - Field: rating - Constant: 1 - alias: ratingPlusOne + - "ratingPlusOne" - ExprWithAlias: - expr: - Subtract: + - Subtract: - Field: published - Constant: 1900 - alias: yearsSince1900 + - "yearsSince1900" - ExprWithAlias: - expr: - Multiply: + - Multiply: - Field: rating - Constant: 10 - alias: ratingTimesTen + - "ratingTimesTen" - ExprWithAlias: - expr: - Divide: + - Divide: - Field: rating - Constant: 2 - alias: ratingDividedByTwo + - "ratingDividedByTwo" - Limit: 1 results: - ratingPlusOne: 5.2 @@ -793,17 +758,15 @@ tests: - Field: rating - Select: - ExprWithAlias: - expr: - Eq: + - Eq: - Field: rating - Constant: null - alias: ratingIsNull + - "ratingIsNull" - ExprWithAlias: - expr: - - Not: + - Not: - IsNaN: - Field: rating - alias: ratingIsNotNaN + - "ratingIsNotNaN" - Limit: 1 results: - ratingIsNull: false @@ -817,17 +780,15 @@ tests: - Constant: Douglas Adams - Select: - ExprWithAlias: - expr: - LogicalMax: + - LogicalMax: - Field: rating - Constant: 4.5 - alias: max_rating + - "max_rating" - ExprWithAlias: - expr: - LogicalMax: + - LogicalMax: - Field: published - Constant: 1900 - alias: max_published + - "max_published" results: - max_rating: 4.5 max_published: 1979 @@ -836,17 +797,15 @@ tests: - Collection: books - Select: - ExprWithAlias: - expr: - LogicalMin: + - LogicalMin: - Field: rating - Constant: 4.5 - alias: min_rating + - "min_rating" - ExprWithAlias: - expr: - LogicalMin: + - LogicalMin: - Field: published - Constant: 1900 - alias: min_published + - "min_published" results: - min_rating: 4.2 min_published: 1900 @@ -855,11 +814,10 @@ tests: - Collection: books - Select: - ExprWithAlias: - expr: - MapGet: + - MapGet: - Field: awards - hugo - alias: hugoAward + - "hugoAward" - Field: title - Where: - Eq: @@ -875,35 +833,20 @@ tests: - Collection: books - Select: - ExprWithAlias: - expr: - CosineDistance: - vector1: - - 0.1 - - 0.1 - vector2: - - 0.5 - - 0.8 - alias: cosineDistance + - CosineDistance: + - [0.1, 0.1] + - [0.5, 0.8] + - "cosineDistance" - ExprWithAlias: - expr: - DotProduct: - vector1: - - 0.1 - - 0.1 - vector2: - - 0.5 - - 0.8 - alias: dotProductDistance + - DotProduct: + - [0.1, 0.1] + - [0.5, 0.8] + - "dotProductDistance" - ExprWithAlias: - expr: - EuclideanDistance: - vector1: - - 0.1 - - 0.1 - vector2: - - 0.5 - - 0.8 - alias: euclideanDistance + - EuclideanDistance: + - [0.1, 0.1] + - [0.5, 0.8] + - "euclideanDistance" - Limit: 1 results: - cosineDistance: 0.02560880430538015 diff --git a/tests/system/test_pipeline_acceptance.py b/tests/system/test_pipeline_acceptance.py index 2fa8bab69..1da0558ee 100644 --- a/tests/system/test_pipeline_acceptance.py +++ b/tests/system/test_pipeline_acceptance.py @@ -49,8 +49,9 @@ def loader(): def _apply_yaml_args(cls, yaml_args): if isinstance(yaml_args, dict): - # yaml has a mapping of arguments. Treat as kwargs - return cls(**parse_expressions(yaml_args)) + # reject mapping arguments: use only positional arguments in yaml + # for cross-language simplicity + raise ValueError(f"found kwargs for class: {cls}") elif isinstance(yaml_args, list): # yaml has an array of arguments. Treat as args return cls(*parse_expressions(yaml_args)) From 6b2a6054eb64d950a8d8a0a88c490ac20bca06f2 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 27 Feb 2025 11:10:56 -0800 Subject: [PATCH 037/131] added _to_pb to stages --- google/cloud/firestore_v1/pipeline.py | 3 + google/cloud/firestore_v1/pipeline_stages.py | 84 ++++++++++++++++++-- 2 files changed, 82 insertions(+), 5 deletions(-) diff --git a/google/cloud/firestore_v1/pipeline.py b/google/cloud/firestore_v1/pipeline.py index b1b356d71..36b4fdcba 100644 --- a/google/cloud/firestore_v1/pipeline.py +++ b/google/cloud/firestore_v1/pipeline.py @@ -26,6 +26,9 @@ def __repr__(self): stages_str = ",\n ".join([repr(s) for s in self.stages]) return f"Pipeline(\n {stages_str}\n)" + def _to_pb(self) -> Pipeline: + return Pipeline(stages=[s._to_pb() for s in self.stages]) + def add_fields(self, fields: Dict[str, Expr]) -> Pipeline: self.stages.append(stages.AddFields(fields)) return self diff --git a/google/cloud/firestore_v1/pipeline_stages.py b/google/cloud/firestore_v1/pipeline_stages.py index 03e65e376..c3f3021c9 100644 --- a/google/cloud/firestore_v1/pipeline_stages.py +++ b/google/cloud/firestore_v1/pipeline_stages.py @@ -3,6 +3,8 @@ from enum import Enum from enum import auto +from google.cloud.firestore_v1.types.document import Pipeline +from google.cloud.firestore_v1.types.document import Value from google.cloud.firestore_v1.pipeline_expressions import ( Accumulator, Expr, @@ -32,6 +34,17 @@ class Stage: def __init__(self, custom_name: Optional[str] = None): self.name = custom_name or type(self).__name__.lower() + def _to_pb(self) -> Pipeline.Stage: + return Pipeline.Stage(name=self.name, args=self._pb_args(), options=self._pb_options()) + + def _pb_args(self) -> list[Value]: + """Return Ordered list of arguments the given stage expects""" + return [] + + def _pb_options(self) -> dict[str, Value]: + """Return optional named arguments that certain functions may support.""" + return {} + def __repr__(self): items = ("%s=%r" % (k, v) for k, v in self.__dict__.items() if k != "name") return f"{self.__class__.__name__}({', '.join(items)})" @@ -42,6 +55,9 @@ def __init__(self, *fields: Selectable): super().__init__("add_fields") self.fields = list(fields) + def _pb_args(self) -> list[Value]: + return [Value(map_value=self._fields_map())] + def _fields_map(self) -> dict[str, Expr]: return dict(f._to_map() for f in self.fields) @@ -57,6 +73,9 @@ def __init__( self.groups: list[Selectable] = [Field(f) if isinstance(f, str) else f for f in groups] self.accumulators: list[ExprWithAlias[Accumulator]] = [*accumulators, *extra_accumulators] + def _pb_args(self) -> list[Value]: + raise NotImplementedError + @property def _group_map(self) -> dict[str, Expr]: return dict(f._to_map() for f in self.groups) @@ -80,21 +99,25 @@ class Collection(Stage): def __init__(self, path: str): super().__init__() if not path.startswith("/"): - path = "/" + path + path = f"/{path}" self.path = path + def _pb_args(self): + return [Value(reference_value=self.path)] class CollectionGroup(Stage): def __init__(self, collection_id: str): super().__init__("collection_group") self.collection_id = collection_id + def _pb_args(self): + return [Value(string_value=self.collection_id)] + class Database(Stage): def __init__(self): super().__init__() - class Distinct(Stage): def __init__(self, *fields: str | Selectable): super().__init__() @@ -109,6 +132,9 @@ def _fields_dict(self) -> dict[str, Selectable]: for f in self.fields ) + def _pb_args(self) -> list[Value]: + raise NotImplementedError + class Documents(Stage): def __init__(self, *documents: str): @@ -120,6 +146,9 @@ def of(*documents: "DocumentReference") -> "Documents": doc_paths = ["/" + doc.path for doc in documents] return Documents(doc_paths) + def _pb_args(self): + return [Value(list_value=self.documents)] + class FindNearest(Stage): def __init__( @@ -135,24 +164,38 @@ def __init__( self.distance_measure = distance_measure self.options = options or FindNearestOptions() + def _pb_args(self): + raise NotImplementedError + + def _pb_options(self) -> dict[str, Value]: + raise NotImplementedError class GenericStage(Stage): def __init__(self, name: str, *params: Any): super().__init__(name) self.params = list(params) + def _pb_args(self): + raise NotImplementedError + class Limit(Stage): def __init__(self, limit: int): super().__init__() self.limit = limit + def _pb_args(self): + return [Value(integer_value=self.limit)] + class Offset(Stage): def __init__(self, offset: int): super().__init__() self.offset = offset + def _pb_args(self): + return [Value(integer_value=self.offset)] + class RemoveFields(Stage): def __init__(self, *fields: str | Field): @@ -163,6 +206,9 @@ def __init__(self, *fields: str | Field): def _fields_map(self) -> dict[str, Field]: dict(f._to_map() for f in self.fields) + def _pb_args(self) -> list[Value]: + return [Value(map_value=self._fields_map())] + class Replace(Stage): class Mode(Enum): @@ -175,14 +221,23 @@ def __init__(self, field: Selectable, mode: Mode = Mode.FULL_REPLACE): self.field = field self.mode = mode + def _pb_args(self): + raise NotImplementedError + class Sample(Stage): def __init__(self, limit_or_options: int | SampleOptions): super().__init__() if isinstance(limit_or_options, int): - limit_or_options = SampleOptions(limit_or_options, SampleOptions.Mode.DOCUMENTS) - self.options = limit_or_options + options = SampleOptions(limit_or_options, SampleOptions.Mode.DOCUMENTS) + else: + options = limit_or_options + self.options: SampleOptions = options + + def _pb_args(self): + return [Value(integer_value=self.options.limit), Value(string_value=self.options.mode.value)] + class Select(Stage): def __init__(self, *fields: str | Selectable): @@ -193,7 +248,8 @@ def __init__(self, *fields: str | Selectable): def _projections_map(self) -> dict[str, Expr]: return dict(f._to_map() for f in self.projections) - + def _pb_args(self) -> list[Value]: + return [Value(map_value=self._projections_map())] class Sort(Stage): @@ -201,12 +257,18 @@ def __init__(self, *orders: "Ordering"): super().__init__() self.orders = list(orders) + def _pb_args(self): + return [Value(map_value=o._to_map()) for o in self.orders] + class Union(Stage): def __init__(self, other: "Pipeline"): super().__init__() self.other = other + def _pb_args(self): + return [Value(map_value=self.other._to_map())] + class Unnest(Stage): def __init__(self, field: Field, options: Optional["UnnestOptions"] = None): @@ -214,9 +276,21 @@ def __init__(self, field: Field, options: Optional["UnnestOptions"] = None): self.field = field self.options = options + def _pb_args(self): + raise NotImplementedError + + def _pb_options(self): + options = {} + if self.options is not None: + options["index_field"] = Value(string_value=self.options.index_field) + return options + class Where(Stage): def __init__(self, condition: FilterCondition): super().__init__() self.condition = condition + def _pb_args(self): + return [Value(map_value=self.condition._to_map())] + From 6d2aa90702b03b218b926b9f76e0dac9918ebe94 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 27 Feb 2025 12:16:45 -0800 Subject: [PATCH 038/131] added _pb to pipeline expressions --- .../firestore_v1/pipeline_expressions.py | 53 ++++++++++++++++++- 1 file changed, 51 insertions(+), 2 deletions(-) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index a66bfe9ff..6d428db86 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -1,7 +1,12 @@ from typing import Any, Iterable, List, Mapping, Union, Generic, TypeVar from enum import Enum from enum import auto +import datetime from dataclasses import dataclass +from google.cloud.firestore_v1.types.document import Value + +CONSTANT_ELEMENTS = Union[str, int, float, bool, datetime.datetime, bytes, tuple[float, float]] +CONSTANT_TYPES = Union[CONSTANT_ELEMENTS, list[Union[CONSTANT_ELEMENTS, list, dict]], dict[str, Union[CONSTANT_ELEMENTS, list, dict]]] class Ordering: @@ -37,6 +42,9 @@ class Expr: def __repr__(self): return f"{self.__class__.__name__}()" + def _to_pb(self) -> Value: + raise NotImplementedError + @staticmethod def _cast_to_expr_or_convert_to_constant(o: Any) -> "Expr": return o if isinstance(o, Expr) else Constant(o) @@ -220,22 +228,48 @@ def as_(self, alias: str) -> "ExprWithAlias": class Constant(Expr): - def __init__(self, value: Any): + def __init__(self, value: CONSTANT_TYPES): self.value = value @staticmethod - def of(value): + def of(value:CONSTANT_TYPES): return Constant(value) def __repr__(self): return f"Constant.of({self.value!r})" + def _to_pb(self): + if self.value is None: + return Value(null_value=Value.NullValue.NULL_VALUE) + elif isinstance(self.value, bool): + return Value(boolean_value=self.value) + elif isinstance(self.value, int): + return Value(integer_value=self.value) + elif isinstance(self.value, float): + return Value(double_value=self.value) + elif isinstance(self.value, datetime.datetime): + return Value(timestamp_value=self.value.timestamp()) + elif isinstance(self.value, str): + return Value(string_value=self.value) + elif isinstance(self.value, bytes): + return Value(bytes_value=self.value) + elif isinstance(self.value, tuple) and len(self.value) == 2 and isinstance(self.value[0], float) and isinstance(self.value[1], float): + return Value(geo_point_value=self.value) + elif isinstance(self.value, list): + return Value(array_value={"values":[Constant(v)._to_pb() for v in self.value]}) + elif isinstance(self.value, dict): + return Value(map_value={"fields": {k: Constant(v)._to_pb() for k, v in self.value.items()}}) + else: + raise ValueError(f"Unsupported type: {type(self.value)}") class ListOfExprs(Expr): def __init__(self, exprs: List[Expr]): self.exprs = exprs + def _to_pb(self): + return Value(array_value={"values": [e._to_pb() for e in self.exprs]}) + class Function(Expr): """A type of Expression that takes in inputs and gives outputs.""" @@ -247,6 +281,13 @@ def __init__(self, name: str, params: List[Expr]): def __repr__(self): return f"{self.__class__.__name__}({', '.join([repr(p) for p in self.params])})" + def _to_pb(self): + return Value( + function_value={ + "name": self.name, "args": [p._to_pb() for p in self.params] + } + ) + class Divide(Function): def __init__(self, left: Expr, right: Expr): super().__init__("divide", [left, right]) @@ -485,6 +526,11 @@ def _to_map(self): def __repr__(self): return f"{self.expr}.as('{self.alias}')" + def _to_pb(self): + return Value( + map_value={"fields": {self.alias: self.expr._to_pb()}} + ) + class Field(Expr, Selectable): DOCUMENT_ID = "__name__" @@ -502,6 +548,9 @@ def _to_map(self): def __repr__(self): return f"Field.of({self.path!r})" + def _to_pb(self): + return Value(field_path=self.path) + class FilterCondition(Function): """Filters the given data in some way.""" From e51d9dc842b4e7836ce8e8d0bb05496dc8d9a337 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 27 Feb 2025 12:56:11 -0800 Subject: [PATCH 039/131] added unimplemented stage._to_pbs --- .../firestore_v1/pipeline_expressions.py | 10 +++--- google/cloud/firestore_v1/pipeline_stages.py | 36 ++++++++++--------- 2 files changed, 24 insertions(+), 22 deletions(-) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index 6d428db86..552641aa6 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -507,7 +507,7 @@ def __init__(self, value: Expr, distinct: bool=False): super().__init__("countif", [value] if value else []) -class Selectable: +class Selectable(Expr): """Points at something in the database?""" def _to_map(self): @@ -515,13 +515,13 @@ def _to_map(self): T = TypeVar('T', bound=Expr) -class ExprWithAlias(Expr, Selectable, Generic[T]): +class ExprWithAlias(Selectable, Generic[T]): def __init__(self, expr: T, alias: str): self.expr = expr self.alias = alias def _to_map(self): - return self.alias, self.expr + return self.alias, self.expr._to_pb() def __repr__(self): return f"{self.expr}.as('{self.alias}')" @@ -532,7 +532,7 @@ def _to_pb(self): ) -class Field(Expr, Selectable): +class Field(Selectable): DOCUMENT_ID = "__name__" def __init__(self, path: str): @@ -543,7 +543,7 @@ def of(path: str): return Field(path) def _to_map(self): - return self.path, self + return self.path, self._to_pb() def __repr__(self): return f"Field.of({self.path!r})" diff --git a/google/cloud/firestore_v1/pipeline_stages.py b/google/cloud/firestore_v1/pipeline_stages.py index c3f3021c9..308513861 100644 --- a/google/cloud/firestore_v1/pipeline_stages.py +++ b/google/cloud/firestore_v1/pipeline_stages.py @@ -74,7 +74,10 @@ def __init__( self.accumulators: list[ExprWithAlias[Accumulator]] = [*accumulators, *extra_accumulators] def _pb_args(self) -> list[Value]: - raise NotImplementedError + return [ + Value(map_value={"fields": self._accumulators_map}), + Value(map_value={"groups": self._group_map}), + ] @property def _group_map(self) -> dict[str, Expr]: @@ -123,17 +126,8 @@ def __init__(self, *fields: str | Selectable): super().__init__() self.fields: list[Selectable] = [Field(f) if isinstance(f, str) else f for f in fields] - @property - def _fields_dict(self) -> dict[str, Selectable]: - return dict( - f._to_map() - if isinstance(f, Selectable) - else (f,Field(f)) - for f in self.fields - ) - def _pb_args(self) -> list[Value]: - raise NotImplementedError + raise Value(map_value={"fields": {m[0]: m[1] for m in [f._to_map() for f in self.fields]}}) class Documents(Stage): @@ -165,18 +159,26 @@ def __init__( self.options = options or FindNearestOptions() def _pb_args(self): - raise NotImplementedError + return [ + self.property._to_pb(), + Value(array_value={"values": self.vector}), + ] def _pb_options(self) -> dict[str, Value]: - raise NotImplementedError + options = {} + if self.options and self.options.limit is not None: + options["limit"] = Value(integer_value=self.options.limit) + if self.options and self.options.distance_field is not None: + options["distance_field"] = self.options.distance_field._to_pb() + return options class GenericStage(Stage): - def __init__(self, name: str, *params: Any): + def __init__(self, name: str, *params: Value): super().__init__(name) self.params = list(params) def _pb_args(self): - raise NotImplementedError + return self.params class Limit(Stage): @@ -222,7 +224,7 @@ def __init__(self, field: Selectable, mode: Mode = Mode.FULL_REPLACE): self.mode = mode def _pb_args(self): - raise NotImplementedError + return [self.field._to_pb(), Value(string_value=self.mode.value)] class Sample(Stage): @@ -277,7 +279,7 @@ def __init__(self, field: Field, options: Optional["UnnestOptions"] = None): self.options = options def _pb_args(self): - raise NotImplementedError + return [self.field._to_pb()] def _pb_options(self): options = {} From 8f6dea5de42693f1fe0748b8e348cfe8d17e6d9b Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 27 Feb 2025 13:02:07 -0800 Subject: [PATCH 040/131] fixed proto formatting --- google/cloud/firestore_v1/pipeline_stages.py | 39 +++++--------------- 1 file changed, 9 insertions(+), 30 deletions(-) diff --git a/google/cloud/firestore_v1/pipeline_stages.py b/google/cloud/firestore_v1/pipeline_stages.py index 308513861..627c2e7c4 100644 --- a/google/cloud/firestore_v1/pipeline_stages.py +++ b/google/cloud/firestore_v1/pipeline_stages.py @@ -56,11 +56,7 @@ def __init__(self, *fields: Selectable): self.fields = list(fields) def _pb_args(self) -> list[Value]: - return [Value(map_value=self._fields_map())] - - def _fields_map(self) -> dict[str, Expr]: - return dict(f._to_map() for f in self.fields) - + return [Value(map_value={"fields": {m[0]: m[1] for m in [f._to_map() for f in self.fields]}})] class Aggregate(Stage): def __init__( @@ -75,19 +71,10 @@ def __init__( def _pb_args(self) -> list[Value]: return [ - Value(map_value={"fields": self._accumulators_map}), - Value(map_value={"groups": self._group_map}), + Value(map_value={"fields": {m[0]: m[1] for m in [f._to_map() for f in self.accumulators]}}), + Value(map_value={"fields": {m[0]: m[1] for m in [f._to_map() for f in self.groups]}}) ] - @property - def _group_map(self) -> dict[str, Expr]: - return dict(f._to_map() for f in self.groups) - - @property - def _accumulators_map(self) -> dict[str, Expr]: - return dict(f._to_map() for f in self.accumulators) - - def __repr__(self): accumulator_str = ', '.join(repr(v) for v in self.accumulators) group_str = "" @@ -141,7 +128,7 @@ def of(*documents: "DocumentReference") -> "Documents": return Documents(doc_paths) def _pb_args(self): - return [Value(list_value=self.documents)] + return [Value(list_value={"values": [Value(string_value=doc) for doc in self.documents]})] class FindNearest(Stage): @@ -204,12 +191,8 @@ def __init__(self, *fields: str | Field): super().__init__("remove_fields") self.fields = [Field(f) if isinstance(f, str) else f for f in fields] - @property - def _fields_map(self) -> dict[str, Field]: - dict(f._to_map() for f in self.fields) - def _pb_args(self) -> list[Value]: - return [Value(map_value=self._fields_map())] + return [Value(map_value={"fields": {m[0]: m[1] for m in [f._to_map() for f in self.fields]}})] class Replace(Stage): @@ -246,12 +229,8 @@ def __init__(self, *fields: str | Selectable): super().__init__() self.projections = [Field(f) if isinstance(f, str) else f for f in fields] - @property - def _projections_map(self) -> dict[str, Expr]: - return dict(f._to_map() for f in self.projections) - def _pb_args(self) -> list[Value]: - return [Value(map_value=self._projections_map())] + return [Value(map_value={"fields": {m[0]: m[1] for m in [f._to_map() for f in self.projections]}})]} class Sort(Stage): @@ -260,7 +239,7 @@ def __init__(self, *orders: "Ordering"): self.orders = list(orders) def _pb_args(self): - return [Value(map_value=o._to_map()) for o in self.orders] + return [Value(map_value={"fields": {m[0]: m[1] for m in [o._to_map() for o in self.orders]}})] class Union(Stage): @@ -269,7 +248,7 @@ def __init__(self, other: "Pipeline"): self.other = other def _pb_args(self): - return [Value(map_value=self.other._to_map())] + return [Value(pipeline_value=self.other._to_pb())] class Unnest(Stage): @@ -294,5 +273,5 @@ def __init__(self, condition: FilterCondition): self.condition = condition def _pb_args(self): - return [Value(map_value=self.condition._to_map())] + return [Value(map_value={"fields": {m[0]: m[1] for m in [self.condition._to_map()]}})] From 654e5f712716f9447b0438793a34dc1a601e79df Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 27 Feb 2025 16:31:02 -0800 Subject: [PATCH 041/131] got protos to build with no errors --- google/cloud/firestore_v1/pipeline.py | 6 +- .../firestore_v1/pipeline_expressions.py | 37 ++++++++--- google/cloud/firestore_v1/pipeline_stages.py | 27 ++++---- tests/system/pipeline_e2e.yaml | 61 ++++++++++--------- tests/system/test_pipeline_acceptance.py | 11 ++-- 5 files changed, 82 insertions(+), 60 deletions(-) diff --git a/google/cloud/firestore_v1/pipeline.py b/google/cloud/firestore_v1/pipeline.py index 36b4fdcba..73da00201 100644 --- a/google/cloud/firestore_v1/pipeline.py +++ b/google/cloud/firestore_v1/pipeline.py @@ -1,7 +1,7 @@ from __future__ import annotations from typing import Any, Dict, Iterable, List, Optional from google.cloud.firestore_v1 import pipeline_stages as stages - +from google.cloud.firestore_v1.types.document import Pipeline as Pipeline_pb from google.cloud.firestore_v1.pipeline_expressions import ( Accumulator, Expr, @@ -26,8 +26,8 @@ def __repr__(self): stages_str = ",\n ".join([repr(s) for s in self.stages]) return f"Pipeline(\n {stages_str}\n)" - def _to_pb(self) -> Pipeline: - return Pipeline(stages=[s._to_pb() for s in self.stages]) + def _to_pb(self) -> Pipeline_pb: + return Pipeline_pb(stages=[s._to_pb() for s in self.stages]) def add_fields(self, fields: Dict[str, Expr]) -> Pipeline: self.stages.append(stages.AddFields(fields)) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index 552641aa6..f996f5604 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from google.cloud.firestore_v1.types.document import Value -CONSTANT_ELEMENTS = Union[str, int, float, bool, datetime.datetime, bytes, tuple[float, float]] +CONSTANT_ELEMENTS = Union[str, int, float, bool, datetime.datetime, bytes, tuple[float, float], None] CONSTANT_TYPES = Union[CONSTANT_ELEMENTS, list[Union[CONSTANT_ELEMENTS, list, dict]], dict[str, Union[CONSTANT_ELEMENTS, list, dict]]] class Ordering: @@ -15,7 +15,7 @@ class Direction(Enum): DESCENDING = auto() def __init__(self, expr, order_dir: Direction | str): - self.expr = expr + self.expr = expr if isinstance(expr, Expr) else Field.of(expr) self.order_dir = Ordering.Direction[order_dir] if isinstance(order_dir, str) else order_dir def __repr__(self): @@ -25,15 +25,28 @@ def __repr__(self): order_str = ".descending()" return f"{self.expr!r}{order_str}" + def _to_pb(self) -> Value: + return Value( + map_value={"fields": + { + "direction": Value(string_value=self.order_dir.name), + "expression": self.expr._to_pb() + } + } + ) + @dataclass class SampleOptions: class Mode(Enum): - DOCUMENTS = auto() - PERCENTAGE = auto() + DOCUMENTS = "documents" + PERCENTAGE = "percent" n: int mode: Mode + def __post_init__(self): + self.mode = SampleOptions.Mode(self.mode) if isinstance(self.mode, str) else self.mode + class Expr: """Represents an expression that can be evaluated to a value within the execution of a pipeline. @@ -228,7 +241,7 @@ def as_(self, alias: str) -> "ExprWithAlias": class Constant(Expr): - def __init__(self, value: CONSTANT_TYPES): + def __init__(self, value: CONSTANT_TYPES=None): self.value = value @staticmethod @@ -240,7 +253,7 @@ def __repr__(self): def _to_pb(self): if self.value is None: - return Value(null_value=Value.NullValue.NULL_VALUE) + return Value(null_value=0) elif isinstance(self.value, bool): return Value(boolean_value=self.value) elif isinstance(self.value, int): @@ -265,7 +278,7 @@ def _to_pb(self): class ListOfExprs(Expr): def __init__(self, exprs: List[Expr]): - self.exprs = exprs + self.exprs: list[Expr] = [Field.of(e) if isinstance(e, str) else e for e in exprs] def _to_pb(self): return Value(array_value={"values": [e._to_pb() for e in self.exprs]}) @@ -276,7 +289,7 @@ class Function(Expr): def __init__(self, name: str, params: List[Expr]): self.name = name - self.params = params + self.params = [Field.of(p) if isinstance(p, str) else p for p in params] def __repr__(self): return f"{self.__class__.__name__}({', '.join([repr(p) for p in self.params])})" @@ -295,11 +308,15 @@ def __init__(self, left: Expr, right: Expr): class DotProduct(Function): def __init__(self, vector1: Expr, vector2: Expr): + vector1 = Constant(vector1) if isinstance(vector1, list) else vector1 + vector2 = Constant(vector2) if isinstance(vector2, list) else vector2 super().__init__("dot_product", [vector1, vector2]) class EuclideanDistance(Function): def __init__(self, vector1: Expr, vector2: Expr): + vector1 = Constant(vector1) if isinstance(vector1, list) else vector1 + vector2 = Constant(vector2) if isinstance(vector2, list) else vector2 super().__init__("euclidean_distance", [vector1, vector2]) @@ -470,6 +487,8 @@ def __init__(self, value: Expr): class CosineDistance(Function): def __init__(self, vector1: Expr, vector2: Expr): + vector1 = Constant(vector1) if isinstance(vector1, list) else vector1 + vector2 = Constant(vector2) if isinstance(vector2, list) else vector2 super().__init__("cosine_distance", [vector1, vector2]) @@ -549,7 +568,7 @@ def __repr__(self): return f"Field.of({self.path!r})" def _to_pb(self): - return Value(field_path=self.path) + return Value(field_reference_value=self.path) class FilterCondition(Function): diff --git a/google/cloud/firestore_v1/pipeline_stages.py b/google/cloud/firestore_v1/pipeline_stages.py index 627c2e7c4..c3cb71c0f 100644 --- a/google/cloud/firestore_v1/pipeline_stages.py +++ b/google/cloud/firestore_v1/pipeline_stages.py @@ -35,11 +35,11 @@ def __init__(self, custom_name: Optional[str] = None): self.name = custom_name or type(self).__name__.lower() def _to_pb(self) -> Pipeline.Stage: - return Pipeline.Stage(name=self.name, args=self._pb_args(), options=self._pb_options()) + return Pipeline.Stage(name=self.name, args=[*self._pb_args()], options=self._pb_options()) - def _pb_args(self) -> list[Value]: + def _pb_args(self) -> tuple[Value, ...]: """Return Ordered list of arguments the given stage expects""" - return [] + return () def _pb_options(self) -> dict[str, Value]: """Return optional named arguments that certain functions may support.""" @@ -114,7 +114,7 @@ def __init__(self, *fields: str | Selectable): self.fields: list[Selectable] = [Field(f) if isinstance(f, str) else f for f in fields] def _pb_args(self) -> list[Value]: - raise Value(map_value={"fields": {m[0]: m[1] for m in [f._to_map() for f in self.fields]}}) + return [Value(map_value={"fields": {m[0]: m[1] for m in [f._to_map() for f in self.fields]}})] class Documents(Stage): @@ -149,6 +149,7 @@ def _pb_args(self): return [ self.property._to_pb(), Value(array_value={"values": self.vector}), + Value(string_value=self.distance_measure.value), ] def _pb_options(self) -> dict[str, Value]: @@ -201,10 +202,10 @@ class Mode(Enum): MERGE_PREFER_NEXT = "merge_prefer_nest" MERGE_PREFER_PARENT = "merge_prefer_parent" - def __init__(self, field: Selectable, mode: Mode = Mode.FULL_REPLACE): + def __init__(self, field: Selectable | str, mode: Mode | str = Mode.FULL_REPLACE): super().__init__() - self.field = field - self.mode = mode + self.field = Field(field) if isinstance(field, str) else field + self.mode = self.Mode[mode] if isinstance(mode, str) else mode def _pb_args(self): return [self.field._to_pb(), Value(string_value=self.mode.value)] @@ -221,7 +222,7 @@ def __init__(self, limit_or_options: int | SampleOptions): self.options: SampleOptions = options def _pb_args(self): - return [Value(integer_value=self.options.limit), Value(string_value=self.options.mode.value)] + return [Value(integer_value=self.options.n), Value(string_value=self.options.mode.value)] class Select(Stage): @@ -230,7 +231,7 @@ def __init__(self, *fields: str | Selectable): self.projections = [Field(f) if isinstance(f, str) else f for f in fields] def _pb_args(self) -> list[Value]: - return [Value(map_value={"fields": {m[0]: m[1] for m in [f._to_map() for f in self.projections]}})]} + return [Value(map_value={"fields": {m[0]: m[1] for m in [f._to_map() for f in self.projections]}})] class Sort(Stage): @@ -239,7 +240,7 @@ def __init__(self, *orders: "Ordering"): self.orders = list(orders) def _pb_args(self): - return [Value(map_value={"fields": {m[0]: m[1] for m in [o._to_map() for o in self.orders]}})] + return [o._to_pb() for o in self.orders] class Union(Stage): @@ -252,9 +253,9 @@ def _pb_args(self): class Unnest(Stage): - def __init__(self, field: Field, options: Optional["UnnestOptions"] = None): + def __init__(self, field: Field | str, options: Optional["UnnestOptions"] = None): super().__init__() - self.field = field + self.field: Field = Field(field) if isinstance(field, str) else field self.options = options def _pb_args(self): @@ -273,5 +274,5 @@ def __init__(self, condition: FilterCondition): self.condition = condition def _pb_args(self): - return [Value(map_value={"fields": {m[0]: m[1] for m in [self.condition._to_map()]}})] + return [self.condition._to_pb()] diff --git a/tests/system/pipeline_e2e.yaml b/tests/system/pipeline_e2e.yaml index eab5b1be2..c5ca0d840 100644 --- a/tests/system/pipeline_e2e.yaml +++ b/tests/system/pipeline_e2e.yaml @@ -130,9 +130,9 @@ tests: pipeline: - Collection: books - Aggregate: - - - ExprWithAlias: - - Count - - "count" + - ExprWithAlias: + - Count + - "count" results: - count: 10 - description: "testAggregates - avg, count, max" @@ -143,15 +143,15 @@ tests: - Field: genre - Constant: Science Fiction - Aggregate: - - - ExprWithAlias: - - Count - - "count" - - ExprWithAlias: - - Avg: rating - - "avg_rating" - - ExprWithAlias: - - Max: rating - - "max_rating" + - ExprWithAlias: + - Count + - "count" + - ExprWithAlias: + - Avg: rating + - "avg_rating" + - ExprWithAlias: + - Max: rating + - "max_rating" results: - count: 2 avg_rating: 4.4 @@ -164,8 +164,8 @@ tests: - Field: published - Constant: 1900 - Aggregate: - - [] - - - genre + accumulators: [] + groups: [genre] error: "Cannot groupBy without accumulators" - description: testDistinct pipeline: @@ -190,10 +190,11 @@ tests: - Field: published - Constant: 1984 - Aggregate: - - - ExprWithAlias: - - Avg: rating - - "avg_rating" - - - genre + accumulators: + - ExprWithAlias: + - Avg: rating + - "avg_rating" + groups: [genre] - Where: - Gt: - Field: avg_rating @@ -209,15 +210,15 @@ tests: pipeline: - Collection: books - Aggregate: - - - ExprWithAlias: - - Count - - "count" - - ExprWithAlias: - - Max: rating - - "max_rating" - - ExprWithAlias: - - Min: published - - "min_published" + - ExprWithAlias: + - Count + - "count" + - ExprWithAlias: + - Max: rating + - "max_rating" + - ExprWithAlias: + - Min: published + - "min_published" results: - count: 10 max_rating: 4.7 @@ -715,8 +716,8 @@ tests: - title - Sort: - Ordering: - - Field: title - - direction: ASCENDING + - title + - ASCENDING results: - rating: 4.3 title: Crime and Punishment @@ -914,7 +915,7 @@ tests: - Sample: - SampleOptions: - 60 - - PERCENTAGE + - percent results_num: 6 # Results will vary due to randomness - description: testUnion pipeline: diff --git a/tests/system/test_pipeline_acceptance.py b/tests/system/test_pipeline_acceptance.py index 1da0558ee..8ce5d63aa 100644 --- a/tests/system/test_pipeline_acceptance.py +++ b/tests/system/test_pipeline_acceptance.py @@ -49,15 +49,16 @@ def loader(): def _apply_yaml_args(cls, yaml_args): if isinstance(yaml_args, dict): - # reject mapping arguments: use only positional arguments in yaml - # for cross-language simplicity - raise ValueError(f"found kwargs for class: {cls}") + return cls(**parse_expressions(yaml_args)) elif isinstance(yaml_args, list): # yaml has an array of arguments. Treat as args return cls(*parse_expressions(yaml_args)) - else: + elif yaml_args is not None: # yaml has a single argument return cls(parse_expressions(yaml_args)) + else: + # no arguments + return cls() def parse_pipeline(pipeline: list[dict[str, Any], str]): @@ -112,7 +113,7 @@ def parse_expressions(yaml_element: Any): ) def test_e2e_scenario(test_dict): pipeline = parse_pipeline(test_dict["pipeline"]) - print(pipeline) + print(pipeline._to_pb()) # before_ast = ast.parse(test_dict["before"]) # got_ast = before_ast From 69d4b76a32bcc8ed857780485422f404306626d9 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 27 Feb 2025 16:41:42 -0800 Subject: [PATCH 042/131] added headers --- google/cloud/firestore_v1/pipeline.py | 14 ++++++++++++++ .../cloud/firestore_v1/pipeline_expressions.py | 17 ++++++++++++++++- google/cloud/firestore_v1/pipeline_stages.py | 14 ++++++++++++++ tests/system/test_pipeline_acceptance.py | 14 ++++++++++++++ 4 files changed, 58 insertions(+), 1 deletion(-) diff --git a/google/cloud/firestore_v1/pipeline.py b/google/cloud/firestore_v1/pipeline.py index 73da00201..e19acde82 100644 --- a/google/cloud/firestore_v1/pipeline.py +++ b/google/cloud/firestore_v1/pipeline.py @@ -1,3 +1,17 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from __future__ import annotations from typing import Any, Dict, Iterable, List, Optional from google.cloud.firestore_v1 import pipeline_stages as stages diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index f996f5604..21a63d953 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -1,4 +1,19 @@ -from typing import Any, Iterable, List, Mapping, Union, Generic, TypeVar +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations +from typing import Any, Iterable, List, Mapping, Union, Generic, TypeVar, List, Dict, Tuple from enum import Enum from enum import auto import datetime diff --git a/google/cloud/firestore_v1/pipeline_stages.py b/google/cloud/firestore_v1/pipeline_stages.py index c3cb71c0f..8256bfd4a 100644 --- a/google/cloud/firestore_v1/pipeline_stages.py +++ b/google/cloud/firestore_v1/pipeline_stages.py @@ -1,3 +1,17 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from __future__ import annotations from typing import Any, Dict, Iterable, List, Optional, Sequence from enum import Enum diff --git a/tests/system/test_pipeline_acceptance.py b/tests/system/test_pipeline_acceptance.py index 8ce5d63aa..ab24977f2 100644 --- a/tests/system/test_pipeline_acceptance.py +++ b/tests/system/test_pipeline_acceptance.py @@ -1,3 +1,17 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from __future__ import annotations import sys import os From 6d966a8e97db28a95231f3571b07dd470d3ae9d5 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 27 Feb 2025 17:04:47 -0800 Subject: [PATCH 043/131] fixed some types --- google/cloud/firestore_v1/pipeline_expressions.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index 21a63d953..6a53d7e22 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -117,22 +117,22 @@ def lte(self, other: Any) -> "Lte": return Lte(self, self._cast_to_expr_or_convert_to_constant(other)) def in_(self, *others: Any) -> "In": - return In(self, ListOfExprs([self._cast_to_expr_or_convert_to_constant(o) for o in others])) + return In(self, [self._cast_to_expr_or_convert_to_constant(o) for o in others]) def not_in(self, *others: Any) -> "Not": return Not(self.in_(*others)) def array_concat(self, array: List[Any]) -> "ArrayConcat": - return ArrayConcat(self, ListOfExprs([self._cast_to_expr_or_convert_to_constant(o) for o in array])) + return ArrayConcat(self, [self._cast_to_expr_or_convert_to_constant(o) for o in array]) def array_contains(self, element: Any) -> "ArrayContains": return ArrayContains(self, self._cast_to_expr_or_convert_to_constant(element)) def array_contains_all(self, elements: List[Any]) -> "ArrayContainsAll": - return ArrayContainsAll(self, ListOfExprs([self._cast_to_expr_or_convert_to_constant(e) for e in elements])) + return ArrayContainsAll(self, [self._cast_to_expr_or_convert_to_constant(e) for e in elements]) def array_contains_any(self, elements: List[Any]) -> "ArrayContainsAny": - return ArrayContainsAny(self, ListOfExprs([self._cast_to_expr_or_convert_to_constant(e) for e in elements])) + return ArrayContainsAny(self, [self._cast_to_expr_or_convert_to_constant(e) for e in elements]) def array_length(self) -> "ArrayLength": return ArrayLength(self) @@ -246,10 +246,10 @@ def timestamp_sub(self, unit: Any, amount: Any) -> "TimestampSub": return TimestampSub(self, self._cast_to_expr_or_convert_to_constant(unit), self._cast_to_expr_or_convert_to_constant(amount)) def ascending(self) -> Ordering: - return Ordering.ascending(self) + return Ordering(self, Ordering.Direction.ASCENDING) def descending(self) -> Ordering: - return Ordering.descending(self) + return Ordering(self, Ordering.Direction.DESCENDING) def as_(self, alias: str) -> "ExprWithAlias": return ExprWithAlias(self, alias) From 33fee52d6c0d90dfa4a9df7cc2eb5dca19469b33 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 28 Feb 2025 14:30:59 -0800 Subject: [PATCH 044/131] fixed expression typing --- .../firestore_v1/pipeline_expressions.py | 126 +++++++----------- tests/system/pipeline_e2e.yaml | 60 +++++---- tests/system/test_pipeline_acceptance.py | 6 +- 3 files changed, 85 insertions(+), 107 deletions(-) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index 6a53d7e22..75f7e985e 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -13,15 +13,18 @@ # limitations under the License. from __future__ import annotations -from typing import Any, Iterable, List, Mapping, Union, Generic, TypeVar, List, Dict, Tuple +from typing import Any, Iterable, List, Mapping, Union, Generic, TypeVar, List, Dict, Tuple, Sequence from enum import Enum from enum import auto import datetime from dataclasses import dataclass from google.cloud.firestore_v1.types.document import Value +from google.cloud.firestore_v1.vector import Vector +from google.cloud.firestore_v1._helpers import GeoPoint +from google.cloud.firestore_v1._helpers import encode_value + +CONSTANT_TYPE = TypeVar('CONSTANT_TYPE', str, int, float, bool, datetime.datetime, bytes, GeoPoint, Vector, list, Dict[str, Any], None) -CONSTANT_ELEMENTS = Union[str, int, float, bool, datetime.datetime, bytes, tuple[float, float], None] -CONSTANT_TYPES = Union[CONSTANT_ELEMENTS, list[Union[CONSTANT_ELEMENTS, list, dict]], dict[str, Union[CONSTANT_ELEMENTS, list, dict]]] class Ordering: @@ -77,61 +80,61 @@ def _to_pb(self) -> Value: def _cast_to_expr_or_convert_to_constant(o: Any) -> "Expr": return o if isinstance(o, Expr) else Constant(o) - def add(self, other: Any) -> "Add": + def add(self, other: Expr | float) -> "Add": return Add(self, self._cast_to_expr_or_convert_to_constant(other)) - def subtract(self, other: Any) -> "Subtract": + def subtract(self, other: Expr | float) -> "Subtract": return Subtract(self, self._cast_to_expr_or_convert_to_constant(other)) - def multiply(self, other: Any) -> "Multiply": + def multiply(self, other: Expr | float) -> "Multiply": return Multiply(self, self._cast_to_expr_or_convert_to_constant(other)) - def divide(self, other: Any) -> "Divide": + def divide(self, other: Expr | float) -> "Divide": return Divide(self, self._cast_to_expr_or_convert_to_constant(other)) - def mod(self, other: Any) -> "Mod": + def mod(self, other: Expr | float) -> "Mod": return Mod(self, self._cast_to_expr_or_convert_to_constant(other)) - def logical_max(self, other: Any) -> "LogicalMax": + def logical_max(self, other: Expr | CONSTANT_TYPE) -> "LogicalMax": return LogicalMax(self, self._cast_to_expr_or_convert_to_constant(other)) - def logical_min(self, other: Any) -> "LogicalMin": + def logical_min(self, other: Expr | CONSTANT_TYPE) -> "LogicalMin": return LogicalMin(self, self._cast_to_expr_or_convert_to_constant(other)) - def eq(self, other: Any) -> "Eq": + def eq(self, other: Expr | CONSTANT_TYPE) -> "Eq": return Eq(self, self._cast_to_expr_or_convert_to_constant(other)) - def neq(self, other: Any) -> "Neq": + def neq(self, other: Expr | CONSTANT_TYPE) -> "Neq": return Neq(self, self._cast_to_expr_or_convert_to_constant(other)) - def gt(self, other: Any) -> "Gt": + def gt(self, other: Expr | CONSTANT_TYPE) -> "Gt": return Gt(self, self._cast_to_expr_or_convert_to_constant(other)) - def gte(self, other: Any) -> "Gte": + def gte(self, other: Expr | CONSTANT_TYPE) -> "Gte": return Gte(self, self._cast_to_expr_or_convert_to_constant(other)) - def lt(self, other: Any) -> "Lt": + def lt(self, other: Expr | CONSTANT_TYPE) -> "Lt": return Lt(self, self._cast_to_expr_or_convert_to_constant(other)) - def lte(self, other: Any) -> "Lte": + def lte(self, other: Expr | CONSTANT_TYPE) -> "Lte": return Lte(self, self._cast_to_expr_or_convert_to_constant(other)) - def in_(self, *others: Any) -> "In": + def in_any(self, *others: Expr | CONSTANT_TYPE) -> "In": return In(self, [self._cast_to_expr_or_convert_to_constant(o) for o in others]) - def not_in(self, *others: Any) -> "Not": - return Not(self.in_(*others)) + def not_in_any(self, *others: Expr | CONSTANT_TYPE) -> "Not": + return Not(self.in_any(*others)) - def array_concat(self, array: List[Any]) -> "ArrayConcat": + def array_concat(self, array: List[Expr | CONSTANT_TYPE]) -> "ArrayConcat": return ArrayConcat(self, [self._cast_to_expr_or_convert_to_constant(o) for o in array]) - def array_contains(self, element: Any) -> "ArrayContains": + def array_contains(self, element: Expr | CONSTANT_TYPE) -> "ArrayContains": return ArrayContains(self, self._cast_to_expr_or_convert_to_constant(element)) - def array_contains_all(self, elements: List[Any]) -> "ArrayContainsAll": + def array_contains_all(self, elements: List[Expr | CONSTANT_TYPE]) -> "ArrayContainsAll": return ArrayContainsAll(self, [self._cast_to_expr_or_convert_to_constant(e) for e in elements]) - def array_contains_any(self, elements: List[Any]) -> "ArrayContainsAny": + def array_contains_any(self, elements: List[Expr | CONSTANT_TYPE]) -> "ArrayContainsAny": return ArrayContainsAny(self, [self._cast_to_expr_or_convert_to_constant(e) for e in elements]) def array_length(self) -> "ArrayLength": @@ -167,25 +170,25 @@ def char_length(self) -> "CharLength": def byte_length(self) -> "ByteLength": return ByteLength(self) - def like(self, pattern: Any) -> "Like": + def like(self, pattern: Expr | str) -> "Like": return Like(self, self._cast_to_expr_or_convert_to_constant(pattern)) - def regex_contains(self, regex: Any) -> "RegexContains": + def regex_contains(self, regex: Expr | str) -> "RegexContains": return RegexContains(self, self._cast_to_expr_or_convert_to_constant(regex)) - def regex_matches(self, regex: Any) -> "RegexMatch": + def regex_matches(self, regex: Expr | str) -> "RegexMatch": return RegexMatch(self, self._cast_to_expr_or_convert_to_constant(regex)) - def str_contains(self, substring: Any) -> "StrContains": + def str_contains(self, substring: Expr | str) -> "StrContains": return StrContains(self, self._cast_to_expr_or_convert_to_constant(substring)) - def starts_with(self, prefix: Any) -> "StartsWith": + def starts_with(self, prefix: Expr | str) -> "StartsWith": return StartsWith(self, self._cast_to_expr_or_convert_to_constant(prefix)) - def ends_with(self, postfix: Any) -> "EndsWith": + def ends_with(self, postfix: Expr | str) -> "EndsWith": return EndsWith(self, self._cast_to_expr_or_convert_to_constant(postfix)) - def str_concat(self, *elements: Any) -> "StrConcat": + def str_concat(self, *elements: Expr | CONSTANT_TYPE) -> "StrConcat": return StrConcat(*[self._cast_to_expr_or_convert_to_constant(el) for el in elements]) def to_lower(self) -> "ToLower": @@ -200,22 +203,22 @@ def trim(self) -> "Trim": def reverse(self) -> "Reverse": return Reverse(self) - def replace_first(self, find: Any, replace: Any) -> "ReplaceFirst": + def replace_first(self, find: Expr | str, replace: Expr | str) -> "ReplaceFirst": return ReplaceFirst(self, self._cast_to_expr_or_convert_to_constant(find), self._cast_to_expr_or_convert_to_constant(replace)) - def replace_all(self, find: Any, replace: Any) -> "ReplaceAll": + def replace_all(self, find: Expr | str, replace: Expr | str) -> "ReplaceAll": return ReplaceAll(self, self._cast_to_expr_or_convert_to_constant(find), self._cast_to_expr_or_convert_to_constant(replace)) def map_get(self, key: str) -> "MapGet": return MapGet(self, key) - def cosine_distance(self, other: Any) -> "CosineDistance": + def cosine_distance(self, other: Expr | list[float] | Vector) -> "CosineDistance": return CosineDistance(self, self._cast_to_expr_or_convert_to_constant(other)) - def euclidean_distance(self, other: Any) -> "EuclideanDistance": + def euclidean_distance(self, other: Expr | list[float] | Vector) -> "EuclideanDistance": return EuclideanDistance(self, self._cast_to_expr_or_convert_to_constant(other)) - def dot_product(self, other: Any) -> "DotProduct": + def dot_product(self, other: Expr | list[float] | Vector) -> "DotProduct": return DotProduct(self, self._cast_to_expr_or_convert_to_constant(other)) def vector_length(self) -> "VectorLength": @@ -239,10 +242,10 @@ def timestamp_to_unix_seconds(self) -> "TimestampToUnixSeconds": def unix_seconds_to_timestamp(self) -> "UnixSecondsToTimestamp": return UnixSecondsToTimestamp(self) - def timestamp_add(self, unit: Any, amount: Any) -> "TimestampAdd": + def timestamp_add(self, unit: Expr | str, amount: Expr | float) -> "TimestampAdd": return TimestampAdd(self, self._cast_to_expr_or_convert_to_constant(unit), self._cast_to_expr_or_convert_to_constant(amount)) - def timestamp_sub(self, unit: Any, amount: Any) -> "TimestampSub": + def timestamp_sub(self, unit: Expr | str, amount: Expr | float) -> "TimestampSub": return TimestampSub(self, self._cast_to_expr_or_convert_to_constant(unit), self._cast_to_expr_or_convert_to_constant(amount)) def ascending(self) -> Ordering: @@ -254,46 +257,23 @@ def descending(self) -> Ordering: def as_(self, alias: str) -> "ExprWithAlias": return ExprWithAlias(self, alias) - -class Constant(Expr): - def __init__(self, value: CONSTANT_TYPES=None): - self.value = value +class Constant(Expr, Generic[CONSTANT_TYPE]): + def __init__(self, value: CONSTANT_TYPE): + self.value: CONSTANT_TYPE = value @staticmethod - def of(value:CONSTANT_TYPES): + def of(value:CONSTANT_TYPE) -> Constant[CONSTANT_TYPE]: return Constant(value) def __repr__(self): return f"Constant.of({self.value!r})" - def _to_pb(self): - if self.value is None: - return Value(null_value=0) - elif isinstance(self.value, bool): - return Value(boolean_value=self.value) - elif isinstance(self.value, int): - return Value(integer_value=self.value) - elif isinstance(self.value, float): - return Value(double_value=self.value) - elif isinstance(self.value, datetime.datetime): - return Value(timestamp_value=self.value.timestamp()) - elif isinstance(self.value, str): - return Value(string_value=self.value) - elif isinstance(self.value, bytes): - return Value(bytes_value=self.value) - elif isinstance(self.value, tuple) and len(self.value) == 2 and isinstance(self.value[0], float) and isinstance(self.value[1], float): - return Value(geo_point_value=self.value) - elif isinstance(self.value, list): - return Value(array_value={"values":[Constant(v)._to_pb() for v in self.value]}) - elif isinstance(self.value, dict): - return Value(map_value={"fields": {k: Constant(v)._to_pb() for k, v in self.value.items()}}) - else: - raise ValueError(f"Unsupported type: {type(self.value)}") - + def _to_pb(self) -> Value: + return encode_value(self.value) class ListOfExprs(Expr): def __init__(self, exprs: List[Expr]): - self.exprs: list[Expr] = [Field.of(e) if isinstance(e, str) else e for e in exprs] + self.exprs: list[Expr] = exprs def _to_pb(self): return Value(array_value={"values": [e._to_pb() for e in self.exprs]}) @@ -302,9 +282,9 @@ def _to_pb(self): class Function(Expr): """A type of Expression that takes in inputs and gives outputs.""" - def __init__(self, name: str, params: List[Expr]): + def __init__(self, name: str, params: Sequence[Expr]): self.name = name - self.params = [Field.of(p) if isinstance(p, str) else p for p in params] + self.params = list(params) def __repr__(self): return f"{self.__class__.__name__}({', '.join([repr(p) for p in self.params])})" @@ -323,15 +303,11 @@ def __init__(self, left: Expr, right: Expr): class DotProduct(Function): def __init__(self, vector1: Expr, vector2: Expr): - vector1 = Constant(vector1) if isinstance(vector1, list) else vector1 - vector2 = Constant(vector2) if isinstance(vector2, list) else vector2 super().__init__("dot_product", [vector1, vector2]) class EuclideanDistance(Function): def __init__(self, vector1: Expr, vector2: Expr): - vector1 = Constant(vector1) if isinstance(vector1, list) else vector1 - vector2 = Constant(vector2) if isinstance(vector2, list) else vector2 super().__init__("euclidean_distance", [vector1, vector2]) @@ -502,8 +478,6 @@ def __init__(self, value: Expr): class CosineDistance(Function): def __init__(self, vector1: Expr, vector2: Expr): - vector1 = Constant(vector1) if isinstance(vector1, list) else vector1 - vector2 = Constant(vector2) if isinstance(vector2, list) else vector2 super().__init__("cosine_distance", [vector1, vector2]) @@ -532,7 +506,7 @@ def __init__(self, value: Expr, distinct: bool=False): class Count(Accumulator): - def __init__(self, value: Expr = None): + def __init__(self, value: Expr | None = None): super().__init__("count", [value] if value else []) diff --git a/tests/system/pipeline_e2e.yaml b/tests/system/pipeline_e2e.yaml index c5ca0d840..d88461add 100644 --- a/tests/system/pipeline_e2e.yaml +++ b/tests/system/pipeline_e2e.yaml @@ -147,10 +147,12 @@ tests: - Count - "count" - ExprWithAlias: - - Avg: rating + - Avg: + - Field: rating - "avg_rating" - ExprWithAlias: - - Max: rating + - Max: + - Field: rating - "max_rating" results: - count: 2 @@ -190,11 +192,12 @@ tests: - Field: published - Constant: 1984 - Aggregate: - accumulators: - - ExprWithAlias: - - Avg: rating - - "avg_rating" - groups: [genre] + accumulators: + - ExprWithAlias: + - Avg: + - Field: rating + - "avg_rating" + groups: [genre] - Where: - Gt: - Field: avg_rating @@ -214,10 +217,12 @@ tests: - Count - "count" - ExprWithAlias: - - Max: rating + - Max: + - Field: rating - "max_rating" - ExprWithAlias: - - Min: published + - Min: + - Field: published - "min_published" results: - count: 10 @@ -369,8 +374,8 @@ tests: - Collection: books - Where: - ArrayContains: - - tags - - comedy + - Constant: tags + - Constant: comedy results: - title: The Hitchhiker's Guide to the Galaxy author: Douglas Adams @@ -389,8 +394,9 @@ tests: - Collection: books - Where: - ArrayContainsAny: - - tags - - ["comedy", "classic"] + - Field: tags + - - Constant: comedy + - Constant: classic - Select: - title results: @@ -402,7 +408,8 @@ tests: - Where: - ArrayContainsAll: - Field: tags - - ["adventure", "magic"] + - - Constant: adventure + - Constant: magic - Select: - title results: @@ -437,7 +444,8 @@ tests: - ExprWithAlias: - ArrayConcat: - Field: tags - - ["newTag1", "newTag2"] + - - Constant: newTag1 + - Constant: newTag2 - "modifiedTags" - Limit: 1 results: @@ -533,8 +541,8 @@ tests: - ExprWithAlias: - ReplaceFirst: - Field: title - - The - - A + - Constant: The + - Constant: A - "replaced_title" - Where: - Eq: @@ -549,8 +557,8 @@ tests: - ExprWithAlias: - ReplaceAll: - Field: title - - " " - - "_" + - Constant: " " + - Constant: "_" - "replaced_title" - Where: - Eq: @@ -835,18 +843,18 @@ tests: - Select: - ExprWithAlias: - CosineDistance: - - [0.1, 0.1] - - [0.5, 0.8] + - Constant: [[0.1, 0.1]] + - Constant: [[0.5, 0.8]] - "cosineDistance" - ExprWithAlias: - DotProduct: - - [0.1, 0.1] - - [0.5, 0.8] + - Constant: [[0.1, 0.1]] + - Constant: [[0.5, 0.8]] - "dotProductDistance" - ExprWithAlias: - EuclideanDistance: - - [0.1, 0.1] - - [0.5, 0.8] + - Constant: [[0.1, 0.1]] + - Constant: [[0.5, 0.8]] - "euclideanDistance" - Limit: 1 results: @@ -890,7 +898,7 @@ tests: - Where: - Eq: - Field: title - - "The Hitchhiker's Guide to the Galaxy" + - Constant: "The Hitchhiker's Guide to the Galaxy" - Replace: awards results: - title: The Hitchhiker's Guide to the Galaxy diff --git a/tests/system/test_pipeline_acceptance.py b/tests/system/test_pipeline_acceptance.py index ab24977f2..2179e9f2b 100644 --- a/tests/system/test_pipeline_acceptance.py +++ b/tests/system/test_pipeline_acceptance.py @@ -67,13 +67,9 @@ def _apply_yaml_args(cls, yaml_args): elif isinstance(yaml_args, list): # yaml has an array of arguments. Treat as args return cls(*parse_expressions(yaml_args)) - elif yaml_args is not None: + else: # yaml has a single argument return cls(parse_expressions(yaml_args)) - else: - # no arguments - return cls() - def parse_pipeline(pipeline: list[dict[str, Any], str]): """ From cb8539df529c246a3c03ac277d9009bd4fc880b5 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 28 Feb 2025 15:17:32 -0800 Subject: [PATCH 045/131] added abstract to expr --- google/cloud/firestore_v1/pipeline_expressions.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index 75f7e985e..7d4373891 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -14,6 +14,8 @@ from __future__ import annotations from typing import Any, Iterable, List, Mapping, Union, Generic, TypeVar, List, Dict, Tuple, Sequence +from abc import ABC +from abc import abstractmethod from enum import Enum from enum import auto import datetime @@ -65,7 +67,7 @@ class Mode(Enum): def __post_init__(self): self.mode = SampleOptions.Mode(self.mode) if isinstance(self.mode, str) else self.mode -class Expr: +class Expr(ABC): """Represents an expression that can be evaluated to a value within the execution of a pipeline. """ @@ -73,6 +75,7 @@ class Expr: def __repr__(self): return f"{self.__class__.__name__}()" + @abstractmethod def _to_pb(self) -> Value: raise NotImplementedError @@ -518,6 +521,7 @@ def __init__(self, value: Expr, distinct: bool=False): class Selectable(Expr): """Points at something in the database?""" + @abstractmethod def _to_map(self): raise NotImplementedError From 380dce967c3486bc14706a1ffa64336ef68d505c Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 28 Feb 2025 15:38:02 -0800 Subject: [PATCH 046/131] fixed types in pipeline_stages --- google/cloud/firestore_v1/pipeline_stages.py | 26 +++++++++++--------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/google/cloud/firestore_v1/pipeline_stages.py b/google/cloud/firestore_v1/pipeline_stages.py index 8256bfd4a..c35ddc3b9 100644 --- a/google/cloud/firestore_v1/pipeline_stages.py +++ b/google/cloud/firestore_v1/pipeline_stages.py @@ -19,6 +19,9 @@ from google.cloud.firestore_v1.types.document import Pipeline from google.cloud.firestore_v1.types.document import Value +from google.cloud.firestore_v1.document import DocumentReference +from google.cloud.firestore_v1.vector import Vector +from google.cloud.firestore_v1.base_vector_query import DistanceMeasure from google.cloud.firestore_v1.pipeline_expressions import ( Accumulator, Expr, @@ -27,6 +30,7 @@ FilterCondition, Selectable, SampleOptions, + Ordering ) class FindNearestOptions: @@ -49,11 +53,11 @@ def __init__(self, custom_name: Optional[str] = None): self.name = custom_name or type(self).__name__.lower() def _to_pb(self) -> Pipeline.Stage: - return Pipeline.Stage(name=self.name, args=[*self._pb_args()], options=self._pb_options()) + return Pipeline.Stage(name=self.name, args=self._pb_args(), options=self._pb_options()) - def _pb_args(self) -> tuple[Value, ...]: + def _pb_args(self) -> list[Value]: """Return Ordered list of arguments the given stage expects""" - return () + return [] def _pb_options(self) -> dict[str, Value]: """Return optional named arguments that certain functions may support.""" @@ -69,7 +73,7 @@ def __init__(self, *fields: Selectable): super().__init__("add_fields") self.fields = list(fields) - def _pb_args(self) -> list[Value]: + def _pb_args(self): return [Value(map_value={"fields": {m[0]: m[1] for m in [f._to_map() for f in self.fields]}})] class Aggregate(Stage): @@ -83,7 +87,7 @@ def __init__( self.groups: list[Selectable] = [Field(f) if isinstance(f, str) else f for f in groups] self.accumulators: list[ExprWithAlias[Accumulator]] = [*accumulators, *extra_accumulators] - def _pb_args(self) -> list[Value]: + def _pb_args(self): return [ Value(map_value={"fields": {m[0]: m[1] for m in [f._to_map() for f in self.accumulators]}}), Value(map_value={"fields": {m[0]: m[1] for m in [f._to_map() for f in self.groups]}}) @@ -132,30 +136,30 @@ def _pb_args(self) -> list[Value]: class Documents(Stage): - def __init__(self, *documents: str): + def __init__(self, *paths: str): super().__init__() - self.documents = list(documents) + self.paths = paths @staticmethod def of(*documents: "DocumentReference") -> "Documents": doc_paths = ["/" + doc.path for doc in documents] - return Documents(doc_paths) + return Documents(*doc_paths) def _pb_args(self): - return [Value(list_value={"values": [Value(string_value=doc) for doc in self.documents]})] + return [Value(list_value={"values": [Value(string_value=path) for path in self.paths]})] class FindNearest(Stage): def __init__( self, property: Expr, - vector: List[float], + vector: Sequence[float] | Vector, distance_measure: "DistanceMeasure", options: Optional["FindNearestOptions"] = None, ): super().__init__("find_nearest") self.property = property - self.vector = vector + self.vector: Vector = vector if isinstance(vector, Vector) else Vector(vector) self.distance_measure = distance_measure self.options = options or FindNearestOptions() From 831886f7d7605361add9c97a6d02600a483dbecf Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 28 Feb 2025 16:29:10 -0800 Subject: [PATCH 047/131] align typing --- google/cloud/firestore_v1/pipeline.py | 37 +++++++++++--------- google/cloud/firestore_v1/pipeline_stages.py | 23 ++++++------ 2 files changed, 32 insertions(+), 28 deletions(-) diff --git a/google/cloud/firestore_v1/pipeline.py b/google/cloud/firestore_v1/pipeline.py index e19acde82..106963c0b 100644 --- a/google/cloud/firestore_v1/pipeline.py +++ b/google/cloud/firestore_v1/pipeline.py @@ -13,9 +13,11 @@ # limitations under the License. from __future__ import annotations -from typing import Any, Dict, Iterable, List, Optional +from typing import Any, Dict, Iterable, List, Optional, Sequence from google.cloud.firestore_v1 import pipeline_stages as stages from google.cloud.firestore_v1.types.document import Pipeline as Pipeline_pb +from google.cloud.firestore_v1.vector import Vector +from google.cloud.firestore_v1.base_vector_query import DistanceMeasure from google.cloud.firestore_v1.pipeline_expressions import ( Accumulator, Expr, @@ -43,16 +45,16 @@ def __repr__(self): def _to_pb(self) -> Pipeline_pb: return Pipeline_pb(stages=[s._to_pb() for s in self.stages]) - def add_fields(self, fields: Dict[str, Expr]) -> Pipeline: - self.stages.append(stages.AddFields(fields)) + def add_fields(self, *fields: Selectable) -> Pipeline: + self.stages.append(stages.AddFields(*fields)) return self - def remove_fields(self, fields: List[Field]) -> Pipeline: - self.stages.append(stages.RemoveFields(fields)) + def remove_fields(self, *fields: Field | str) -> Pipeline: + self.stages.append(stages.RemoveFields(*fields)) return self - def select(self, projections: Dict[str, Expr]) -> Pipeline: - self.stages.append(stages.Select(projections)) + def select(self, *selections: str | Selectable) -> Pipeline: + self.stages.append(stages.Select(*selections)) return self def where(self, condition: FilterCondition) -> Pipeline: @@ -62,16 +64,16 @@ def where(self, condition: FilterCondition) -> Pipeline: def find_nearest( self, field: str | Expr, - vector: "Vector", - distance_measure: "FindNearest.DistanceMeasure", + vector: Sequence[float] | "Vector", + distance_measure: "DistanceMeasure", limit: int | None, options: Optional[stages.FindNearestOptions] = None, ) -> Pipeline: self.stages.append(stages.FindNearest(field, vector, distance_measure, options)) return self - def sort(self, orders: List[stages.Ordering]) -> Pipeline: - self.stages.append(stages.Sort(orders)) + def sort(self, *orders: stages.Ordering) -> Pipeline: + self.stages.append(stages.Sort(*orders)) return self def replace( @@ -98,8 +100,8 @@ def unnest( self.stages.append(stages.Unnest(field_name, options)) return self - def generic_stage(self, name: str, params: List[Any]) -> Pipeline: - self.stages.append(stages.GenericStage(name, params)) + def generic_stage(self, name: str, *params: Expr) -> Pipeline: + self.stages.append(stages.GenericStage(name, *params)) return self def offset(self, offset: int) -> Pipeline: @@ -112,13 +114,14 @@ def limit(self, limit: int) -> Pipeline: def aggregate( self, - accumulators: Optional[Dict[str, Accumulator]] = None, + *accumulators: ExprWithAlias[Accumulator], + groups: Sequence[str | Selectable] = (), ) -> Pipeline: - self.stages.append(stages.Aggregate(accumulators=accumulators)) + self.stages.append(stages.Aggregate(*accumulators, groups=groups)) return self - def distinct(self, fields: Dict[str, Expr]) -> Pipeline: - self.stages.append(stages.Distinct(fields)) + def distinct(self, *fields: str | Selectable) -> Pipeline: + self.stages.append(stages.Distinct(*fields)) return self def execute(self) -> list["PipelineResult"]: diff --git a/google/cloud/firestore_v1/pipeline_stages.py b/google/cloud/firestore_v1/pipeline_stages.py index c35ddc3b9..929996fba 100644 --- a/google/cloud/firestore_v1/pipeline_stages.py +++ b/google/cloud/firestore_v1/pipeline_stages.py @@ -17,11 +17,12 @@ from enum import Enum from enum import auto -from google.cloud.firestore_v1.types.document import Pipeline +from google.cloud.firestore_v1.types.document import Pipeline as Pipeline_pb from google.cloud.firestore_v1.types.document import Value from google.cloud.firestore_v1.document import DocumentReference from google.cloud.firestore_v1.vector import Vector from google.cloud.firestore_v1.base_vector_query import DistanceMeasure +from google.cloud.firestore_v1.pipeline import Pipeline from google.cloud.firestore_v1.pipeline_expressions import ( Accumulator, Expr, @@ -52,8 +53,8 @@ class Stage: def __init__(self, custom_name: Optional[str] = None): self.name = custom_name or type(self).__name__.lower() - def _to_pb(self) -> Pipeline.Stage: - return Pipeline.Stage(name=self.name, args=self._pb_args(), options=self._pb_options()) + def _to_pb(self) -> Pipeline_pb.Stage: + return Pipeline_pb.Stage(name=self.name, args=self._pb_args(), options=self._pb_options()) def _pb_args(self) -> list[Value]: """Return Ordered list of arguments the given stage expects""" @@ -152,20 +153,20 @@ def _pb_args(self): class FindNearest(Stage): def __init__( self, - property: Expr, + field: str | Expr, vector: Sequence[float] | Vector, distance_measure: "DistanceMeasure", options: Optional["FindNearestOptions"] = None, ): super().__init__("find_nearest") - self.property = property + self.field: Expr = Field(field) if isinstance(field, str) else field self.vector: Vector = vector if isinstance(vector, Vector) else Vector(vector) self.distance_measure = distance_measure self.options = options or FindNearestOptions() def _pb_args(self): return [ - self.property._to_pb(), + self.field._to_pb(), Value(array_value={"values": self.vector}), Value(string_value=self.distance_measure.value), ] @@ -179,9 +180,9 @@ def _pb_options(self) -> dict[str, Value]: return options class GenericStage(Stage): - def __init__(self, name: str, *params: Value): + def __init__(self, name: str, *params: Expr | Value): super().__init__(name) - self.params = list(params) + self.params: list[Value] = [p._to_pb() if isinstance(p, Expr) else p for p in params] def _pb_args(self): return self.params @@ -244,9 +245,9 @@ def _pb_args(self): class Select(Stage): - def __init__(self, *fields: str | Selectable): + def __init__(self, *selections: str | Selectable): super().__init__() - self.projections = [Field(f) if isinstance(f, str) else f for f in fields] + self.projections = [Field(s) if isinstance(s, str) else s for s in selections] def _pb_args(self) -> list[Value]: return [Value(map_value={"fields": {m[0]: m[1] for m in [f._to_map() for f in self.projections]}})] @@ -262,7 +263,7 @@ def _pb_args(self): class Union(Stage): - def __init__(self, other: "Pipeline"): + def __init__(self, other: Pipeline): super().__init__() self.other = other From 79df2057d2f30e0d423f33203ef33d66bb7a4a22 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 28 Feb 2025 17:23:06 -0800 Subject: [PATCH 048/131] added stubs for execute_pipeline --- google/cloud/firestore_v1/base_collection.py | 7 +++++ google/cloud/firestore_v1/pipeline.py | 33 ++++++++++++++------ google/cloud/firestore_v1/pipeline_stages.py | 7 +++-- tests/system/test_pipeline_acceptance.py | 7 +++-- 4 files changed, 39 insertions(+), 15 deletions(-) diff --git a/google/cloud/firestore_v1/base_collection.py b/google/cloud/firestore_v1/base_collection.py index 1ac1ba318..dfcba7040 100644 --- a/google/cloud/firestore_v1/base_collection.py +++ b/google/cloud/firestore_v1/base_collection.py @@ -35,6 +35,8 @@ from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.base_query import QueryType +from google.cloud.firestore_v1.pipeline import Pipeline +from google.cloud.firestore_v1.pipeline_stages import Collection as CollectionStage if TYPE_CHECKING: # pragma: NO COVER # Types needed only for Type Hints @@ -590,6 +592,11 @@ def find_nearest( distance_threshold=distance_threshold, ) + def pipeline(self) -> Pipeline: + path_str = "/".join(self._path) + # TODO: add other query fields + return Pipeline(self._client, CollectionStage(path_str)) + def _auto_id() -> str: """Generate a "random" automatically generated ID. diff --git a/google/cloud/firestore_v1/pipeline.py b/google/cloud/firestore_v1/pipeline.py index 106963c0b..a4ed69cd1 100644 --- a/google/cloud/firestore_v1/pipeline.py +++ b/google/cloud/firestore_v1/pipeline.py @@ -13,9 +13,11 @@ # limitations under the License. from __future__ import annotations -from typing import Any, Dict, Iterable, List, Optional, Sequence +from typing import AsyncIterable, Any, Dict, Iterable, List, Optional, Sequence from google.cloud.firestore_v1 import pipeline_stages as stages -from google.cloud.firestore_v1.types.document import Pipeline as Pipeline_pb +from google.cloud.firestore_v1.types.pipeline import StructuredPipeline as StructuredPipeline_pb +from google.cloud.firestore_v1.types.firestore import ExecutePipelineRequest +from google.cloud.firestore_v1.types.firestore import ExecutePipelineResponse from google.cloud.firestore_v1.vector import Vector from google.cloud.firestore_v1.base_vector_query import DistanceMeasure from google.cloud.firestore_v1.pipeline_expressions import ( @@ -30,7 +32,8 @@ class Pipeline: - def __init__(self, *stages: stages.Stage): + def __init__(self, client, *stages: stages.Stage): + self._client = client self.stages = list(stages) def __repr__(self): @@ -42,8 +45,8 @@ def __repr__(self): stages_str = ",\n ".join([repr(s) for s in self.stages]) return f"Pipeline(\n {stages_str}\n)" - def _to_pb(self) -> Pipeline_pb: - return Pipeline_pb(stages=[s._to_pb() for s in self.stages]) + def _to_pb(self) -> StructuredPipeline_pb: + return StructuredPipeline_pb(pipeline={"stages":[s._to_pb() for s in self.stages]}) def add_fields(self, *fields: Selectable) -> Pipeline: self.stages.append(stages.AddFields(*fields)) @@ -124,8 +127,18 @@ def distinct(self, *fields: str | Selectable) -> Pipeline: self.stages.append(stages.Distinct(*fields)) return self - def execute(self) -> list["PipelineResult"]: - return [] - - async def execute_async(self) -> List["PipelineResult"]: - return [] + def execute(self) -> Iterable["ExecutePipelineResponse"]: + breakpoint() + request = ExecutePipelineRequest( + database=self._client._database, + structured_pipeline=self._to_pb(), + transaction=b"test", + ) + results = self._client._firestore_api.execute_pipeline(request) + print(request) + return results + + async def execute_async(self) -> AsyncIterable["ExecutePipelineResponse"]: + from google.cloud.firestore_v1.async_client import AsyncClient + if not isinstance(self._client, AsyncClient): + raise TypeError("execute_async requires AsyncClient") diff --git a/google/cloud/firestore_v1/pipeline_stages.py b/google/cloud/firestore_v1/pipeline_stages.py index 929996fba..32cf5c382 100644 --- a/google/cloud/firestore_v1/pipeline_stages.py +++ b/google/cloud/firestore_v1/pipeline_stages.py @@ -13,7 +13,7 @@ # limitations under the License. from __future__ import annotations -from typing import Any, Dict, Iterable, List, Optional, Sequence +from typing import Any, Dict, Iterable, List, Optional, Sequence, TYPE_CHECKING from enum import Enum from enum import auto @@ -22,7 +22,6 @@ from google.cloud.firestore_v1.document import DocumentReference from google.cloud.firestore_v1.vector import Vector from google.cloud.firestore_v1.base_vector_query import DistanceMeasure -from google.cloud.firestore_v1.pipeline import Pipeline from google.cloud.firestore_v1.pipeline_expressions import ( Accumulator, Expr, @@ -34,6 +33,10 @@ Ordering ) +if TYPE_CHECKING: + from google.cloud.firestore_v1.pipeline import Pipeline + + class FindNearestOptions: def __init__( self, diff --git a/tests/system/test_pipeline_acceptance.py b/tests/system/test_pipeline_acceptance.py index 2179e9f2b..900e45f46 100644 --- a/tests/system/test_pipeline_acceptance.py +++ b/tests/system/test_pipeline_acceptance.py @@ -71,7 +71,7 @@ def _apply_yaml_args(cls, yaml_args): # yaml has a single argument return cls(parse_expressions(yaml_args)) -def parse_pipeline(pipeline: list[dict[str, Any], str]): +def parse_pipeline(client, pipeline: list[dict[str, Any], str]): """ parse a yaml list of pipeline stages into firestore.pipeline_stages.Stage classes """ @@ -89,7 +89,7 @@ def parse_pipeline(pipeline: list[dict[str, Any], str]): # yaml has no arguments stage_obj = stage_cls() result_list.append(stage_obj) - return Pipeline(*result_list) + return Pipeline(client, *result_list) def _is_expr_string(yaml_str): return isinstance(yaml_str, str) and \ @@ -122,7 +122,8 @@ def parse_expressions(yaml_element: Any): "test_dict", loader(), ids=lambda x: f"{x.get('description', '')}" ) def test_e2e_scenario(test_dict): - pipeline = parse_pipeline(test_dict["pipeline"]) + client = Client(project=FIRESTORE_PROJECT, database=FIRESTORE_TEST_DB) + pipeline = parse_pipeline(client, test_dict["pipeline"]) print(pipeline._to_pb()) # before_ast = ast.parse(test_dict["before"]) From e760c443deefcfc984b78a84a55530aeac9c3b6f Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 25 Mar 2025 16:02:18 -0700 Subject: [PATCH 049/131] fixed union stage proto --- google/cloud/firestore_v1/pipeline_stages.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/cloud/firestore_v1/pipeline_stages.py b/google/cloud/firestore_v1/pipeline_stages.py index 32cf5c382..0ebf6aa8a 100644 --- a/google/cloud/firestore_v1/pipeline_stages.py +++ b/google/cloud/firestore_v1/pipeline_stages.py @@ -271,7 +271,7 @@ def __init__(self, other: Pipeline): self.other = other def _pb_args(self): - return [Value(pipeline_value=self.other._to_pb())] + return [Value(pipeline_value=self.other._to_pb().pipeline)] class Unnest(Stage): From ea1e2ba42136330b65a14b42792db42b53b70ad6 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 25 Mar 2025 16:03:00 -0700 Subject: [PATCH 050/131] propagate client in test parsers --- tests/system/test_pipeline_acceptance.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/tests/system/test_pipeline_acceptance.py b/tests/system/test_pipeline_acceptance.py index 900e45f46..1f5f48cab 100644 --- a/tests/system/test_pipeline_acceptance.py +++ b/tests/system/test_pipeline_acceptance.py @@ -61,15 +61,15 @@ def loader(): document_ref.delete() -def _apply_yaml_args(cls, yaml_args): +def _apply_yaml_args(cls, client, yaml_args): if isinstance(yaml_args, dict): - return cls(**parse_expressions(yaml_args)) + return cls(**parse_expressions(client, yaml_args)) elif isinstance(yaml_args, list): # yaml has an array of arguments. Treat as args - return cls(*parse_expressions(yaml_args)) + return cls(*parse_expressions(client, yaml_args)) else: # yaml has a single argument - return cls(parse_expressions(yaml_args)) + return cls(parse_expressions(client, yaml_args)) def parse_pipeline(client, pipeline: list[dict[str, Any], str]): """ @@ -84,7 +84,7 @@ def parse_pipeline(client, pipeline: list[dict[str, Any], str]): # find arguments if given if isinstance(stage, dict): stage_yaml_args = stage[stage_name] - stage_obj = _apply_yaml_args(stage_cls, stage_yaml_args) + stage_obj = _apply_yaml_args(stage_cls, client, stage_yaml_args) else: # yaml has no arguments stage_obj = stage_cls() @@ -96,22 +96,23 @@ def _is_expr_string(yaml_str): yaml_str[0].isupper() and \ hasattr(pipeline_expressions, yaml_str) -def parse_expressions(yaml_element: Any): +def parse_expressions(client, yaml_element: Any): if isinstance(yaml_element, list): - return [parse_expressions(v) for v in yaml_element] + return [parse_expressions(client, v) for v in yaml_element] elif isinstance(yaml_element, dict): if len(yaml_element) == 1 and _is_expr_string(list(yaml_element)[0]): # build pipeline expressions if possible cls_str = list(yaml_element)[0] cls = getattr(pipeline_expressions, cls_str) yaml_args = yaml_element[cls_str] - return _apply_yaml_args(cls, yaml_args) + return _apply_yaml_args(cls, client, yaml_args) elif len(yaml_element) == 1 and list(yaml_element)[0] == "Pipeline": # find Pipeline objects for Union expressions - return parse_pipeline(yaml_element["Pipeline"]) + other_ppl = yaml_element["Pipeline"] + return parse_pipeline(client, other_ppl) else: # otherwise, return dict - return {parse_expressions(k): parse_expressions(v) for k,v in yaml_element.items()} + return {parse_expressions(client, k): parse_expressions(client, v) for k,v in yaml_element.items()} elif _is_expr_string(yaml_element): return getattr(pipeline_expressions, yaml_element)() else: From b6b082dc2aba88b306c5fa413d5139021e82c5ac Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 25 Mar 2025 16:05:25 -0700 Subject: [PATCH 051/131] enable mypy for pipeline code --- google/cloud/firestore_v1/pipeline.py | 2 ++ noxfile.py | 6 +++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/google/cloud/firestore_v1/pipeline.py b/google/cloud/firestore_v1/pipeline.py index a4ed69cd1..234d54a02 100644 --- a/google/cloud/firestore_v1/pipeline.py +++ b/google/cloud/firestore_v1/pipeline.py @@ -142,3 +142,5 @@ async def execute_async(self) -> AsyncIterable["ExecutePipelineResponse"]: from google.cloud.firestore_v1.async_client import AsyncClient if not isinstance(self._client, AsyncClient): raise TypeError("execute_async requires AsyncClient") + # TODO + raise NotImplementedError diff --git a/noxfile.py b/noxfile.py index 1a5625a30..503b049ac 100644 --- a/noxfile.py +++ b/noxfile.py @@ -150,7 +150,11 @@ def mypy(session): session.install("-e", ".") session.install("mypy", "types-setuptools") # TODO: also verify types on tests, all of google package - session.run("mypy", "-p", "google.cloud.firestore", "--no-incremental") + session.run("mypy", + "-p", "google.cloud.firestore_v1.pipeline_expressions", + "-p", "google.cloud.firestore_v1.pipeline_stages", + "-p", "google.cloud.firestore_v1.pipeline", + "--no-incremental") @nox.session(python=DEFAULT_PYTHON_VERSION) From 658964c191c995090c92effbd52eb7bb2f93f28b Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 25 Mar 2025 16:18:19 -0700 Subject: [PATCH 052/131] added query.pipeline --- google/cloud/firestore_v1/base_collection.py | 5 -- google/cloud/firestore_v1/base_query.py | 65 ++++++++++++++++++++ 2 files changed, 65 insertions(+), 5 deletions(-) diff --git a/google/cloud/firestore_v1/base_collection.py b/google/cloud/firestore_v1/base_collection.py index dfcba7040..506234507 100644 --- a/google/cloud/firestore_v1/base_collection.py +++ b/google/cloud/firestore_v1/base_collection.py @@ -592,11 +592,6 @@ def find_nearest( distance_threshold=distance_threshold, ) - def pipeline(self) -> Pipeline: - path_str = "/".join(self._path) - # TODO: add other query fields - return Pipeline(self._client, CollectionStage(path_str)) - def _auto_id() -> str: """Generate a "random" automatically generated ID. diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index 3a473094a..01dda6423 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -57,6 +57,7 @@ query, ) from google.cloud.firestore_v1.vector import Vector +from google.cloud.firestore_v1.pipeline import Pipeline if TYPE_CHECKING: # pragma: NO COVER from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator @@ -1103,6 +1104,70 @@ def recursive(self: QueryType) -> QueryType: return copied + def pipeline(self) -> Pipeline: + # TODO: add extensive tests + ppl = Pipeline(self._client) + if self._all_descendants: + ppl = ppl.collection_group(self._parent.id) + else: + ppl = ppl.collection("/".join(self._parent._path)) + + # Filters + for filter in self._field_filters: + ppl = ppl.where(filter) + + # Projections + if self._projection and self._projection.fields: + ppl = ppl.select( + *[ + field.field_path + for field in self._projection.fields + ] + ) + + # Orders + orders = self._normalize_orders() + if orders: + # Add exists filters to match Query's implicit orderby semantics. + exists = [] + for order in orders: + # skip __name__ + if order.field.field_path == "__name__": + continue + exists.append(field_path_module.FieldPath(order.field.field_path).exists()) + + if len(exists) > 1: + ppl = ppl.where(field_path_module.And(*exists)) + elif len(exists) == 1: + ppl = ppl.where(exists[0]) + + orderings = [] + for order in orders: + direction = ( + "asc" if order.direction == StructuredQuery.Direction.ASCENDING else "desc" + ) + orderings.append( + getattr(field_path_module.FieldPath(order.field.field_path), direction)() + ) + ppl = ppl.sort(*orderings) + + # Cursors, Limit and Offset + if self._start_at or self._end_at or self._limit_to_last: + ppl = ppl.paginate( + start_at=self._start_at, + end_at=self._end_at, + limit=self._limit, + limit_to_last=self._limit_to_last, + offset=self._offset, + ) + else: # Limit & Offset without cursors + if self._offset: + ppl = ppl.offset(self._offset) + if self._limit: + ppl = ppl.limit(self._limit) + + return ppl + def _comparator(self, doc1, doc2) -> int: _orders = self._orders From 7e12b74d839813beb266bb6ace80a46dd8050bb1 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 25 Mar 2025 16:42:26 -0700 Subject: [PATCH 053/131] got queries to execute --- google/cloud/firestore_v1/pipeline.py | 8 ++++---- tests/system/test_pipeline_acceptance.py | 1 + 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/google/cloud/firestore_v1/pipeline.py b/google/cloud/firestore_v1/pipeline.py index 234d54a02..259e30aa5 100644 --- a/google/cloud/firestore_v1/pipeline.py +++ b/google/cloud/firestore_v1/pipeline.py @@ -13,6 +13,7 @@ # limitations under the License. from __future__ import annotations +import datetime from typing import AsyncIterable, Any, Dict, Iterable, List, Optional, Sequence from google.cloud.firestore_v1 import pipeline_stages as stages from google.cloud.firestore_v1.types.pipeline import StructuredPipeline as StructuredPipeline_pb @@ -128,14 +129,13 @@ def distinct(self, *fields: str | Selectable) -> Pipeline: return self def execute(self) -> Iterable["ExecutePipelineResponse"]: - breakpoint() + database_name = f"projects/{self._client.project}/databases/{self._client._database}" request = ExecutePipelineRequest( - database=self._client._database, + database=database_name, structured_pipeline=self._to_pb(), - transaction=b"test", + read_time=datetime.datetime.now(), ) results = self._client._firestore_api.execute_pipeline(request) - print(request) return results async def execute_async(self) -> AsyncIterable["ExecutePipelineResponse"]: diff --git a/tests/system/test_pipeline_acceptance.py b/tests/system/test_pipeline_acceptance.py index 1f5f48cab..c0a02abfb 100644 --- a/tests/system/test_pipeline_acceptance.py +++ b/tests/system/test_pipeline_acceptance.py @@ -126,6 +126,7 @@ def test_e2e_scenario(test_dict): client = Client(project=FIRESTORE_PROJECT, database=FIRESTORE_TEST_DB) pipeline = parse_pipeline(client, test_dict["pipeline"]) print(pipeline._to_pb()) + pipeline.execute() # before_ast = ast.parse(test_dict["before"]) # got_ast = before_ast From e463418280065fcc1cd9a0be814af95e48d1ff3b Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 25 Mar 2025 16:59:51 -0700 Subject: [PATCH 054/131] fixed some encodings --- google/cloud/firestore_v1/pipeline_expressions.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index 7d4373891..fc3604dbf 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -31,8 +31,8 @@ class Ordering: class Direction(Enum): - ASCENDING = auto() - DESCENDING = auto() + ASCENDING = "ascending" + DESCENDING = "descending" def __init__(self, expr, order_dir: Direction | str): self.expr = expr if isinstance(expr, Expr) else Field.of(expr) @@ -49,7 +49,7 @@ def _to_pb(self) -> Value: return Value( map_value={"fields": { - "direction": Value(string_value=self.order_dir.name), + "direction": Value(string_value=self.order_dir.value), "expression": self.expr._to_pb() } } @@ -316,12 +316,12 @@ def __init__(self, vector1: Expr, vector2: Expr): class LogicalMax(Function): def __init__(self, left: Expr, right: Expr): - super().__init__("logical_max", [left, right]) + super().__init__("logical_maximum", [left, right]) class LogicalMin(Function): def __init__(self, left: Expr, right: Expr): - super().__init__("logical_min", [left, right]) + super().__init__("logical_minimum", [left, right]) class MapGet(Function): @@ -490,12 +490,12 @@ class Accumulator(Function): class Max(Accumulator): def __init__(self, value: Expr, distinct: bool=False): - super().__init__("max", [value]) + super().__init__("maximum", [value]) class Min(Accumulator): def __init__(self, value: Expr, distinct: bool=False): - super().__init__("min", [value]) + super().__init__("minimum", [value]) class Sum(Accumulator): From 4f527c1f597c3cc3f4f2ffe17ed94b4db738c000 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 27 Mar 2025 14:15:24 -0700 Subject: [PATCH 055/131] broke pipelines into separate async/sync/base files --- google/cloud/firestore_v1/async_pipeline.py | 39 ++++++ google/cloud/firestore_v1/base_pipeline.py | 127 ++++++++++++++++++++ google/cloud/firestore_v1/pipeline.py | 115 ++---------------- 3 files changed, 174 insertions(+), 107 deletions(-) create mode 100644 google/cloud/firestore_v1/async_pipeline.py create mode 100644 google/cloud/firestore_v1/base_pipeline.py diff --git a/google/cloud/firestore_v1/async_pipeline.py b/google/cloud/firestore_v1/async_pipeline.py new file mode 100644 index 000000000..34b9eb45b --- /dev/null +++ b/google/cloud/firestore_v1/async_pipeline.py @@ -0,0 +1,39 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations +import datetime +from typing import AsyncIterable, TYPE_CHECKING +from google.cloud.firestore_v1 import pipeline_stages as stages +from google.cloud.firestore_v1.types.firestore import ExecutePipelineRequest +from google.cloud.firestore_v1.types.firestore import ExecutePipelineResponse +from google.cloud.firestore_v1.base_pipeline import _BasePipeline + +if TYPE_CHECKING: + from google.cloud.firestore_v1.async_client import AsyncClient + + +class Pipeline(_BasePipeline): + def __init__(self, client:AsyncClient, *stages: stages.Stage): + super().__init__(*stages) + self._client = client + + async def execute_async(self) -> AsyncIterable["ExecutePipelineResponse"]: + database_name = f"projects/{self._client.project}/databases/{self._client._database}" + request = ExecutePipelineRequest( + database=database_name, + structured_pipeline=self._to_pb(), + read_time=datetime.datetime.now(), + ) + return await self._client._firestore_api.execute_pipeline(request) diff --git a/google/cloud/firestore_v1/base_pipeline.py b/google/cloud/firestore_v1/base_pipeline.py new file mode 100644 index 000000000..bdfb1a8ed --- /dev/null +++ b/google/cloud/firestore_v1/base_pipeline.py @@ -0,0 +1,127 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations +from typing import Optional, Sequence, Self +from google.cloud.firestore_v1 import pipeline_stages as stages +from google.cloud.firestore_v1.types.pipeline import StructuredPipeline as StructuredPipeline_pb +from google.cloud.firestore_v1.vector import Vector +from google.cloud.firestore_v1.base_vector_query import DistanceMeasure +from google.cloud.firestore_v1.pipeline_expressions import ( + Accumulator, + Expr, + ExprWithAlias, + Field, + FilterCondition, + Selectable, + SampleOptions, +) + + +class _BasePipeline: + def __init__(self, *stages: stages.Stage): + self.stages = list(stages) + + def __repr__(self): + if not self.stages: + return "Pipeline()" + elif len(self.stages) == 1: + return f"Pipeline({self.stages[0]!r})" + else: + stages_str = ",\n ".join([repr(s) for s in self.stages]) + return f"Pipeline(\n {stages_str}\n)" + + def _to_pb(self) -> StructuredPipeline_pb: + return StructuredPipeline_pb(pipeline={"stages":[s._to_pb() for s in self.stages]}) + + def add_fields(self, *fields: Selectable) -> Self: + self.stages.append(stages.AddFields(*fields)) + return self + + def remove_fields(self, *fields: Field | str) -> Self: + self.stages.append(stages.RemoveFields(*fields)) + return self + + def select(self, *selections: str | Selectable) -> Self: + self.stages.append(stages.Select(*selections)) + return self + + def where(self, condition: FilterCondition) -> Self: + self.stages.append(stages.Where(condition)) + return self + + def find_nearest( + self, + field: str | Expr, + vector: Sequence[float] | "Vector", + distance_measure: "DistanceMeasure", + limit: int | None, + options: Optional[stages.FindNearestOptions] = None, + ) -> Self: + self.stages.append(stages.FindNearest(field, vector, distance_measure, options)) + return self + + def sort(self, *orders: stages.Ordering) -> Self: + self.stages.append(stages.Sort(*orders)) + return self + + def replace( + self, + field: Selectable, + mode: stages.Replace.Mode = stages.Replace.Mode.FULL_REPLACE, + ) -> Self: + self.stages.append(stages.Replace(field, mode)) + return self + + def sample(self, limit_or_options: int | SampleOptions) -> Self: + self.stages.append(stages.Sample(limit_or_options)) + return self + + def union(self, other: Self) -> Self: + self.stages.append(stages.Union(other)) + return self + + def unnest( + self, + field_name: str, + options: Optional[stages.UnnestOptions] = None, + ) -> Self: + self.stages.append(stages.Unnest(field_name, options)) + return self + + def generic_stage(self, name: str, *params: Expr) -> Self: + self.stages.append(stages.GenericStage(name, *params)) + return self + + def offset(self, offset: int) -> Self: + self.stages.append(stages.Offset(offset)) + return self + + def limit(self, limit: int) -> Self: + self.stages.append(stages.Limit(limit)) + return self + + def aggregate( + self, + *accumulators: ExprWithAlias[Accumulator], + groups: Sequence[str | Selectable] = (), + ) -> Self: + self.stages.append(stages.Aggregate(*accumulators, groups=groups)) + return self + + def distinct(self, *fields: str | Selectable) -> Self: + self.stages.append(stages.Distinct(*fields)) + return self + + diff --git a/google/cloud/firestore_v1/pipeline.py b/google/cloud/firestore_v1/pipeline.py index 259e30aa5..b545fcc32 100644 --- a/google/cloud/firestore_v1/pipeline.py +++ b/google/cloud/firestore_v1/pipeline.py @@ -14,119 +14,20 @@ from __future__ import annotations import datetime -from typing import AsyncIterable, Any, Dict, Iterable, List, Optional, Sequence +from typing import AsyncIterable, Iterable, TYPE_CHECKING from google.cloud.firestore_v1 import pipeline_stages as stages -from google.cloud.firestore_v1.types.pipeline import StructuredPipeline as StructuredPipeline_pb from google.cloud.firestore_v1.types.firestore import ExecutePipelineRequest from google.cloud.firestore_v1.types.firestore import ExecutePipelineResponse -from google.cloud.firestore_v1.vector import Vector -from google.cloud.firestore_v1.base_vector_query import DistanceMeasure -from google.cloud.firestore_v1.pipeline_expressions import ( - Accumulator, - Expr, - ExprWithAlias, - Field, - FilterCondition, - Selectable, - SampleOptions, -) +from google.cloud.firestore_v1.base_pipeline import _BasePipeline +if TYPE_CHECKING: + from google.cloud.firestore_v1.client import Client -class Pipeline: - def __init__(self, client, *stages: stages.Stage): - self._client = client - self.stages = list(stages) - - def __repr__(self): - if not self.stages: - return "Pipeline()" - elif len(self.stages) == 1: - return f"Pipeline({self.stages[0]!r})" - else: - stages_str = ",\n ".join([repr(s) for s in self.stages]) - return f"Pipeline(\n {stages_str}\n)" - - def _to_pb(self) -> StructuredPipeline_pb: - return StructuredPipeline_pb(pipeline={"stages":[s._to_pb() for s in self.stages]}) - - def add_fields(self, *fields: Selectable) -> Pipeline: - self.stages.append(stages.AddFields(*fields)) - return self - - def remove_fields(self, *fields: Field | str) -> Pipeline: - self.stages.append(stages.RemoveFields(*fields)) - return self - - def select(self, *selections: str | Selectable) -> Pipeline: - self.stages.append(stages.Select(*selections)) - return self - - def where(self, condition: FilterCondition) -> Pipeline: - self.stages.append(stages.Where(condition)) - return self - - def find_nearest( - self, - field: str | Expr, - vector: Sequence[float] | "Vector", - distance_measure: "DistanceMeasure", - limit: int | None, - options: Optional[stages.FindNearestOptions] = None, - ) -> Pipeline: - self.stages.append(stages.FindNearest(field, vector, distance_measure, options)) - return self - - def sort(self, *orders: stages.Ordering) -> Pipeline: - self.stages.append(stages.Sort(*orders)) - return self - def replace( - self, - field: Selectable, - mode: stages.Replace.Mode = stages.Replace.Mode.FULL_REPLACE, - ) -> Pipeline: - self.stages.append(stages.Replace(field, mode)) - return self - - def sample(self, limit_or_options: int | SampleOptions) -> Pipeline: - self.stages.append(stages.Sample(limit_or_options)) - return self - - def union(self, other: Pipeline) -> Pipeline: - self.stages.append(stages.Union(other)) - return self - - def unnest( - self, - field_name: str, - options: Optional[stages.UnnestOptions] = None, - ) -> Pipeline: - self.stages.append(stages.Unnest(field_name, options)) - return self - - def generic_stage(self, name: str, *params: Expr) -> Pipeline: - self.stages.append(stages.GenericStage(name, *params)) - return self - - def offset(self, offset: int) -> Pipeline: - self.stages.append(stages.Offset(offset)) - return self - - def limit(self, limit: int) -> Pipeline: - self.stages.append(stages.Limit(limit)) - return self - - def aggregate( - self, - *accumulators: ExprWithAlias[Accumulator], - groups: Sequence[str | Selectable] = (), - ) -> Pipeline: - self.stages.append(stages.Aggregate(*accumulators, groups=groups)) - return self - - def distinct(self, *fields: str | Selectable) -> Pipeline: - self.stages.append(stages.Distinct(*fields)) - return self +class Pipeline(_BasePipeline): + def __init__(self, client:Client, *stages: stages.Stage): + super().__init__(*stages) + self._client = client def execute(self) -> Iterable["ExecutePipelineResponse"]: database_name = f"projects/{self._client.project}/databases/{self._client._database}" From a323e5b935309a1b7114899ba060e2748fe1e36e Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 27 Mar 2025 14:41:36 -0700 Subject: [PATCH 056/131] added docstrings to expressions --- .../firestore_v1/pipeline_expressions.py | 846 +++++++++++++++++- 1 file changed, 842 insertions(+), 4 deletions(-) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index fc3604dbf..9e22c73f2 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -29,6 +29,7 @@ class Ordering: + """Represents the direction for sorting results in a pipeline.""" class Direction(Enum): ASCENDING = "ascending" @@ -57,6 +58,7 @@ def _to_pb(self) -> Value: @dataclass class SampleOptions: + """Options for the 'sample' pipeline stage.""" class Mode(Enum): DOCUMENTS = "documents" PERCENTAGE = "percent" @@ -70,6 +72,17 @@ def __post_init__(self): class Expr(ABC): """Represents an expression that can be evaluated to a value within the execution of a pipeline. + + Expressions are the building blocks for creating complex queries and + transformations in Firestore pipelines. They can represent: + + - **Field references:** Access values from document fields. + - **Literals:** Represent constant values (strings, numbers, booleans). + - **Function calls:** Apply functions to one or more expressions. + - **Aggregations:** Calculate aggregate values (e.g., sum, average) over a set of documents. + + The `Expr` class provides a fluent API for building expressions. You can chain + together method calls to create complex expressions. """ def __repr__(self): @@ -84,188 +97,931 @@ def _cast_to_expr_or_convert_to_constant(o: Any) -> "Expr": return o if isinstance(o, Expr) else Constant(o) def add(self, other: Expr | float) -> "Add": + """Creates an expression that adds this expression to another expression or constant. + + Example: + >>> # Add the value of the 'quantity' field and the 'reserve' field. + >>> Field.of("quantity").add(Field.of("reserve")) + >>> # Add 5 to the value of the 'age' field + >>> Field.of("age").add(5) + + Args: + other: The expression or constant value to add to this expression. + + Returns: + A new `Expr` representing the addition operation. + """ return Add(self, self._cast_to_expr_or_convert_to_constant(other)) def subtract(self, other: Expr | float) -> "Subtract": + """Creates an expression that subtracts another expression or constant from this expression. + + Example: + >>> # Subtract the 'discount' field from the 'price' field + >>> Field.of("price").subtract(Field.of("discount")) + >>> # Subtract 20 from the value of the 'total' field + >>> Field.of("total").subtract(20) + + Args: + other: The expression or constant value to subtract from this expression. + + Returns: + A new `Expr` representing the subtraction operation. + """ return Subtract(self, self._cast_to_expr_or_convert_to_constant(other)) def multiply(self, other: Expr | float) -> "Multiply": + """Creates an expression that multiplies this expression by another expression or constant. + + Example: + >>> # Multiply the 'quantity' field by the 'price' field + >>> Field.of("quantity").multiply(Field.of("price")) + >>> # Multiply the 'value' field by 2 + >>> Field.of("value").multiply(2) + + Args: + other: The expression or constant value to multiply by. + + Returns: + A new `Expr` representing the multiplication operation. + """ return Multiply(self, self._cast_to_expr_or_convert_to_constant(other)) def divide(self, other: Expr | float) -> "Divide": + """Creates an expression that divides this expression by another expression or constant. + + Example: + >>> # Divide the 'total' field by the 'count' field + >>> Field.of("total").divide(Field.of("count")) + >>> # Divide the 'value' field by 10 + >>> Field.of("value").divide(10) + + Args: + other: The expression or constant value to divide by. + + Returns: + A new `Expr` representing the division operation. + """ return Divide(self, self._cast_to_expr_or_convert_to_constant(other)) def mod(self, other: Expr | float) -> "Mod": + """Creates an expression that calculates the modulo (remainder) to another expression or constant. + + Example: + >>> # Calculate the remainder of dividing the 'value' field by field 'divisor'. + >>> Field.of("value").mod(Field.of("divisor")) + >>> # Calculate the remainder of dividing the 'value' field by 5. + >>> Field.of("value").mod(5) + + Args: + other: The divisor expression or constant. + + Returns: + A new `Expr` representing the modulo operation. + """ return Mod(self, self._cast_to_expr_or_convert_to_constant(other)) def logical_max(self, other: Expr | CONSTANT_TYPE) -> "LogicalMax": + """Creates an expression that returns the larger value between this expression + and another expression or constant, based on Firestore's value type ordering. + + Firestore's value type ordering is described here: + https://cloud.google.com/firestore/docs/concepts/data-types#value_type_ordering + + Example: + >>> # Returns the larger value between the 'discount' field and the 'cap' field. + >>> Field.of("discount").logical_max(Field.of("cap")) + >>> # Returns the larger value between the 'value' field and 10. + >>> Field.of("value").logical_max(10) + + Args: + other: The other expression or constant value to compare with. + + Returns: + A new `Expr` representing the logical max operation. + """ return LogicalMax(self, self._cast_to_expr_or_convert_to_constant(other)) def logical_min(self, other: Expr | CONSTANT_TYPE) -> "LogicalMin": + """Creates an expression that returns the smaller value between this expression + and another expression or constant, based on Firestore's value type ordering. + + Firestore's value type ordering is described here: + https://cloud.google.com/firestore/docs/concepts/data-types#value_type_ordering + + Example: + >>> # Returns the smaller value between the 'discount' field and the 'floor' field. + >>> Field.of("discount").logical_min(Field.of("floor")) + >>> # Returns the smaller value between the 'value' field and 10. + >>> Field.of("value").logical_min(10) + + Args: + other: The other expression or constant value to compare with. + + Returns: + A new `Expr` representing the logical min operation. + """ return LogicalMin(self, self._cast_to_expr_or_convert_to_constant(other)) def eq(self, other: Expr | CONSTANT_TYPE) -> "Eq": + """Creates an expression that checks if this expression is equal to another + expression or constant value. + + Example: + >>> # Check if the 'age' field is equal to 21 + >>> Field.of("age").eq(21) + >>> # Check if the 'city' field is equal to "London" + >>> Field.of("city").eq("London") + + Args: + other: The expression or constant value to compare for equality. + + Returns: + A new `Expr` representing the equality comparison. + """ return Eq(self, self._cast_to_expr_or_convert_to_constant(other)) def neq(self, other: Expr | CONSTANT_TYPE) -> "Neq": + """Creates an expression that checks if this expression is not equal to another + expression or constant value. + + Example: + >>> # Check if the 'status' field is not equal to "completed" + >>> Field.of("status").neq("completed") + >>> # Check if the 'country' field is not equal to "USA" + >>> Field.of("country").neq("USA") + + Args: + other: The expression or constant value to compare for inequality. + + Returns: + A new `Expr` representing the inequality comparison. + """ return Neq(self, self._cast_to_expr_or_convert_to_constant(other)) def gt(self, other: Expr | CONSTANT_TYPE) -> "Gt": + """Creates an expression that checks if this expression is greater than another + expression or constant value. + + Example: + >>> # Check if the 'age' field is greater than the 'limit' field + >>> Field.of("age").gt(Field.of("limit")) + >>> # Check if the 'price' field is greater than 100 + >>> Field.of("price").gt(100) + + Args: + other: The expression or constant value to compare for greater than. + + Returns: + A new `Expr` representing the greater than comparison. + """ return Gt(self, self._cast_to_expr_or_convert_to_constant(other)) def gte(self, other: Expr | CONSTANT_TYPE) -> "Gte": + """Creates an expression that checks if this expression is greater than or equal + to another expression or constant value. + + Example: + >>> # Check if the 'quantity' field is greater than or equal to field 'requirement' plus 1 + >>> Field.of("quantity").gte(Field.of('requirement').add(1)) + >>> # Check if the 'score' field is greater than or equal to 80 + >>> Field.of("score").gte(80) + + Args: + other: The expression or constant value to compare for greater than or equal to. + + Returns: + A new `Expr` representing the greater than or equal to comparison. + """ return Gte(self, self._cast_to_expr_or_convert_to_constant(other)) def lt(self, other: Expr | CONSTANT_TYPE) -> "Lt": + """Creates an expression that checks if this expression is less than another + expression or constant value. + + Example: + >>> # Check if the 'age' field is less than 'limit' + >>> Field.of("age").lt(Field.of('limit')) + >>> # Check if the 'price' field is less than 50 + >>> Field.of("price").lt(50) + + Args: + other: The expression or constant value to compare for less than. + + Returns: + A new `Expr` representing the less than comparison. + """ return Lt(self, self._cast_to_expr_or_convert_to_constant(other)) def lte(self, other: Expr | CONSTANT_TYPE) -> "Lte": + """Creates an expression that checks if this expression is less than or equal to + another expression or constant value. + + Example: + >>> # Check if the 'quantity' field is less than or equal to 20 + >>> Field.of("quantity").lte(Constant.of(20)) + >>> # Check if the 'score' field is less than or equal to 70 + >>> Field.of("score").lte(70) + + Args: + other: The expression or constant value to compare for less than or equal to. + + Returns: + A new `Expr` representing the less than or equal to comparison. + """ return Lte(self, self._cast_to_expr_or_convert_to_constant(other)) def in_any(self, *others: Expr | CONSTANT_TYPE) -> "In": + """Creates an expression that checks if this expression is equal to any of the + provided values or expressions. + + Example: + >>> # Check if the 'category' field is either "Electronics" or value of field 'primaryType' + >>> Field.of("category").in_any("Electronics", Field.of("primaryType")) + + Args: + *others: The values or expressions to check against. + + Returns: + A new `Expr` representing the 'IN' comparison. + """ return In(self, [self._cast_to_expr_or_convert_to_constant(o) for o in others]) def not_in_any(self, *others: Expr | CONSTANT_TYPE) -> "Not": + """Creates an expression that checks if this expression is not equal to any of the + provided values or expressions. + + Example: + >>> # Check if the 'status' field is neither "pending" nor "cancelled" + >>> Field.of("status").not_in_any("pending", "cancelled") + + Args: + *others: The values or expressions to check against. + + Returns: + A new `Expr` representing the 'NOT IN' comparison. + """ return Not(self.in_any(*others)) def array_concat(self, array: List[Expr | CONSTANT_TYPE]) -> "ArrayConcat": + """Creates an expression that concatenates an array expression with another array. + + Example: + >>> # Combine the 'tags' array with a new array and an array field + >>> Field.of("tags").array_concat(["newTag1", "newTag2", Field.of("otherTag")]) + + Args: + array: The list of constants or expressions to concat with. + + Returns: + A new `Expr` representing the concatenated array. + """ return ArrayConcat(self, [self._cast_to_expr_or_convert_to_constant(o) for o in array]) def array_contains(self, element: Expr | CONSTANT_TYPE) -> "ArrayContains": + """Creates an expression that checks if an array contains a specific element or value. + + Example: + >>> # Check if the 'sizes' array contains the value from the 'selectedSize' field + >>> Field.of("sizes").array_contains(Field.of("selectedSize")) + >>> # Check if the 'colors' array contains "red" + >>> Field.of("colors").array_contains("red") + + Args: + element: The element (expression or constant) to search for in the array. + + Returns: + A new `Expr` representing the 'array_contains' comparison. + """ return ArrayContains(self, self._cast_to_expr_or_convert_to_constant(element)) def array_contains_all(self, elements: List[Expr | CONSTANT_TYPE]) -> "ArrayContainsAll": + """Creates an expression that checks if an array contains all the specified elements. + + Example: + >>> # Check if the 'tags' array contains both "news" and "sports" + >>> Field.of("tags").array_contains_all(["news", "sports"]) + >>> # Check if the 'tags' array contains both of the values from field 'tag1' and "tag2" + >>> Field.of("tags").array_contains_all([Field.of("tag1"), "tag2"]) + + Args: + elements: The list of elements (expressions or constants) to check for in the array. + + Returns: + A new `Expr` representing the 'array_contains_all' comparison. + """ return ArrayContainsAll(self, [self._cast_to_expr_or_convert_to_constant(e) for e in elements]) def array_contains_any(self, elements: List[Expr | CONSTANT_TYPE]) -> "ArrayContainsAny": + """Creates an expression that checks if an array contains any of the specified elements. + + Example: + >>> # Check if the 'categories' array contains either values from field "cate1" or "cate2" + >>> Field.of("categories").array_contains_any([Field.of("cate1"), Field.of("cate2")]) + >>> # Check if the 'groups' array contains either the value from the 'userGroup' field + >>> # or the value "guest" + >>> Field.of("groups").array_contains_any([Field.of("userGroup"), "guest"]) + + Args: + elements: The list of elements (expressions or constants) to check for in the array. + + Returns: + A new `Expr` representing the 'array_contains_any' comparison. + """ return ArrayContainsAny(self, [self._cast_to_expr_or_convert_to_constant(e) for e in elements]) def array_length(self) -> "ArrayLength": + """Creates an expression that calculates the length of an array. + + Example: + >>> # Get the number of items in the 'cart' array + >>> Field.of("cart").array_length() + + Returns: + A new `Expr` representing the length of the array. + """ return ArrayLength(self) def array_reverse(self) -> "ArrayReverse": + """Creates an expression that returns the reversed content of an array. + + Example: + >>> # Get the 'preferences' array in reversed order. + >>> Field.of("preferences").array_reverse() + + Returns: + A new `Expr` representing the reversed array. + """ return ArrayReverse(self) def is_nan(self) -> "IsNaN": + """Creates an expression that checks if this expression evaluates to 'NaN' (Not a Number). + + Example: + >>> # Check if the result of a calculation is NaN + >>> Field.of("value").divide(0).is_nan() + + Returns: + A new `Expr` representing the 'isNaN' check. + """ return IsNaN(self) def exists(self) -> "Exists": + """Creates an expression that checks if a field exists in the document. + + Example: + >>> # Check if the document has a field named "phoneNumber" + >>> Field.of("phoneNumber").exists() + + Returns: + A new `Expr` representing the 'exists' check. + """ return Exists(self) def sum(self) -> "Sum": + """Creates an aggregation that calculates the sum of a numeric field across multiple stage inputs. + + Example: + >>> # Calculate the total revenue from a set of orders + >>> Field.of("orderAmount").sum().as_("totalRevenue") + + Returns: + A new `Accumulator` representing the 'sum' aggregation. + """ return Sum(self, False) def avg(self) -> "Avg": + """Creates an aggregation that calculates the average (mean) of a numeric field across multiple + stage inputs. + + Example: + >>> # Calculate the average age of users + >>> Field.of("age").avg().as_("averageAge") + + Returns: + A new `Accumulator` representing the 'avg' aggregation. + """ return Avg(self, False) def count(self) -> "Count": + """Creates an aggregation that counts the number of stage inputs with valid evaluations of the + expression or field. + + Example: + >>> # Count the total number of products + >>> Field.of("productId").count().as_("totalProducts") + + Returns: + A new `Accumulator` representing the 'count' aggregation. + """ return Count(self) def min(self) -> "Min": + """Creates an aggregation that finds the minimum value of a field across multiple stage inputs. + + Example: + >>> # Find the lowest price of all products + >>> Field.of("price").min().as_("lowestPrice") + + Returns: + A new `Accumulator` representing the 'min' aggregation. + """ return Min(self, False) def max(self) -> "Max": + """Creates an aggregation that finds the maximum value of a field across multiple stage inputs. + + Example: + >>> # Find the highest score in a leaderboard + >>> Field.of("score").max().as_("highestScore") + + Returns: + A new `Accumulator` representing the 'max' aggregation. + """ return Max(self, False) def char_length(self) -> "CharLength": + """Creates an expression that calculates the character length of a string. + + Example: + >>> # Get the character length of the 'name' field + >>> Field.of("name").char_length() + + Returns: + A new `Expr` representing the length of the string. + """ return CharLength(self) def byte_length(self) -> "ByteLength": + """Creates an expression that calculates the byte length of a string in its UTF-8 form. + + Example: + >>> # Get the byte length of the 'name' field + >>> Field.of("name").byte_length() + + Returns: + A new `Expr` representing the byte length of the string. + """ return ByteLength(self) def like(self, pattern: Expr | str) -> "Like": + """Creates an expression that performs a case-sensitive string comparison. + + Example: + >>> # Check if the 'title' field contains the word "guide" (case-sensitive) + >>> Field.of("title").like("%guide%") + >>> # Check if the 'title' field matches the pattern specified in field 'pattern'. + >>> Field.of("title").like(Field.of("pattern")) + + Args: + pattern: The pattern (string or expression) to search for. You can use "%" as a wildcard character. + + Returns: + A new `Expr` representing the 'like' comparison. + """ return Like(self, self._cast_to_expr_or_convert_to_constant(pattern)) def regex_contains(self, regex: Expr | str) -> "RegexContains": + """Creates an expression that checks if a string contains a specified regular expression as a + substring. + + Example: + >>> # Check if the 'description' field contains "example" (case-insensitive) + >>> Field.of("description").regex_contains("(?i)example") + >>> # Check if the 'description' field contains the regular expression stored in field 'regex' + >>> Field.of("description").regex_contains(Field.of("regex")) + + Args: + regex: The regular expression (string or expression) to use for the search. + + Returns: + A new `Expr` representing the 'contains' comparison. + """ return RegexContains(self, self._cast_to_expr_or_convert_to_constant(regex)) def regex_matches(self, regex: Expr | str) -> "RegexMatch": + """Creates an expression that checks if a string matches a specified regular expression. + + Example: + >>> # Check if the 'email' field matches a valid email pattern + >>> Field.of("email").regex_matches("[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\\.[A-Za-z]{2,}") + >>> # Check if the 'email' field matches a regular expression stored in field 'regex' + >>> Field.of("email").regex_matches(Field.of("regex")) + + Args: + regex: The regular expression (string or expression) to use for the match. + + Returns: + A new `Expr` representing the regular expression match. + """ return RegexMatch(self, self._cast_to_expr_or_convert_to_constant(regex)) def str_contains(self, substring: Expr | str) -> "StrContains": + """Creates an expression that checks if this string expression contains a specified substring. + + Example: + >>> # Check if the 'description' field contains "example". + >>> Field.of("description").str_contains("example") + >>> # Check if the 'description' field contains the value of the 'keyword' field. + >>> Field.of("description").str_contains(Field.of("keyword")) + + Args: + substring: The substring (string or expression) to use for the search. + + Returns: + A new `Expr` representing the 'contains' comparison. + """ return StrContains(self, self._cast_to_expr_or_convert_to_constant(substring)) def starts_with(self, prefix: Expr | str) -> "StartsWith": + """Creates an expression that checks if a string starts with a given prefix. + + Example: + >>> # Check if the 'name' field starts with "Mr." + >>> Field.of("name").starts_with("Mr.") + >>> # Check if the 'fullName' field starts with the value of the 'firstName' field + >>> Field.of("fullName").starts_with(Field.of("firstName")) + + Args: + prefix: The prefix (string or expression) to check for. + + Returns: + A new `Expr` representing the 'starts with' comparison. + """ return StartsWith(self, self._cast_to_expr_or_convert_to_constant(prefix)) def ends_with(self, postfix: Expr | str) -> "EndsWith": + """Creates an expression that checks if a string ends with a given postfix. + + Example: + >>> # Check if the 'filename' field ends with ".txt" + >>> Field.of("filename").ends_with(".txt") + >>> # Check if the 'url' field ends with the value of the 'extension' field + >>> Field.of("url").ends_with(Field.of("extension")) + + Args: + postfix: The postfix (string or expression) to check for. + + Returns: + A new `Expr` representing the 'ends with' comparison. + """ return EndsWith(self, self._cast_to_expr_or_convert_to_constant(postfix)) def str_concat(self, *elements: Expr | CONSTANT_TYPE) -> "StrConcat": + """Creates an expression that concatenates string expressions, fields or constants together. + + Example: + >>> # Combine the 'firstName', " ", and 'lastName' fields into a single string + >>> Field.of("firstName").str_concat(" ", Field.of("lastName")) + + Args: + *elements: The expressions or constants (typically strings) to concatenate. + + Returns: + A new `Expr` representing the concatenated string. + """ return StrConcat(*[self._cast_to_expr_or_convert_to_constant(el) for el in elements]) def to_lower(self) -> "ToLower": + """Creates an expression that converts a string to lowercase. + + Example: + >>> # Convert the 'name' field to lowercase + >>> Field.of("name").to_lower() + + Returns: + A new `Expr` representing the lowercase string. + """ return ToLower(self) def to_upper(self) -> "ToUpper": + """Creates an expression that converts a string to uppercase. + + Example: + >>> # Convert the 'title' field to uppercase + >>> Field.of("title").to_upper() + + Returns: + A new `Expr` representing the uppercase string. + """ return ToUpper(self) def trim(self) -> "Trim": + """Creates an expression that removes leading and trailing whitespace from a string. + + Example: + >>> # Trim whitespace from the 'userInput' field + >>> Field.of("userInput").trim() + + Returns: + A new `Expr` representing the trimmed string. + """ return Trim(self) def reverse(self) -> "Reverse": + """Creates an expression that reverses a string. + + Example: + >>> # Reverse the 'userInput' field + >>> Field.of("userInput").reverse() + + Returns: + A new `Expr` representing the reversed string. + """ return Reverse(self) def replace_first(self, find: Expr | str, replace: Expr | str) -> "ReplaceFirst": + """Creates an expression that replaces the first occurrence of a substring within a string with + another substring. + + Example: + >>> # Replace the first occurrence of "hello" with "hi" in the 'message' field + >>> Field.of("message").replace_first("hello", "hi") + >>> # Replace the first occurrence of the value in 'findField' with the value in 'replaceField' in the 'message' field + >>> Field.of("message").replace_first(Field.of("findField"), Field.of("replaceField")) + + Args: + find: The substring (string or expression) to search for. + replace: The substring (string or expression) to replace the first occurrence of 'find' with. + + Returns: + A new `Expr` representing the string with the first occurrence replaced. + """ return ReplaceFirst(self, self._cast_to_expr_or_convert_to_constant(find), self._cast_to_expr_or_convert_to_constant(replace)) def replace_all(self, find: Expr | str, replace: Expr | str) -> "ReplaceAll": + """Creates an expression that replaces all occurrences of a substring within a string with another + substring. + + Example: + >>> # Replace all occurrences of "hello" with "hi" in the 'message' field + >>> Field.of("message").replace_all("hello", "hi") + >>> # Replace all occurrences of the value in 'findField' with the value in 'replaceField' in the 'message' field + >>> Field.of("message").replace_all(Field.of("findField"), Field.of("replaceField")) + + Args: + find: The substring (string or expression) to search for. + replace: The substring (string or expression) to replace all occurrences of 'find' with. + + Returns: + A new `Expr` representing the string with all occurrences replaced. + """ return ReplaceAll(self, self._cast_to_expr_or_convert_to_constant(find), self._cast_to_expr_or_convert_to_constant(replace)) def map_get(self, key: str) -> "MapGet": + """Accesses a value from a map (object) field using the provided key. + + Example: + >>> # Get the 'city' value from + >>> # the 'address' map field + >>> Field.of("address").map_get("city") + + Args: + key: The key to access in the map. + + Returns: + A new `Expr` representing the value associated with the given key in the map. + """ return MapGet(self, key) def cosine_distance(self, other: Expr | list[float] | Vector) -> "CosineDistance": + """Calculates the cosine distance between two vectors. + + Example: + >>> # Calculate the cosine distance between the 'userVector' field and the 'itemVector' field + >>> Field.of("userVector").cosine_distance(Field.of("itemVector")) + >>> # Calculate the Cosine distance between the 'location' field and a target location + >>> Field.of("location").cosine_distance([37.7749, -122.4194]) + + Args: + other: The other vector (represented as an Expr, list of floats, or Vector) to compare against. + + Returns: + A new `Expr` representing the cosine distance between the two vectors. + """ return CosineDistance(self, self._cast_to_expr_or_convert_to_constant(other)) def euclidean_distance(self, other: Expr | list[float] | Vector) -> "EuclideanDistance": + """Calculates the Euclidean distance between two vectors. + + Example: + >>> # Calculate the Euclidean distance between the 'location' field and a target location + >>> Field.of("location").euclidean_distance([37.7749, -122.4194]) + >>> # Calculate the Euclidean distance between two vector fields: 'pointA' and 'pointB' + >>> Field.of("pointA").euclidean_distance(Field.of("pointB")) + + Args: + other: The other vector (represented as an Expr, list of floats, or Vector) to compare against. + + Returns: + A new `Expr` representing the Euclidean distance between the two vectors. + """ return EuclideanDistance(self, self._cast_to_expr_or_convert_to_constant(other)) def dot_product(self, other: Expr | list[float] | Vector) -> "DotProduct": + """Calculates the dot product between two vectors. + + Example: + >>> # Calculate the dot product between a feature vector and a target vector + >>> Field.of("features").dot_product([0.5, 0.8, 0.2]) + >>> # Calculate the dot product between two document vectors: 'docVector1' and 'docVector2' + >>> Field.of("docVector1").dot_product(Field.of("docVector2")) + + Args: + other: The other vector (represented as an Expr, list of floats, or Vector) to calculate dot product with. + + Returns: + A new `Expr` representing the dot product between the two vectors. + """ return DotProduct(self, self._cast_to_expr_or_convert_to_constant(other)) def vector_length(self) -> "VectorLength": + """Creates an expression that calculates the length (dimension) of a Firestore Vector. + + Example: + >>> # Get the vector length (dimension) of the field 'embedding'. + >>> Field.of("embedding").vector_length() + + Returns: + A new `Expr` representing the length of the vector. + """ return VectorLength(self) def timestamp_to_unix_micros(self) -> "TimestampToUnixMicros": + """Creates an expression that converts a timestamp to the number of microseconds since the epoch + (1970-01-01 00:00:00 UTC). + + Truncates higher levels of precision by rounding down to the beginning of the microsecond. + + Example: + >>> # Convert the 'timestamp' field to microseconds since the epoch. + >>> Field.of("timestamp").timestamp_to_unix_micros() + + Returns: + A new `Expr` representing the number of microseconds since the epoch. + """ return TimestampToUnixMicros(self) def unix_micros_to_timestamp(self) -> "UnixMicrosToTimestamp": + """Creates an expression that converts a number of microseconds since the epoch (1970-01-01 + 00:00:00 UTC) to a timestamp. + + Example: + >>> # Convert the 'microseconds' field to a timestamp. + >>> Field.of("microseconds").unix_micros_to_timestamp() + + Returns: + A new `Expr` representing the timestamp. + """ return UnixMicrosToTimestamp(self) def timestamp_to_unix_millis(self) -> "TimestampToUnixMillis": + """Creates an expression that converts a timestamp to the number of milliseconds since the epoch + (1970-01-01 00:00:00 UTC). + + Truncates higher levels of precision by rounding down to the beginning of the millisecond. + + Example: + >>> # Convert the 'timestamp' field to milliseconds since the epoch. + >>> Field.of("timestamp").timestamp_to_unix_millis() + + Returns: + A new `Expr` representing the number of milliseconds since the epoch. + """ return TimestampToUnixMillis(self) def unix_millis_to_timestamp(self) -> "UnixMillisToTimestamp": + """Creates an expression that converts a number of milliseconds since the epoch (1970-01-01 + 00:00:00 UTC) to a timestamp. + + Example: + >>> # Convert the 'milliseconds' field to a timestamp. + >>> Field.of("milliseconds").unix_millis_to_timestamp() + + Returns: + A new `Expr` representing the timestamp. + """ return UnixMillisToTimestamp(self) def timestamp_to_unix_seconds(self) -> "TimestampToUnixSeconds": + """Creates an expression that converts a timestamp to the number of seconds since the epoch + (1970-01-01 00:00:00 UTC). + + Truncates higher levels of precision by rounding down to the beginning of the second. + + Example: + >>> # Convert the 'timestamp' field to seconds since the epoch. + >>> Field.of("timestamp").timestamp_to_unix_seconds() + + Returns: + A new `Expr` representing the number of seconds since the epoch. + """ return TimestampToUnixSeconds(self) def unix_seconds_to_timestamp(self) -> "UnixSecondsToTimestamp": + """Creates an expression that converts a number of seconds since the epoch (1970-01-01 00:00:00 + UTC) to a timestamp. + + Example: + >>> # Convert the 'seconds' field to a timestamp. + >>> Field.of("seconds").unix_seconds_to_timestamp() + + Returns: + A new `Expr` representing the timestamp. + """ return UnixSecondsToTimestamp(self) def timestamp_add(self, unit: Expr | str, amount: Expr | float) -> "TimestampAdd": + """Creates an expression that adds a specified amount of time to this timestamp expression. + + Example: + >>> # Add a duration specified by the 'unit' and 'amount' fields to the 'timestamp' field. + >>> Field.of("timestamp").timestamp_add(Field.of("unit"), Field.of("amount")) + >>> # Add 1.5 days to the 'timestamp' field. + >>> Field.of("timestamp").timestamp_add("day", 1.5) + + Args: + unit: The expression or string evaluating to the unit of time to add, must be one of + 'microsecond', 'millisecond', 'second', 'minute', 'hour', 'day'. + amount: The expression or float representing the amount of time to add. + + Returns: + A new `Expr` representing the resulting timestamp. + """ return TimestampAdd(self, self._cast_to_expr_or_convert_to_constant(unit), self._cast_to_expr_or_convert_to_constant(amount)) def timestamp_sub(self, unit: Expr | str, amount: Expr | float) -> "TimestampSub": + """Creates an expression that subtracts a specified amount of time from this timestamp expression. + + Example: + >>> # Subtract a duration specified by the 'unit' and 'amount' fields from the 'timestamp' field. + >>> Field.of("timestamp").timestamp_sub(Field.of("unit"), Field.of("amount")) + >>> # Subtract 2.5 hours from the 'timestamp' field. + >>> Field.of("timestamp").timestamp_sub("hour", 2.5) + + Args: + unit: The expression or string evaluating to the unit of time to subtract, must be one of + 'microsecond', 'millisecond', 'second', 'minute', 'hour', 'day'. + amount: The expression or float representing the amount of time to subtract. + + Returns: + A new `Expr` representing the resulting timestamp. + """ return TimestampSub(self, self._cast_to_expr_or_convert_to_constant(unit), self._cast_to_expr_or_convert_to_constant(amount)) def ascending(self) -> Ordering: + """Creates an `Ordering` that sorts documents in ascending order based on this expression. + + Example: + >>> # Sort documents by the 'name' field in ascending order + >>> firestore.pipeline().collection("users").sort(Field.of("name").ascending()) + + Returns: + A new `Ordering` for ascending sorting. + """ return Ordering(self, Ordering.Direction.ASCENDING) def descending(self) -> Ordering: + """Creates an `Ordering` that sorts documents in descending order based on this expression. + + Example: + >>> # Sort documents by the 'createdAt' field in descending order + >>> firestore.pipeline().collection("users").sort(Field.of("createdAt").descending()) + + Returns: + A new `Ordering` for descending sorting. + """ return Ordering(self, Ordering.Direction.DESCENDING) def as_(self, alias: str) -> "ExprWithAlias": + """Assigns an alias to this expression. + + Aliases are useful for renaming fields in the output of a stage or for giving meaningful + names to calculated values. + + Example: + >>> # Calculate the total price and assign it the alias "totalPrice" and add it to the output. + >>> firestore.pipeline().collection("items").add_fields( + ... Field.of("price").multiply(Field.of("quantity")).as_("totalPrice") + ... ) + + Args: + alias: The alias to assign to this expression. + + Returns: + A new `Selectable` (typically an `ExprWithAlias`) that wraps this + expression and associates it with the provided alias. + """ return ExprWithAlias(self, alias) class Constant(Expr, Generic[CONSTANT_TYPE]): + """Represents a constant literal value in an expression.""" def __init__(self, value: CONSTANT_TYPE): self.value: CONSTANT_TYPE = value @staticmethod def of(value:CONSTANT_TYPE) -> Constant[CONSTANT_TYPE]: + """Creates a constant expression from a Python value.""" return Constant(value) def __repr__(self): @@ -275,6 +1031,7 @@ def _to_pb(self) -> Value: return encode_value(self.value) class ListOfExprs(Expr): + """Represents a list of expressions, typically used as an argument to functions like 'in' or array functions.""" def __init__(self, exprs: List[Expr]): self.exprs: list[Expr] = exprs @@ -283,7 +1040,7 @@ def _to_pb(self): class Function(Expr): - """A type of Expression that takes in inputs and gives outputs.""" + """A base class for expressions that represent function calls.""" def __init__(self, name: str, params: Sequence[Expr]): self.name = name @@ -300,226 +1057,269 @@ def _to_pb(self): ) class Divide(Function): + """Represents the division function.""" def __init__(self, left: Expr, right: Expr): super().__init__("divide", [left, right]) class DotProduct(Function): + """Represents the vector dot product function.""" def __init__(self, vector1: Expr, vector2: Expr): super().__init__("dot_product", [vector1, vector2]) class EuclideanDistance(Function): + """Represents the vector Euclidean distance function.""" def __init__(self, vector1: Expr, vector2: Expr): super().__init__("euclidean_distance", [vector1, vector2]) class LogicalMax(Function): + """Represents the logical maximum function based on Firestore type ordering.""" def __init__(self, left: Expr, right: Expr): super().__init__("logical_maximum", [left, right]) class LogicalMin(Function): + """Represents the logical minimum function based on Firestore type ordering.""" def __init__(self, left: Expr, right: Expr): super().__init__("logical_minimum", [left, right]) class MapGet(Function): + """Represents accessing a value within a map by key.""" def __init__(self, map_: Expr, key: str): super().__init__("map_get", [map_, Constant(key)]) class Mod(Function): + """Represents the modulo function.""" def __init__(self, left: Expr, right: Expr): super().__init__("mod", [left, right]) class Multiply(Function): + """Represents the multiplication function.""" def __init__(self, left: Expr, right: Expr): super().__init__("multiply", [left, right]) class Parent(Function): + """Represents getting the parent document reference.""" def __init__(self, value: Expr): super().__init__("parent", [value]) class ReplaceAll(Function): + """Represents replacing all occurrences of a substring.""" def __init__(self, value: Expr, pattern: Expr, replacement: Expr): super().__init__("replace_all", [value, pattern, replacement]) class ReplaceFirst(Function): + """Represents replacing the first occurrence of a substring.""" def __init__(self, value: Expr, pattern: Expr, replacement: Expr): super().__init__("replace_first", [value, pattern, replacement]) class Reverse(Function): + """Represents reversing a string.""" def __init__(self, expr: Expr): super().__init__("reverse", [expr]) class StrConcat(Function): + """Represents concatenating multiple strings.""" def __init__(self, *exprs: Expr): super().__init__("str_concat", exprs) class Subtract(Function): + """Represents the subtraction function.""" def __init__(self, left: Expr, right: Expr): super().__init__("subtract", [left, right]) class TimestampAdd(Function): + """Represents adding a duration to a timestamp.""" def __init__(self, timestamp: Expr, unit: Expr, amount: Expr): super().__init__("timestamp_add", [timestamp, unit, amount]) class TimestampSub(Function): + """Represents subtracting a duration from a timestamp.""" def __init__(self, timestamp: Expr, unit: Expr, amount: Expr): super().__init__("timestamp_sub", [timestamp, unit, amount]) class TimestampToUnixMicros(Function): + """Represents converting a timestamp to microseconds since epoch.""" def __init__(self, input: Expr): super().__init__("timestamp_to_unix_micros", [input]) class TimestampToUnixMillis(Function): + """Represents converting a timestamp to milliseconds since epoch.""" def __init__(self, input: Expr): super().__init__("timestamp_to_unix_millis", [input]) class TimestampToUnixSeconds(Function): + """Represents converting a timestamp to seconds since epoch.""" def __init__(self, input: Expr): super().__init__("timestamp_to_unix_seconds", [input]) class ToLower(Function): + """Represents converting a string to lowercase.""" def __init__(self, value: Expr): super().__init__("to_lower", [value]) class ToUpper(Function): + """Represents converting a string to uppercase.""" def __init__(self, value: Expr): super().__init__("to_upper", [value]) class Trim(Function): + """Represents trimming whitespace from a string.""" def __init__(self, expr: Expr): super().__init__("trim", [expr]) class UnixMicrosToTimestamp(Function): + """Represents converting microseconds since epoch to a timestamp.""" def __init__(self, input: Expr): super().__init__("unix_micros_to_timestamp", [input]) class UnixMillisToTimestamp(Function): + """Represents converting milliseconds since epoch to a timestamp.""" def __init__(self, input: Expr): super().__init__("unix_millis_to_timestamp", [input]) class UnixSecondsToTimestamp(Function): + """Represents converting seconds since epoch to a timestamp.""" def __init__(self, input: Expr): super().__init__("unix_seconds_to_timestamp", [input]) class VectorLength(Function): + """Represents getting the length (dimension) of a vector.""" def __init__(self, array: Expr): super().__init__("vector_length", [array]) class Add(Function): + """Represents the addition function.""" def __init__(self, left: Expr, right: Expr): super().__init__("add", [left, right]) class ArrayConcat(Function): + """Represents concatenating multiple arrays.""" def __init__(self, array: Expr, rest: List[Expr]): super().__init__("array_concat", [array] + rest) class ArrayElement(Function): + """Represents accessing an element within an array""" def __init__(self): super().__init__("array_element", []) class ArrayFilter(Function): + """Represents filtering elements from an array based on a condition.""" def __init__(self, array: Expr, filter: "FilterCondition"): super().__init__("array_filter", [array, filter]) class ArrayLength(Function): + """Represents getting the length of an array.""" def __init__(self, array: Expr): super().__init__("array_length", [array]) class ArrayReverse(Function): + """Represents reversing the elements of an array.""" def __init__(self, array: Expr): super().__init__("array_reverse", [array]) class ArrayTransform(Function): + """Represents applying a transformation function to each element of an array.""" def __init__(self, array: Expr, transform: Function): super().__init__("array_transform", [array, transform]) class ByteLength(Function): + """Represents getting the byte length of a string (UTF-8).""" def __init__(self, expr: Expr): super().__init__("byte_length", [expr]) class CharLength(Function): + """Represents getting the character length of a string.""" def __init__(self, expr: Expr): super().__init__("char_length", [expr]) class CollectionId(Function): + """Represents getting the collection ID from a document reference.""" def __init__(self, value: Expr): super().__init__("collection_id", [value]) class CosineDistance(Function): + """Represents the vector cosine distance function.""" def __init__(self, vector1: Expr, vector2: Expr): super().__init__("cosine_distance", [vector1, vector2]) class Accumulator(Function): - """A type of expression that takes in many, and results in one value.""" + """A base class for aggregation functions that operate across multiple inputs.""" class Max(Accumulator): + """Represents the maximum aggregation function.""" def __init__(self, value: Expr, distinct: bool=False): super().__init__("maximum", [value]) class Min(Accumulator): + """Represents the minimum aggregation function.""" def __init__(self, value: Expr, distinct: bool=False): super().__init__("minimum", [value]) class Sum(Accumulator): + """Represents the sum aggregation function.""" def __init__(self, value: Expr, distinct: bool=False): super().__init__("sum", [value]) class Avg(Accumulator): + """Represents the average aggregation function.""" def __init__(self, value: Expr, distinct: bool=False): super().__init__("avg", [value]) class Count(Accumulator): + """Represents the count aggregation function.""" def __init__(self, value: Expr | None = None): super().__init__("count", [value] if value else []) class CountIf(Function): + """Represents counting inputs where a condition is true (likely used internally or planned).""" def __init__(self, value: Expr, distinct: bool=False): super().__init__("countif", [value] if value else []) class Selectable(Expr): - """Points at something in the database?""" + """Base class for expressions that can be selected or aliased in projection stages.""" @abstractmethod def _to_map(self): @@ -528,6 +1328,7 @@ def _to_map(self): T = TypeVar('T', bound=Expr) class ExprWithAlias(Selectable, Generic[T]): + """Wraps an expression with an alias.""" def __init__(self, expr: T, alias: str): self.expr = expr self.alias = alias @@ -536,7 +1337,7 @@ def _to_map(self): return self.alias, self.expr._to_pb() def __repr__(self): - return f"{self.expr}.as('{self.alias}')" + return f"{self.expr}.as_('{self.alias}')" def _to_pb(self): return Value( @@ -545,13 +1346,29 @@ def _to_pb(self): class Field(Selectable): + """Represents a reference to a field within a document.""" DOCUMENT_ID = "__name__" def __init__(self, path: str): + """Initializes a Field reference. + + Args: + path: The dot-separated path to the field (e.g., "address.city"). + Use Field.DOCUMENT_ID for the document ID. + """ self.path = path @staticmethod def of(path: str): + """Creates a Field reference. + + Args: + path: The dot-separated path to the field (e.g., "address.city"). + Use Field.DOCUMENT_ID for the document ID. + + Returns: + A new Field instance. + """ return Field(path) def _to_map(self): @@ -581,41 +1398,49 @@ def __init__(self, array: Expr, element: Expr): class ArrayContainsAll(FilterCondition): + """Represents checking if an array contains all specified elements.""" def __init__(self, array: Expr, elements: List[Expr]): super().__init__("array_contains_all", [array, ListOfExprs(elements)]) class ArrayContainsAny(FilterCondition): + """Represents checking if an array contains any of the specified elements.""" def __init__(self, array: Expr, elements: List[Expr]): super().__init__("array_contains_any", [array, ListOfExprs(elements)]) class EndsWith(FilterCondition): + """Represents checking if a string ends with a specific postfix.""" def __init__(self, expr: Expr, postfix: Expr): super().__init__("ends_with", [expr, postfix]) class Eq(FilterCondition): + """Represents the equality comparison.""" def __init__(self, left: Expr, right: Expr): super().__init__("eq", [left, right if right else Constant(None)]) class Exists(FilterCondition): + """Represents checking if a field exists.""" def __init__(self, expr: Expr): super().__init__("exists", [expr]) class Gt(FilterCondition): + """Represents the greater than comparison.""" def __init__(self, left: Expr, right: Expr): super().__init__("gt", [left, right if right else Constant(None)]) class Gte(FilterCondition): + """Represents the greater than or equal to comparison.""" def __init__(self, left: Expr, right: Expr): super().__init__("gte", [left, right if right else Constant(None)]) class If(FilterCondition): + """Represents a conditional expression (if-then-else).""" def __init__(self, condition: "FilterCondition", true_expr: Expr, false_expr: Expr): super().__init__( "if", [condition, true_expr, false_expr if false_expr else Constant(None)] @@ -623,65 +1448,78 @@ def __init__(self, condition: "FilterCondition", true_expr: Expr, false_expr: Ex class In(FilterCondition): + """Represents checking if an expression's value is within a list of values.""" def __init__(self, left: Expr, others: List[Expr]): super().__init__("in", [left, ListOfExprs(others)]) class IsNaN(FilterCondition): + """Represents checking if a numeric value is NaN.""" def __init__(self, value: Expr): super().__init__("is_nan", [value]) class Like(FilterCondition): + """Represents a case-sensitive wildcard string comparison.""" def __init__(self, expr: Expr, pattern: Expr): super().__init__("like", [expr, pattern]) class Lt(FilterCondition): + """Represents the less than comparison.""" def __init__(self, left: Expr, right: Expr): super().__init__("lt", [left, right if right else Constant(None)]) class Lte(FilterCondition): + """Represents the less than or equal to comparison.""" def __init__(self, left: Expr, right: Expr): super().__init__("lte", [left, right if right else Constant(None)]) class Neq(FilterCondition): + """Represents the inequality comparison.""" def __init__(self, left: Expr, right: Expr): super().__init__("neq", [left, right if right else Constant(None)]) class Not(FilterCondition): + """Represents the logical NOT of a filter condition.""" def __init__(self, condition: Expr): super().__init__("not", [condition]) class Or(FilterCondition): + """Represents the logical OR of multiple filter conditions.""" def __init__(self, *conditions: "FilterCondition"): super().__init__("or", conditions) class RegexContains(FilterCondition): + """Represents checking if a string contains a substring matching a regex.""" def __init__(self, expr: Expr, regex: Expr): super().__init__("regex_contains", [expr, regex]) class RegexMatch(FilterCondition): + """Represents checking if a string fully matches a regex.""" def __init__(self, expr: Expr, regex: Expr): super().__init__("regex_match", [expr, regex]) class StartsWith(FilterCondition): + """Represents checking if a string starts with a specific prefix.""" def __init__(self, expr: Expr, prefix: Expr): super().__init__("starts_with", [expr, prefix]) class StrContains(FilterCondition): + """Represents checking if a string contains a specific substring.""" def __init__(self, expr: Expr, substring: Expr): super().__init__("str_contains", [expr, substring]) class Xor(FilterCondition): + """Represents the logical XOR of multiple filter conditions.""" def __init__(self, conditions: List["FilterCondition"]): super().__init__("xor", conditions) From 7e70022af9afad038c76dcd563dec0ce45cf1d37 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 27 Mar 2025 15:11:40 -0700 Subject: [PATCH 057/131] improved ordering --- google/cloud/firestore_v1/pipeline_expressions.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index 9e22c73f2..325a16843 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -35,9 +35,18 @@ class Direction(Enum): ASCENDING = "ascending" DESCENDING = "descending" - def __init__(self, expr, order_dir: Direction | str): + def __init__(self, expr, order_dir: Direction | str=Direction.ASCENDING): + """ + Initializes an Ordering instance + + Args: + expr (Expr | str): The expression or field path string to sort by. + If a string is provided, it's treated as a field path. + order_dir (Direction | str): The direction to sort in. + Defaults to ascending + """ self.expr = expr if isinstance(expr, Expr) else Field.of(expr) - self.order_dir = Ordering.Direction[order_dir] if isinstance(order_dir, str) else order_dir + self.order_dir = Ordering.Direction[order_dir.upper()] if isinstance(order_dir, str) else order_dir def __repr__(self): if self.order_dir is Ordering.Direction.ASCENDING: From 98887a0b45562f273acb8deb69defe1526c46ccf Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 27 Mar 2025 15:30:51 -0700 Subject: [PATCH 058/131] added docstrings to pipeline classes --- google/cloud/firestore_v1/async_pipeline.py | 31 ++ google/cloud/firestore_v1/base_pipeline.py | 460 +++++++++++++++++++- google/cloud/firestore_v1/pipeline.py | 36 +- 3 files changed, 516 insertions(+), 11 deletions(-) diff --git a/google/cloud/firestore_v1/async_pipeline.py b/google/cloud/firestore_v1/async_pipeline.py index 34b9eb45b..7338ae971 100644 --- a/google/cloud/firestore_v1/async_pipeline.py +++ b/google/cloud/firestore_v1/async_pipeline.py @@ -25,7 +25,38 @@ class Pipeline(_BasePipeline): + """ + Pipelines allow for complex data transformations and queries involving + multiple stages like filtering, projection, aggregation, and vector search. + + This class extends `_BasePipeline` and provides methods to execute the + defined pipeline stages using an asynchronous `AsyncClient`. + + Usage Example: + >>> import asyncio + >>> from google.cloud.firestore_v1.pipeline_expressions import Field, gt + >>> + >>> async def run_pipeline(): + ... client = AsyncClient(...) + ... pipeline = client.pipeline() + ... .collection("books") + ... .where(gt(Field.of("published"), 1980)) + ... .select("title", "author") + ... async for result in pipeline.execute_async(): + ... print(result) + >>> + >>> asyncio.run(run_pipeline()) + + Use `AsyncClient.pipeline()` to create instances of this class. + """ def __init__(self, client:AsyncClient, *stages: stages.Stage): + """ + Initializes an asynchronous Pipeline. + + Args: + client: The asynchronous `AsyncClient` instance to use for execution. + *stages: Initial stages for the pipeline. + """ super().__init__(*stages) self._client = client diff --git a/google/cloud/firestore_v1/base_pipeline.py b/google/cloud/firestore_v1/base_pipeline.py index bdfb1a8ed..112c579ed 100644 --- a/google/cloud/firestore_v1/base_pipeline.py +++ b/google/cloud/firestore_v1/base_pipeline.py @@ -30,7 +30,19 @@ class _BasePipeline: + """ + Base class for building Firestore data transformation and query pipelines. + + This class is not intended to be instantiated directly. Use `Client.pipeline()` + or `AsyncClient.pipeline()` to create pipeline instances. + """ def __init__(self, *stages: stages.Stage): + """ + Initializes a new pipeline with the given stages. + + Args: + *stages: Initial stages for the pipeline. + """ self.stages = list(stages) def __repr__(self): @@ -46,18 +58,130 @@ def _to_pb(self) -> StructuredPipeline_pb: return StructuredPipeline_pb(pipeline={"stages":[s._to_pb() for s in self.stages]}) def add_fields(self, *fields: Selectable) -> Self: + """ + Adds new fields to outputs from previous stages. + + This stage allows you to compute values on-the-fly based on existing data + from previous stages or constants. You can use this to create new fields + or overwrite existing ones (if there is name overlap). + + The added fields are defined using `Selectable` expressions, which can be: + - `Field`: References an existing document field. + - `Function`: Performs a calculation using functions like `add`, + `multiply` with assigned aliases using `Expr.as_()`. + + Example: + >>> from google.cloud.firestore_v1.pipeline_expressions import Field, add + >>> pipeline = firestore.pipeline().collection("books") + >>> pipeline = pipeline.add_fields( + ... Field.of("rating").as_("bookRating"), # Rename 'rating' to 'bookRating' + ... add(5, Field.of("quantity")).as_("totalCost") # Calculate 'totalCost' + ... ) + + Args: + *fields: The fields to add to the documents, specified as `Selectable` + expressions. + + Returns: + A reference to this pipeline instance. Used for method chaining + """ self.stages.append(stages.AddFields(*fields)) return self def remove_fields(self, *fields: Field | str) -> Self: + """ + Removes fields from outputs of previous stages. + + Example: + >>> from google.cloud.firestore_v1.pipeline_expressions import Field + >>> pipeline = firestore.pipeline().collection("books") + >>> # Remove by name + >>> pipeline = pipeline.remove_fields("rating", "cost") + >>> # Remove by Field object + >>> pipeline = pipeline.remove_fields(Field.of("rating"), Field.of("cost")) + + + Args: + *fields: The fields to remove, specified as field names (str) or + `Field` objects. + + Returns: + A reference to this pipeline instance. Used for method chaining + """ self.stages.append(stages.RemoveFields(*fields)) return self def select(self, *selections: str | Selectable) -> Self: + """ + Selects or creates a set of fields from the outputs of previous stages. + + The selected fields are defined using `Selectable` expressions or field names: + - `Field`: References an existing document field. + - `Function`: Represents the result of a function with an assigned alias + name using `Expr.as_()`. + - `str`: The name of an existing field. + + If no selections are provided, the output of this stage is empty. Use + `add_fields()` instead if only additions are desired. + + Example: + >>> from google.cloud.firestore_v1.pipeline_expressions import Field, to_upper + >>> pipeline = firestore.pipeline().collection("books") + >>> # Select by name + >>> pipeline = pipeline.select("name", "address") + >>> # Select using Field and Function expressions + >>> pipeline = pipeline.select( + ... Field.of("name"), + ... Field.of("address").to_upper().as_("upperAddress"), + ... ) + + Args: + *selections: The fields to include in the output documents, specified as + field names (str) or `Selectable` expressions. + + Returns: + A reference to this pipeline instance. Used for method chaining + """ self.stages.append(stages.Select(*selections)) return self def where(self, condition: FilterCondition) -> Self: + """ + Filters the documents from previous stages to only include those matching + the specified `FilterCondition`. + + This stage allows you to apply conditions to the data, similar to a "WHERE" + clause in SQL. You can filter documents based on their field values, using + implementations of `FilterCondition`, typically including but not limited to: + - field comparators: `eq`, `lt` (less than), `gt` (greater than), etc. + - logical operators: `and_`, `or_`, `not_`, etc. + - advanced functions: `regex_match`, `array_contains`, etc. + + Example: + >>> from google.cloud.firestore_v1.pipeline_expressions import Field, and_, gt, eq + >>> pipeline = firestore.pipeline().collection("books") + >>> # Using static functions + >>> pipeline = pipeline.where( + ... and_( + ... gt(Field.of("rating"), 4.0), # Filter for ratings > 4.0 + ... eq(Field.of("genre"), "Science Fiction") # Filter for genre + ... ) + ... ) + >>> # Using methods on expressions + >>> pipeline = pipeline.where( + ... and_( + ... Field.of("rating").gt(4.0), + ... Field.of("genre").eq("Science Fiction") + ... ) + ... ) + + + Args: + condition: The `FilterCondition` to apply. + + Returns: + A reference to this pipeline instance. Used for method chaining + """ self.stages.append(stages.Where(condition)) return self @@ -66,13 +190,82 @@ def find_nearest( field: str | Expr, vector: Sequence[float] | "Vector", distance_measure: "DistanceMeasure", - limit: int | None, options: Optional[stages.FindNearestOptions] = None, ) -> Self: + """ + Performs vector distance (similarity) search with given parameters on the + stage inputs. + + This stage adds a "nearest neighbor search" capability to your pipelines. + Given a field or expression that evaluates to a vector and a target vector, + this stage will identify and return the inputs whose vector is closest to + the target vector, using the specified distance measure and options. + + Example: + >>> from google.cloud.firestore_v1.base_vector_query import DistanceMeasure + >>> from google.cloud.firestore_v1.pipeline_stages import FindNearestOptions + >>> from google.cloud.firestore_v1.pipeline_expressions import Field + >>> + >>> target_vector = [0.1, 0.2, 0.3] + >>> pipeline = firestore.pipeline().collection("books") + >>> # Find using field name + >>> pipeline = pipeline.find_nearest( + ... "topicVectors", + ... target_vector, + ... DistanceMeasure.COSINE, + ... options=FindNearestOptions(limit=10, distance_field="distance") + ... ) + >>> # Find using Field expression + >>> pipeline = pipeline.find_nearest( + ... Field.of("topicVectors"), + ... target_vector, + ... DistanceMeasure.COSINE, + ... options=FindNearestOptions(limit=10, distance_field="distance") + ... ) + + Args: + field: The name of the field (str) or an expression (`Expr`) that + evaluates to the vector data. This field should store vector values. + vector: The target vector (sequence of floats or `Vector` object) to + compare against. + distance_measure: The distance measure (`DistanceMeasure`) to use + (e.g., `DistanceMeasure.COSINE`, `DistanceMeasure.EUCLIDEAN`). + limit: The maximum number of nearest neighbors to return. + options: Configuration options (`FindNearestOptions`) for the search, + such as limit and output distance field name. + + Returns: + A reference to this pipeline instance. Used for method chaining + """ self.stages.append(stages.FindNearest(field, vector, distance_measure, options)) return self def sort(self, *orders: stages.Ordering) -> Self: + """ + Sorts the documents from previous stages based on one or more `Ordering` criteria. + + This stage allows you to order the results of your pipeline. You can specify + multiple `Ordering` instances to sort by multiple fields or expressions in + ascending or descending order. If documents have the same value for a sorting + criterion, the next specified ordering will be used. If all orderings result + in equal comparison, the documents are considered equal and the relative order + is unspecified. + + Example: + >>> from google.cloud.firestore_v1.pipeline_expressions import Field + >>> pipeline = firestore.pipeline().collection("books") + >>> # Sort books by rating descending, then title ascending + >>> pipeline = pipeline.sort( + ... Field.of("rating").descending(), + ... Field.of("title").ascending() + ... ) + + Args: + *orders: One or more `Ordering` instances specifying the sorting criteria. + + Returns: + A reference to this pipeline instance. Used for method chaining + """ self.stages.append(stages.Sort(*orders)) return self @@ -81,14 +274,100 @@ def replace( field: Selectable, mode: stages.Replace.Mode = stages.Replace.Mode.FULL_REPLACE, ) -> Self: + """ + Replaces the entire document content with the value of a specified field, + typically a map. + + This stage allows you to emit a map value as the new document structure. + Each key of the map becomes a field in the output document, containing the + corresponding value. + + Example: + Input document: + ```json + { + "name": "John Doe Jr.", + "parents": { + "father": "John Doe Sr.", + "mother": "Jane Doe" + } + } + ``` + + >>> from google.cloud.firestore_v1.pipeline_expressions import Field + >>> pipeline = firestore.pipeline().collection("people") + >>> # Emit the 'parents' map as the document + >>> pipeline = pipeline.replace(Field.of("parents")) + + Output document: + ```json + { + "father": "John Doe Sr.", + "mother": "Jane Doe" + } + ``` + + Args: + field: The `Selectable` field containing the map whose content will + replace the document. + mode: The replacement mode + + Returns: + A reference to this pipeline instance. Used for method chaining + """ self.stages.append(stages.Replace(field, mode)) return self def sample(self, limit_or_options: int | SampleOptions) -> Self: + """ + Performs a pseudo-random sampling of the documents from the previous stage. + + This stage filters documents pseudo-randomly. + - If an `int` limit is provided, it specifies the maximum number of documents + to emit. If fewer documents are available, all are passed through. + - If `SampleOptions` are provided, they specify how sampling is performed + (e.g., by document count or percentage). + + Example: + >>> from google.cloud.firestore_v1.pipeline_expressions import SampleOptions + >>> pipeline = firestore.pipeline().collection("books") + >>> # Sample 10 books, if available. + >>> pipeline = pipeline.sample(10) + >>> # Sample 50% of books. + >>> pipeline = pipeline.sample(SampleOptions(n=50, mode=SampleOptions.Mode.PERCENTAGE)) + + + Args: + limit_or_options: Either an integer specifying the maximum number of + documents to sample, or a `SampleOptions` object. + + Returns: + A reference to this pipeline instance. Used for method chaining + """ self.stages.append(stages.Sample(limit_or_options)) return self def union(self, other: Self) -> Self: + """ + Performs a union of all documents from this pipeline and another pipeline, + including duplicates. + + This stage passes through documents from the previous stage of this pipeline, + and also passes through documents from the previous stage of the `other` + pipeline provided. The order of documents emitted from this stage is undefined. + + Example: + >>> books_pipeline = firestore.pipeline().collection("books") + >>> magazines_pipeline = firestore.pipeline().collection("magazines") + >>> # Emit documents from both collections + >>> combined_pipeline = books_pipeline.union(magazines_pipeline) + + Args: + other: The other `Pipeline` whose results will be unioned with this one. + + Returns: + A reference to this pipeline instance. Used for method chaining + """ self.stages.append(stages.Union(other)) return self @@ -97,18 +376,124 @@ def unnest( field_name: str, options: Optional[stages.UnnestOptions] = None, ) -> Self: + """ + Produces a document for each element in an array field from the previous stage document. + + For each input document, this stage emits zero or more augmented documents. + It takes an array field specified by `field_name`. For each element in that + array, it produces an output document where the original array field is + replaced by the current element's value. + + Optionally, `UnnestOptions` can specify a field to store the original index + of the element within the array. + + Example: + Input document: + ```json + { "title": "The Hitchhiker's Guide", "tags": [ "comedy", "sci-fi" ], ... } + ``` + + >>> from google.cloud.firestore_v1.pipeline_stages import UnnestOptions + >>> pipeline = firestore.pipeline().collection("books") + >>> # Emit a document for each tag + >>> pipeline = pipeline.unnest("tags") + >>> # Emit a document for each tag, including the index + >>> pipeline = pipeline.unnest("tags", options=UnnestOptions(index_field="tagIndex")) + + + Output documents (without options): + ```json + { "title": "The Hitchhiker's Guide", "tags": "comedy", ... } + { "title": "The Hitchhiker's Guide", "tags": "sci-fi", ... } + ``` + + Output documents (with index_field="tagIndex"): + ```json + { "title": "The Hitchhiker's Guide", "tags": "comedy", "tagIndex": 0, ... } + { "title": "The Hitchhiker's Guide", "tags": "sci-fi", "tagIndex": 1, ... } + ``` + + Args: + field_name: The name of the field containing the array to unnest. + options: Optional `UnnestOptions` to configure behavior, like adding an index field. + + Returns: + A reference to this pipeline instance. Used for method chaining + """ self.stages.append(stages.Unnest(field_name, options)) return self def generic_stage(self, name: str, *params: Expr) -> Self: + """ + Adds a generic, named stage to the pipeline with specified parameters. + + This method provides a flexible way to extend the pipeline's functionality + by adding custom stages. Each generic stage is defined by a unique `name` + and a set of `params` that control its behavior. + + Example: + >>> # Assume we don't have a built-in "where" stage + >>> pipeline = firestore.pipeline().collection("books") + >>> pipeline = pipeline.generic_stage("where", [Field.of("published").lt(900)]) + >>> pipeline = pipeline.select("title", "author") + + Args: + name: The name of the generic stage. + *params: A sequence of `Expr` objects representing the parameters for the stage. + + Returns: + A reference to this pipeline instance. Used for method chaining + """ self.stages.append(stages.GenericStage(name, *params)) return self def offset(self, offset: int) -> Self: + """ + Skips the first `offset` number of documents from the results of previous stages. + + This stage is useful for implementing pagination, allowing you to retrieve + results in chunks. It is typically used in conjunction with `limit()` to + control the size of each page. + + Example: + >>> from google.cloud.firestore_v1.pipeline_expressions import Field + >>> pipeline = firestore.pipeline().collection("books") + >>> # Retrieve the second page of 20 results (assuming sorted) + >>> pipeline = pipeline.sort(Field.of("published").descending()) + >>> pipeline = pipeline.offset(20) # Skip the first 20 results + >>> pipeline = pipeline.limit(20) # Take the next 20 results + + Args: + offset: The non-negative number of documents to skip. + + Returns: + A reference to this pipeline instance. Used for method chaining + """ self.stages.append(stages.Offset(offset)) return self def limit(self, limit: int) -> Self: + """ + Limits the maximum number of documents returned by previous stages to `limit`. + + This stage is useful for controlling the size of the result set, often used for: + - **Pagination:** In combination with `offset()` to retrieve specific pages. + - **Top-N queries:** To get a limited number of results after sorting. + - **Performance:** To prevent excessive data transfer. + + Example: + >>> from google.cloud.firestore_v1.pipeline_expressions import Field + >>> pipeline = firestore.pipeline().collection("books") + >>> # Limit the results to the top 10 highest-rated books + >>> pipeline = pipeline.sort(Field.of("rating").descending()) + >>> pipeline = pipeline.limit(10) + + Args: + limit: The non-negative maximum number of documents to return. + + Returns: + A reference to this pipeline instance. Used for method chaining + """ self.stages.append(stages.Limit(limit)) return self @@ -117,11 +502,80 @@ def aggregate( *accumulators: ExprWithAlias[Accumulator], groups: Sequence[str | Selectable] = (), ) -> Self: + """ + Performs aggregation operations on the documents from previous stages, + optionally grouped by specified fields or expressions. + + This stage allows you to calculate aggregate values (like sum, average, count, + min, max) over a set of documents. + + - **Accumulators:** Define the aggregation calculations using `Accumulator` + expressions (e.g., `sum()`, `avg()`, `count()`, `min()`, `max()`) combined + with `as_()` to name the result field. + - **Groups:** Optionally specify fields (by name or `Selectable`) to group + the documents by. Aggregations are then performed within each distinct group. + If no groups are provided, the aggregation is performed over the entire input. + + Example: + >>> from google.cloud.firestore_v1.pipeline_expressions import Field, avg, count_all + >>> pipeline = firestore.pipeline().collection("books") + >>> # Calculate the average rating and total count for all books + >>> pipeline = pipeline.aggregate( + ... avg(Field.of("rating")).as_("averageRating"), + ... count_all().as_("totalBooks") + ... ) + >>> # Calculate the average rating for each genre + >>> pipeline = pipeline.aggregate( + ... avg(Field.of("rating")).as_("avg_rating"), + ... groups=["genre"] # Group by the 'genre' field + ... ) + >>> # Calculate the count for each author, grouping by Field object + >>> pipeline = pipeline.aggregate( + ... count_all().as_("bookCount"), + ... groups=[Field.of("author")] + ... ) + + + Args: + *accumulators: One or more `ExprWithAlias[Accumulator]` expressions defining + the aggregations to perform and their output names. + groups: An optional sequence of field names (str) or `Selectable` + expressions to group by before aggregating. + + Returns: + A reference to this pipeline instance. Used for method chaining + """ self.stages.append(stages.Aggregate(*accumulators, groups=groups)) return self def distinct(self, *fields: str | Selectable) -> Self: - self.stages.append(stages.Distinct(*fields)) - return self + """ + Returns documents with distinct combinations of values for the specified + fields or expressions. + + This stage filters the results from previous stages to include only one + document for each unique combination of values in the specified `fields`. + The output documents contain only the fields specified in the `distinct` call. + Example: + >>> from google.cloud.firestore_v1.pipeline_expressions import Field, to_upper + >>> pipeline = firestore.pipeline().collection("books") + >>> # Get a list of unique genres (output has only 'genre' field) + >>> pipeline = pipeline.distinct("genre") + >>> # Get unique combinations of author (uppercase) and genre + >>> pipeline = pipeline.distinct( + ... Field.of("author").to_upper().as_("authorUpper"), + ... Field.of("genre") + ... ) + + Args: + *fields: Field names (str) or `Selectable` expressions to consider when + determining distinct value combinations. The output will only + contain these fields/expressions. + + Returns: + A reference to this pipeline instance. Used for method chaining + """ + self.stages.append(stages.Distinct(*fields)) + return self diff --git a/google/cloud/firestore_v1/pipeline.py b/google/cloud/firestore_v1/pipeline.py index b545fcc32..92b00d96c 100644 --- a/google/cloud/firestore_v1/pipeline.py +++ b/google/cloud/firestore_v1/pipeline.py @@ -25,7 +25,34 @@ class Pipeline(_BasePipeline): + """ + Pipelines allow for complex data transformations and queries involving + multiple stages like filtering, projection, aggregation, and vector search. + + Usage Example: + >>> from google.cloud.firestore_v1.pipeline_expressions import Field, gt + >>> + >>> async def run_pipeline(): + ... client = Client(...) + ... pipeline = client.pipeline() + ... .collection("books") + ... .where(gt(Field.of("published"), 1980)) + ... .select("title", "author") + ... for result in pipeline.execute(): + ... print(result) + >>> + >>> asyncio.run(run_pipeline()) + + Use `Client.pipeline()` to create instances of this class. + """ def __init__(self, client:Client, *stages: stages.Stage): + """ + Initializes a Pipeline. + + Args: + client: The `Client` instance to use for execution. + *stages: Initial stages for the pipeline. + """ super().__init__(*stages) self._client = client @@ -37,11 +64,4 @@ def execute(self) -> Iterable["ExecutePipelineResponse"]: read_time=datetime.datetime.now(), ) results = self._client._firestore_api.execute_pipeline(request) - return results - - async def execute_async(self) -> AsyncIterable["ExecutePipelineResponse"]: - from google.cloud.firestore_v1.async_client import AsyncClient - if not isinstance(self._client, AsyncClient): - raise TypeError("execute_async requires AsyncClient") - # TODO - raise NotImplementedError + return results \ No newline at end of file From 549c59064c7a79b3085fc145af745c60354a02ff Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 27 Mar 2025 15:44:22 -0700 Subject: [PATCH 059/131] added docstrings --- google/cloud/firestore_v1/base_pipeline.py | 20 +++++++--- google/cloud/firestore_v1/pipeline_stages.py | 39 +++++++++++++++++++- 2 files changed, 53 insertions(+), 6 deletions(-) diff --git a/google/cloud/firestore_v1/base_pipeline.py b/google/cloud/firestore_v1/base_pipeline.py index 112c579ed..7e48ef44e 100644 --- a/google/cloud/firestore_v1/base_pipeline.py +++ b/google/cloud/firestore_v1/base_pipeline.py @@ -384,8 +384,7 @@ def unnest( array, it produces an output document where the original array field is replaced by the current element's value. - Optionally, `UnnestOptions` can specify a field to store the original index - of the element within the array. +. Example: Input document: @@ -397,9 +396,6 @@ def unnest( >>> pipeline = firestore.pipeline().collection("books") >>> # Emit a document for each tag >>> pipeline = pipeline.unnest("tags") - >>> # Emit a document for each tag, including the index - >>> pipeline = pipeline.unnest("tags", options=UnnestOptions(index_field="tagIndex")) - Output documents (without options): ```json @@ -407,6 +403,20 @@ def unnest( { "title": "The Hitchhiker's Guide", "tags": "sci-fi", ... } ``` + Optionally, `UnnestOptions` can specify a field to store the original index + of the element within the array + + Example: + Input document: + ```json + { "title": "The Hitchhiker's Guide", "tags": [ "comedy", "sci-fi" ], ... } + ``` + + >>> from google.cloud.firestore_v1.pipeline_stages import UnnestOptions + >>> pipeline = firestore.pipeline().collection("books") + >>> # Emit a document for each tag, including the index + >>> pipeline = pipeline.unnest("tags", options=UnnestOptions(index_field="tagIndex")) + Output documents (with index_field="tagIndex"): ```json { "title": "The Hitchhiker's Guide", "tags": "comedy", "tagIndex": 0, ... } diff --git a/google/cloud/firestore_v1/pipeline_stages.py b/google/cloud/firestore_v1/pipeline_stages.py index 0ebf6aa8a..8886955e8 100644 --- a/google/cloud/firestore_v1/pipeline_stages.py +++ b/google/cloud/firestore_v1/pipeline_stages.py @@ -38,6 +38,13 @@ class FindNearestOptions: + """Options for configuring the `FindNearest` pipeline stage. + + Attributes: + limit (Optional[int]): The maximum number of nearest neighbors to return. + distance_field (Optional[Field]): An optional field to store the calculated + distance in the output documents. + """ def __init__( self, limit: Optional[int] = None, @@ -48,11 +55,23 @@ def __init__( class UnnestOptions: + """Options for configuring the `Unnest` pipeline stage. + + Attributes: + index_field (str): The name of the field to add to each output document, + storing the original 0-based index of the element within the array. + """ def __init__(self, index_field: str): self.index_field = index_field class Stage: + """Base class for all pipeline stages. + + Each stage represents a specific operation (e.g., filtering, sorting, + transforming) within a Firestore pipeline. Subclasses define the specific + arguments and behavior for each operation. + """ def __init__(self, custom_name: Optional[str] = None): self.name = custom_name or type(self).__name__.lower() @@ -73,6 +92,7 @@ def __repr__(self): class AddFields(Stage): + """Adds new fields to outputs from previous stages.""" def __init__(self, *fields: Selectable): super().__init__("add_fields") self.fields = list(fields) @@ -81,6 +101,7 @@ def _pb_args(self): return [Value(map_value={"fields": {m[0]: m[1] for m in [f._to_map() for f in self.fields]}})] class Aggregate(Stage): + """Performs aggregation operations, optionally grouped.""" def __init__( self, *extra_accumulators: ExprWithAlias[Accumulator], @@ -108,6 +129,7 @@ def __repr__(self): class Collection(Stage): + """Specifies a collection as the initial data source.""" def __init__(self, path: str): super().__init__() if not path.startswith("/"): @@ -118,6 +140,7 @@ def _pb_args(self): return [Value(reference_value=self.path)] class CollectionGroup(Stage): + """Specifies a collection group as the initial data source.""" def __init__(self, collection_id: str): super().__init__("collection_group") self.collection_id = collection_id @@ -127,10 +150,12 @@ def _pb_args(self): class Database(Stage): + """Specifies the default database as the initial data source.""" def __init__(self): super().__init__() class Distinct(Stage): + """Returns documents with distinct combinations of specified field values.""" def __init__(self, *fields: str | Selectable): super().__init__() self.fields: list[Selectable] = [Field(f) if isinstance(f, str) else f for f in fields] @@ -140,6 +165,7 @@ def _pb_args(self) -> list[Value]: class Documents(Stage): + """Specifies specific documents as the initial data source.""" def __init__(self, *paths: str): super().__init__() self.paths = paths @@ -154,6 +180,7 @@ def _pb_args(self): class FindNearest(Stage): + """Performs vector distance (similarity) search.""" def __init__( self, field: str | Expr, @@ -183,6 +210,7 @@ def _pb_options(self) -> dict[str, Value]: return options class GenericStage(Stage): + """Represents a generic, named stage with parameters.""" def __init__(self, name: str, *params: Expr | Value): super().__init__(name) self.params: list[Value] = [p._to_pb() if isinstance(p, Expr) else p for p in params] @@ -192,6 +220,7 @@ def _pb_args(self): class Limit(Stage): + """Limits the maximum number of documents returned.""" def __init__(self, limit: int): super().__init__() self.limit = limit @@ -201,6 +230,7 @@ def _pb_args(self): class Offset(Stage): + """Skips a specified number of documents.""" def __init__(self, offset: int): super().__init__() self.offset = offset @@ -210,6 +240,7 @@ def _pb_args(self): class RemoveFields(Stage): + """Removes specified fields from outputs.""" def __init__(self, *fields: str | Field): super().__init__("remove_fields") self.fields = [Field(f) if isinstance(f, str) else f for f in fields] @@ -219,6 +250,7 @@ def _pb_args(self) -> list[Value]: class Replace(Stage): + """Replaces the document content with the value of a specified field.""" class Mode(Enum): FULL_REPLACE = "full_replace" MERGE_PREFER_NEXT = "merge_prefer_nest" @@ -234,7 +266,7 @@ def _pb_args(self): class Sample(Stage): - + """Performs pseudo-random sampling of documents.""" def __init__(self, limit_or_options: int | SampleOptions): super().__init__() if isinstance(limit_or_options, int): @@ -248,6 +280,7 @@ def _pb_args(self): class Select(Stage): + """Selects or creates a set of fields.""" def __init__(self, *selections: str | Selectable): super().__init__() self.projections = [Field(s) if isinstance(s, str) else s for s in selections] @@ -257,6 +290,7 @@ def _pb_args(self) -> list[Value]: class Sort(Stage): + """Sorts documents based on specified criteria.""" def __init__(self, *orders: "Ordering"): super().__init__() self.orders = list(orders) @@ -266,6 +300,7 @@ def _pb_args(self): class Union(Stage): + """Performs a union of documents from two pipelines.""" def __init__(self, other: Pipeline): super().__init__() self.other = other @@ -275,6 +310,7 @@ def _pb_args(self): class Unnest(Stage): + """Produces a document for each element in an array field.""" def __init__(self, field: Field | str, options: Optional["UnnestOptions"] = None): super().__init__() self.field: Field = Field(field) if isinstance(field, str) else field @@ -291,6 +327,7 @@ def _pb_options(self): class Where(Stage): + """Filters documents based on a specified condition.""" def __init__(self, condition: FilterCondition): super().__init__() self.condition = condition From 376015c1bb57325baa5cf1a5f4832a92d819566b Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 27 Mar 2025 16:12:33 -0700 Subject: [PATCH 060/131] fixed sample_options --- google/cloud/firestore_v1/base_pipeline.py | 3 +- .../firestore_v1/pipeline_expressions.py | 36 +++++++++++++++---- google/cloud/firestore_v1/pipeline_stages.py | 7 ++-- tests/system/pipeline_e2e.yaml | 2 +- 4 files changed, 38 insertions(+), 10 deletions(-) diff --git a/google/cloud/firestore_v1/base_pipeline.py b/google/cloud/firestore_v1/base_pipeline.py index 7e48ef44e..c20b715e9 100644 --- a/google/cloud/firestore_v1/base_pipeline.py +++ b/google/cloud/firestore_v1/base_pipeline.py @@ -333,8 +333,9 @@ def sample(self, limit_or_options: int | SampleOptions) -> Self: >>> pipeline = firestore.pipeline().collection("books") >>> # Sample 10 books, if available. >>> pipeline = pipeline.sample(10) + >>> pipeline = pipeline.sample(SampleOptions.doc_limit(10)) >>> # Sample 50% of books. - >>> pipeline = pipeline.sample(SampleOptions(n=50, mode=SampleOptions.Mode.PERCENTAGE)) + >>> pipeline = pipeline.sample(SampleOptions.percentage(0.5)) Args: diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index 325a16843..92bc5d55a 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -65,18 +65,42 @@ def _to_pb(self) -> Value: } ) -@dataclass class SampleOptions: """Options for the 'sample' pipeline stage.""" class Mode(Enum): DOCUMENTS = "documents" - PERCENTAGE = "percent" + PERCENT = "percent" - n: int - mode: Mode + def __init__(self, value: int | float, mode:Mode | str): + self.value = value + self.mode = SampleOptions.Mode[mode.upper()] if isinstance(mode, str) else mode - def __post_init__(self): - self.mode = SampleOptions.Mode(self.mode) if isinstance(self.mode, str) else self.mode + def __repr__(self): + if self.mode == SampleOptions.Mode.DOCUMENTS: + mode_str = "doc_limit" + else: + mode_str = "percentage" + return f"SampleOptions.{mode_str}({self.value})" + + @staticmethod + def doc_limit(value:int): + """ + Sample a set number of documents + + Args: + value: number of documents to sample + """ + return SampleOptions(value, mode=SampleOptions.Mode.DOCUMENTS) + + @staticmethod + def percentage(value:float): + """ + Sample a percentage of documents + + Args: + value: percentage of documents to return + """ + return SampleOptions(value, mode=SampleOptions.Mode.PERCENTAGE) class Expr(ABC): """Represents an expression that can be evaluated to a value within the diff --git a/google/cloud/firestore_v1/pipeline_stages.py b/google/cloud/firestore_v1/pipeline_stages.py index 8886955e8..0de1814d9 100644 --- a/google/cloud/firestore_v1/pipeline_stages.py +++ b/google/cloud/firestore_v1/pipeline_stages.py @@ -270,13 +270,16 @@ class Sample(Stage): def __init__(self, limit_or_options: int | SampleOptions): super().__init__() if isinstance(limit_or_options, int): - options = SampleOptions(limit_or_options, SampleOptions.Mode.DOCUMENTS) + options = SampleOptions.doc_limit(limit_or_options) else: options = limit_or_options self.options: SampleOptions = options def _pb_args(self): - return [Value(integer_value=self.options.n), Value(string_value=self.options.mode.value)] + if self.options.mode == SampleOptions.Mode.DOCUMENTS: + return [Value(integer_value=self.options.value), Value(string_value="documents")] + else: + return [Value(double_value=self.options.value), Value(string_value="percent")] class Select(Stage): diff --git a/tests/system/pipeline_e2e.yaml b/tests/system/pipeline_e2e.yaml index d88461add..f3162d540 100644 --- a/tests/system/pipeline_e2e.yaml +++ b/tests/system/pipeline_e2e.yaml @@ -922,7 +922,7 @@ tests: - Collection: books - Sample: - SampleOptions: - - 60 + - 0.6 - percent results_num: 6 # Results will vary due to randomness - description: testUnion From e3de1b8c1b2cdcfd532a862b84593467d8b98ced Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 27 Mar 2025 16:33:13 -0700 Subject: [PATCH 061/131] catch expected errors in test --- tests/system/pipeline_e2e.yaml | 2 +- tests/system/test_pipeline_acceptance.py | 16 +++++++++++++--- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/tests/system/pipeline_e2e.yaml b/tests/system/pipeline_e2e.yaml index f3162d540..414e49144 100644 --- a/tests/system/pipeline_e2e.yaml +++ b/tests/system/pipeline_e2e.yaml @@ -168,7 +168,7 @@ tests: - Aggregate: accumulators: [] groups: [genre] - error: "Cannot groupBy without accumulators" + error: ".* requires at least one accumulator" - description: testDistinct pipeline: - Collection: books diff --git a/tests/system/test_pipeline_acceptance.py b/tests/system/test_pipeline_acceptance.py index c0a02abfb..772698095 100644 --- a/tests/system/test_pipeline_acceptance.py +++ b/tests/system/test_pipeline_acceptance.py @@ -17,12 +17,15 @@ import os import pytest import yaml +import re from typing import Any +from contextlib import nullcontext # from google.cloud.firestore_v1.pipeline_stages import * from google.cloud.firestore_v1 import pipeline_stages from google.cloud.firestore_v1 import pipeline_expressions from google.cloud.firestore_v1.pipeline import Pipeline +from google.api_core.exceptions import GoogleAPIError from google.cloud.firestore import Client @@ -124,9 +127,16 @@ def parse_expressions(client, yaml_element: Any): ) def test_e2e_scenario(test_dict): client = Client(project=FIRESTORE_PROJECT, database=FIRESTORE_TEST_DB) - pipeline = parse_pipeline(client, test_dict["pipeline"]) - print(pipeline._to_pb()) - pipeline.execute() + error_regex = test_dict.get("error", None) + with pytest.raises(GoogleAPIError) if error_regex else nullcontext() as ctx: + pipeline = parse_pipeline(client, test_dict["pipeline"]) + print(pipeline._to_pb()) + pipeline.execute() + # check for error message if expected + if error_regex: + found_error = str(ctx.value) + match = re.search(error_regex, found_error) + assert match, f"error '{found_error}' does not match '{error_regex}'" # before_ast = ast.parse(test_dict["before"]) # got_ast = before_ast From f725bc7738f31ba8aeba1ddb94d5e356154bf707 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 27 Mar 2025 17:15:43 -0700 Subject: [PATCH 062/131] added alias to unnest --- google/cloud/firestore_v1/base_pipeline.py | 28 +++++++++++--------- google/cloud/firestore_v1/pipeline_stages.py | 10 +++++-- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/google/cloud/firestore_v1/base_pipeline.py b/google/cloud/firestore_v1/base_pipeline.py index c20b715e9..7d622642a 100644 --- a/google/cloud/firestore_v1/base_pipeline.py +++ b/google/cloud/firestore_v1/base_pipeline.py @@ -374,18 +374,18 @@ def union(self, other: Self) -> Self: def unnest( self, - field_name: str, + field: str | Selectable, + alias: str | Field | None = None, options: Optional[stages.UnnestOptions] = None, ) -> Self: """ Produces a document for each element in an array field from the previous stage document. - For each input document, this stage emits zero or more augmented documents. - It takes an array field specified by `field_name`. For each element in that - array, it produces an output document where the original array field is - replaced by the current element's value. - -. + For each previous stage document, this stage will emit zero or more augmented documents. The + input array found in the previous stage document field specified by the `fieldName` parameter, + will emit an augmented document for each input array element. The input array element will + augment the previous stage document by setting the `alias` field with the array element value. + If `alias` is unset, the data in `field` will be overwritten. Example: Input document: @@ -396,12 +396,12 @@ def unnest( >>> from google.cloud.firestore_v1.pipeline_stages import UnnestOptions >>> pipeline = firestore.pipeline().collection("books") >>> # Emit a document for each tag - >>> pipeline = pipeline.unnest("tags") + >>> pipeline = pipeline.unnest("tags", alias="tag") Output documents (without options): ```json - { "title": "The Hitchhiker's Guide", "tags": "comedy", ... } - { "title": "The Hitchhiker's Guide", "tags": "sci-fi", ... } + { "title": "The Hitchhiker's Guide", "tag": "comedy", ... } + { "title": "The Hitchhiker's Guide", "tag": "sci-fi", ... } ``` Optionally, `UnnestOptions` can specify a field to store the original index @@ -425,13 +425,15 @@ def unnest( ``` Args: - field_name: The name of the field containing the array to unnest. - options: Optional `UnnestOptions` to configure behavior, like adding an index field. + field: The name of the field containing the array to unnest. + alias The alias field is used as the field name for each element within the output array. + If unset, or if `alias` matches the `field`, the output data will overwrite the original field. + options: Optional `UnnestOptions` to configure additional behavior, like adding an index field. Returns: A reference to this pipeline instance. Used for method chaining """ - self.stages.append(stages.Unnest(field_name, options)) + self.stages.append(stages.Unnest(field, alias, options)) return self def generic_stage(self, name: str, *params: Expr) -> Self: diff --git a/google/cloud/firestore_v1/pipeline_stages.py b/google/cloud/firestore_v1/pipeline_stages.py index 0de1814d9..bce9bc58c 100644 --- a/google/cloud/firestore_v1/pipeline_stages.py +++ b/google/cloud/firestore_v1/pipeline_stages.py @@ -314,13 +314,19 @@ def _pb_args(self): class Unnest(Stage): """Produces a document for each element in an array field.""" - def __init__(self, field: Field | str, options: Optional["UnnestOptions"] = None): + def __init__(self, field: Selectable | str, alias: Field | str | None=None, options: UnnestOptions|None=None): super().__init__() self.field: Field = Field(field) if isinstance(field, str) else field + if alias is None: + self.alias = self.field + elif isinstance(alias, str): + self.alias = Field(alias) + else: + self.alias = alias self.options = options def _pb_args(self): - return [self.field._to_pb()] + return [self.field._to_pb(), self.alias._to_pb()] def _pb_options(self): options = {} From 128ab1c9d6c395953f5372d8a96cc9075aebdf86 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 27 Mar 2025 17:28:05 -0700 Subject: [PATCH 063/131] fixed replace --- google/cloud/firestore_v1/pipeline_stages.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/cloud/firestore_v1/pipeline_stages.py b/google/cloud/firestore_v1/pipeline_stages.py index bce9bc58c..01df06c2a 100644 --- a/google/cloud/firestore_v1/pipeline_stages.py +++ b/google/cloud/firestore_v1/pipeline_stages.py @@ -246,7 +246,7 @@ def __init__(self, *fields: str | Field): self.fields = [Field(f) if isinstance(f, str) else f for f in fields] def _pb_args(self) -> list[Value]: - return [Value(map_value={"fields": {m[0]: m[1] for m in [f._to_map() for f in self.fields]}})] + return [f._to_pb() for f in self.fields] class Replace(Stage): From 64a8bda41729d6c582a5104d45dc537af0ffaab3 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 28 Mar 2025 15:17:07 -0700 Subject: [PATCH 064/131] added pipeline to client --- google/cloud/firestore_v1/async_client.py | 4 ++++ google/cloud/firestore_v1/async_pipeline.py | 2 +- google/cloud/firestore_v1/base_client.py | 3 +++ google/cloud/firestore_v1/base_query.py | 6 +++--- google/cloud/firestore_v1/client.py | 4 ++++ google/cloud/firestore_v1/pipeline.py | 1 - 6 files changed, 15 insertions(+), 5 deletions(-) diff --git a/google/cloud/firestore_v1/async_client.py b/google/cloud/firestore_v1/async_client.py index f14ec6573..5fbc642ce 100644 --- a/google/cloud/firestore_v1/async_client.py +++ b/google/cloud/firestore_v1/async_client.py @@ -46,6 +46,7 @@ from google.cloud.firestore_v1.services.firestore.transports import ( grpc_asyncio as firestore_grpc_transport, ) +from google.cloud.firestore_v1.async_pipeline import AsyncPipeline if TYPE_CHECKING: from google.cloud.firestore_v1.bulk_writer import BulkWriter # pragma: NO COVER @@ -412,3 +413,6 @@ def transaction(self, **kwargs) -> AsyncTransaction: A transaction attached to this client. """ return AsyncTransaction(self, **kwargs) + + def pipeline(self) -> AsyncPipeline: + return AsyncPipeline(self) \ No newline at end of file diff --git a/google/cloud/firestore_v1/async_pipeline.py b/google/cloud/firestore_v1/async_pipeline.py index 7338ae971..98fd75552 100644 --- a/google/cloud/firestore_v1/async_pipeline.py +++ b/google/cloud/firestore_v1/async_pipeline.py @@ -24,7 +24,7 @@ from google.cloud.firestore_v1.async_client import AsyncClient -class Pipeline(_BasePipeline): +class AsyncPipeline(_BasePipeline): """ Pipelines allow for complex data transformations and queries involving multiple stages like filtering, projection, aggregation, and vector search. diff --git a/google/cloud/firestore_v1/base_client.py b/google/cloud/firestore_v1/base_client.py index f36ff357b..a3cd55fe8 100644 --- a/google/cloud/firestore_v1/base_client.py +++ b/google/cloud/firestore_v1/base_client.py @@ -475,6 +475,9 @@ def batch(self) -> BaseWriteBatch: def transaction(self, **kwargs) -> BaseTransaction: raise NotImplementedError + def pipeline(self): + raise NotImplementedError + def _reference_info(references: list) -> Tuple[list, dict]: """Get information about document references. diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index 01dda6423..13218b9b8 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -57,7 +57,7 @@ query, ) from google.cloud.firestore_v1.vector import Vector -from google.cloud.firestore_v1.pipeline import Pipeline +from google.cloud.firestore_v1 import pipeline_expressions if TYPE_CHECKING: # pragma: NO COVER from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator @@ -1104,9 +1104,9 @@ def recursive(self: QueryType) -> QueryType: return copied - def pipeline(self) -> Pipeline: + def pipeline(self): # TODO: add extensive tests - ppl = Pipeline(self._client) + ppl = self._client.pipeline() if self._all_descendants: ppl = ppl.collection_group(self._parent.id) else: diff --git a/google/cloud/firestore_v1/client.py b/google/cloud/firestore_v1/client.py index 8bdaf7f81..885843b02 100644 --- a/google/cloud/firestore_v1/client.py +++ b/google/cloud/firestore_v1/client.py @@ -48,6 +48,7 @@ grpc as firestore_grpc_transport, ) from google.cloud.firestore_v1.transaction import Transaction +from google.cloud.firestore_v1.pipeline import Pipeline if TYPE_CHECKING: from google.cloud.firestore_v1.bulk_writer import BulkWriter # pragma: NO COVER @@ -404,3 +405,6 @@ def transaction(self, **kwargs) -> Transaction: A transaction attached to this client. """ return Transaction(self, **kwargs) + + def pipeline(self) -> Pipeline: + return Pipeline(self) \ No newline at end of file diff --git a/google/cloud/firestore_v1/pipeline.py b/google/cloud/firestore_v1/pipeline.py index 92b00d96c..bbead6bf2 100644 --- a/google/cloud/firestore_v1/pipeline.py +++ b/google/cloud/firestore_v1/pipeline.py @@ -61,7 +61,6 @@ def execute(self) -> Iterable["ExecutePipelineResponse"]: request = ExecutePipelineRequest( database=database_name, structured_pipeline=self._to_pb(), - read_time=datetime.datetime.now(), ) results = self._client._firestore_api.execute_pipeline(request) return results \ No newline at end of file From ec060802efaede7bb5637ed52b2b10d7791dd897 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 28 Mar 2025 15:17:24 -0700 Subject: [PATCH 065/131] improved query.pipeline logic --- google/cloud/firestore_v1/base_query.py | 30 +++++++++---------------- 1 file changed, 10 insertions(+), 20 deletions(-) diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index 13218b9b8..1eb18e749 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -1128,38 +1128,28 @@ def pipeline(self): # Orders orders = self._normalize_orders() if orders: - # Add exists filters to match Query's implicit orderby semantics. exists = [] + orderings = [] for order in orders: - # skip __name__ - if order.field.field_path == "__name__": - continue - exists.append(field_path_module.FieldPath(order.field.field_path).exists()) + field = pipeline_expressions.Field.of(order.field.field_path) + exists.append(field.exists()) + direction = ( + "ascending" if order.direction == BaseQuery.Direction.ASCENDING else "descending" + ) + orderings.append(pipeline_expressions.Ordering(field, direction)) + # Add exists filters to match Query's implicit orderby semantics. if len(exists) > 1: ppl = ppl.where(field_path_module.And(*exists)) elif len(exists) == 1: ppl = ppl.where(exists[0]) - orderings = [] - for order in orders: - direction = ( - "asc" if order.direction == StructuredQuery.Direction.ASCENDING else "desc" - ) - orderings.append( - getattr(field_path_module.FieldPath(order.field.field_path), direction)() - ) + # Add sort orderings ppl = ppl.sort(*orderings) # Cursors, Limit and Offset if self._start_at or self._end_at or self._limit_to_last: - ppl = ppl.paginate( - start_at=self._start_at, - end_at=self._end_at, - limit=self._limit, - limit_to_last=self._limit_to_last, - offset=self._offset, - ) + raise NotImplementedError("Query to Pipeline conversion: cursors and limitToLast is not supported yet.") else: # Limit & Offset without cursors if self._offset: ppl = ppl.offset(self._offset) From 07035329482d1537db3e07e3975e37e1b3436e90 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 28 Mar 2025 15:21:24 -0700 Subject: [PATCH 066/131] removed imports --- google/cloud/firestore_v1/base_collection.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/google/cloud/firestore_v1/base_collection.py b/google/cloud/firestore_v1/base_collection.py index 506234507..1ac1ba318 100644 --- a/google/cloud/firestore_v1/base_collection.py +++ b/google/cloud/firestore_v1/base_collection.py @@ -35,8 +35,6 @@ from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.base_query import QueryType -from google.cloud.firestore_v1.pipeline import Pipeline -from google.cloud.firestore_v1.pipeline_stages import Collection as CollectionStage if TYPE_CHECKING: # pragma: NO COVER # Types needed only for Type Hints From 246d1c8545351c3c9192fb5b078e50cf2a9b076e Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 28 Mar 2025 16:02:57 -0700 Subject: [PATCH 067/131] pass stages through --- google/cloud/firestore_v1/async_client.py | 4 ++-- google/cloud/firestore_v1/base_client.py | 2 +- google/cloud/firestore_v1/client.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/google/cloud/firestore_v1/async_client.py b/google/cloud/firestore_v1/async_client.py index 5fbc642ce..9fb6883bb 100644 --- a/google/cloud/firestore_v1/async_client.py +++ b/google/cloud/firestore_v1/async_client.py @@ -414,5 +414,5 @@ def transaction(self, **kwargs) -> AsyncTransaction: """ return AsyncTransaction(self, **kwargs) - def pipeline(self) -> AsyncPipeline: - return AsyncPipeline(self) \ No newline at end of file + def pipeline(self, *stages) -> AsyncPipeline: + return AsyncPipeline(self, *stages) \ No newline at end of file diff --git a/google/cloud/firestore_v1/base_client.py b/google/cloud/firestore_v1/base_client.py index a3cd55fe8..585de4ce2 100644 --- a/google/cloud/firestore_v1/base_client.py +++ b/google/cloud/firestore_v1/base_client.py @@ -475,7 +475,7 @@ def batch(self) -> BaseWriteBatch: def transaction(self, **kwargs) -> BaseTransaction: raise NotImplementedError - def pipeline(self): + def pipeline(self, *stages): raise NotImplementedError diff --git a/google/cloud/firestore_v1/client.py b/google/cloud/firestore_v1/client.py index 885843b02..aa82c59b6 100644 --- a/google/cloud/firestore_v1/client.py +++ b/google/cloud/firestore_v1/client.py @@ -406,5 +406,5 @@ def transaction(self, **kwargs) -> Transaction: """ return Transaction(self, **kwargs) - def pipeline(self) -> Pipeline: - return Pipeline(self) \ No newline at end of file + def pipeline(self, *stages) -> Pipeline: + return Pipeline(self, *stages) \ No newline at end of file From b72d44a1a891973fff9338e7dce1db02aa05a9de Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 28 Mar 2025 16:03:34 -0700 Subject: [PATCH 068/131] fixed collection docstrings --- google/cloud/firestore_v1/async_pipeline.py | 9 ++--- google/cloud/firestore_v1/base_pipeline.py | 38 ++++++++++----------- google/cloud/firestore_v1/pipeline.py | 14 ++++---- 3 files changed, 28 insertions(+), 33 deletions(-) diff --git a/google/cloud/firestore_v1/async_pipeline.py b/google/cloud/firestore_v1/async_pipeline.py index 98fd75552..2dfbaedda 100644 --- a/google/cloud/firestore_v1/async_pipeline.py +++ b/google/cloud/firestore_v1/async_pipeline.py @@ -33,21 +33,18 @@ class AsyncPipeline(_BasePipeline): defined pipeline stages using an asynchronous `AsyncClient`. Usage Example: - >>> import asyncio >>> from google.cloud.firestore_v1.pipeline_expressions import Field, gt >>> >>> async def run_pipeline(): ... client = AsyncClient(...) - ... pipeline = client.pipeline() - ... .collection("books") + ... pipeline = client.collection("books") + ... .pipeline() ... .where(gt(Field.of("published"), 1980)) ... .select("title", "author") ... async for result in pipeline.execute_async(): ... print(result) - >>> - >>> asyncio.run(run_pipeline()) - Use `AsyncClient.pipeline()` to create instances of this class. + Use `client.collection("...").pipeline()` to create instances of this class. """ def __init__(self, client:AsyncClient, *stages: stages.Stage): """ diff --git a/google/cloud/firestore_v1/base_pipeline.py b/google/cloud/firestore_v1/base_pipeline.py index 7d622642a..6d6f8d5f6 100644 --- a/google/cloud/firestore_v1/base_pipeline.py +++ b/google/cloud/firestore_v1/base_pipeline.py @@ -33,8 +33,8 @@ class _BasePipeline: """ Base class for building Firestore data transformation and query pipelines. - This class is not intended to be instantiated directly. Use `Client.pipeline()` - or `AsyncClient.pipeline()` to create pipeline instances. + This class is not intended to be instantiated directly. + Use `client.collection.("...").pipeline()` to create pipeline instances. """ def __init__(self, *stages: stages.Stage): """ @@ -72,7 +72,7 @@ def add_fields(self, *fields: Selectable) -> Self: Example: >>> from google.cloud.firestore_v1.pipeline_expressions import Field, add - >>> pipeline = firestore.pipeline().collection("books") + >>> pipeline = client.collection("books").pipeline() >>> pipeline = pipeline.add_fields( ... Field.of("rating").as_("bookRating"), # Rename 'rating' to 'bookRating' ... add(5, Field.of("quantity")).as_("totalCost") # Calculate 'totalCost' @@ -94,7 +94,7 @@ def remove_fields(self, *fields: Field | str) -> Self: Example: >>> from google.cloud.firestore_v1.pipeline_expressions import Field - >>> pipeline = firestore.pipeline().collection("books") + >>> pipeline = client.collection("books").pipeline() >>> # Remove by name >>> pipeline = pipeline.remove_fields("rating", "cost") >>> # Remove by Field object @@ -126,7 +126,7 @@ def select(self, *selections: str | Selectable) -> Self: Example: >>> from google.cloud.firestore_v1.pipeline_expressions import Field, to_upper - >>> pipeline = firestore.pipeline().collection("books") + >>> pipeline = client.collection("books").pipeline() >>> # Select by name >>> pipeline = pipeline.select("name", "address") >>> # Select using Field and Function expressions @@ -159,7 +159,7 @@ def where(self, condition: FilterCondition) -> Self: Example: >>> from google.cloud.firestore_v1.pipeline_expressions import Field, and_, gt, eq - >>> pipeline = firestore.pipeline().collection("books") + >>> pipeline = client.collection("books").pipeline() >>> # Using static functions >>> pipeline = pipeline.where( ... and_( @@ -207,7 +207,7 @@ def find_nearest( >>> from google.cloud.firestore_v1.pipeline_expressions import Field >>> >>> target_vector = [0.1, 0.2, 0.3] - >>> pipeline = firestore.pipeline().collection("books") + >>> pipeline = client.collection("books").pipeline() >>> # Find using field name >>> pipeline = pipeline.find_nearest( ... "topicVectors", @@ -253,7 +253,7 @@ def sort(self, *orders: stages.Ordering) -> Self: Example: >>> from google.cloud.firestore_v1.pipeline_expressions import Field - >>> pipeline = firestore.pipeline().collection("books") + >>> pipeline = client.collection("books").pipeline() >>> # Sort books by rating descending, then title ascending >>> pipeline = pipeline.sort( ... Field.of("rating").descending(), @@ -295,7 +295,7 @@ def replace( ``` >>> from google.cloud.firestore_v1.pipeline_expressions import Field - >>> pipeline = firestore.pipeline().collection("people") + >>> pipeline = client.collection("people").pipeline() >>> # Emit the 'parents' map as the document >>> pipeline = pipeline.replace(Field.of("parents")) @@ -330,7 +330,7 @@ def sample(self, limit_or_options: int | SampleOptions) -> Self: Example: >>> from google.cloud.firestore_v1.pipeline_expressions import SampleOptions - >>> pipeline = firestore.pipeline().collection("books") + >>> pipeline = client.collection("books").pipeline() >>> # Sample 10 books, if available. >>> pipeline = pipeline.sample(10) >>> pipeline = pipeline.sample(SampleOptions.doc_limit(10)) @@ -358,8 +358,8 @@ def union(self, other: Self) -> Self: pipeline provided. The order of documents emitted from this stage is undefined. Example: - >>> books_pipeline = firestore.pipeline().collection("books") - >>> magazines_pipeline = firestore.pipeline().collection("magazines") + >>> books_pipeline = client.collection("books").pipeline() + >>> magazines_pipeline = client.collection("magazines").pipeline() >>> # Emit documents from both collections >>> combined_pipeline = books_pipeline.union(magazines_pipeline) @@ -394,7 +394,7 @@ def unnest( ``` >>> from google.cloud.firestore_v1.pipeline_stages import UnnestOptions - >>> pipeline = firestore.pipeline().collection("books") + >>> pipeline = client.collection("books").pipeline() >>> # Emit a document for each tag >>> pipeline = pipeline.unnest("tags", alias="tag") @@ -414,7 +414,7 @@ def unnest( ``` >>> from google.cloud.firestore_v1.pipeline_stages import UnnestOptions - >>> pipeline = firestore.pipeline().collection("books") + >>> pipeline = client.collection("books").pipeline() >>> # Emit a document for each tag, including the index >>> pipeline = pipeline.unnest("tags", options=UnnestOptions(index_field="tagIndex")) @@ -446,7 +446,7 @@ def generic_stage(self, name: str, *params: Expr) -> Self: Example: >>> # Assume we don't have a built-in "where" stage - >>> pipeline = firestore.pipeline().collection("books") + >>> pipeline = client.collection("books").pipeline() >>> pipeline = pipeline.generic_stage("where", [Field.of("published").lt(900)]) >>> pipeline = pipeline.select("title", "author") @@ -470,7 +470,7 @@ def offset(self, offset: int) -> Self: Example: >>> from google.cloud.firestore_v1.pipeline_expressions import Field - >>> pipeline = firestore.pipeline().collection("books") + >>> pipeline = client.collection("books").pipeline() >>> # Retrieve the second page of 20 results (assuming sorted) >>> pipeline = pipeline.sort(Field.of("published").descending()) >>> pipeline = pipeline.offset(20) # Skip the first 20 results @@ -496,7 +496,7 @@ def limit(self, limit: int) -> Self: Example: >>> from google.cloud.firestore_v1.pipeline_expressions import Field - >>> pipeline = firestore.pipeline().collection("books") + >>> pipeline = client.collection("books").pipeline() >>> # Limit the results to the top 10 highest-rated books >>> pipeline = pipeline.sort(Field.of("rating").descending()) >>> pipeline = pipeline.limit(10) @@ -531,7 +531,7 @@ def aggregate( Example: >>> from google.cloud.firestore_v1.pipeline_expressions import Field, avg, count_all - >>> pipeline = firestore.pipeline().collection("books") + >>> pipeline = client.collection("books").pipeline() >>> # Calculate the average rating and total count for all books >>> pipeline = pipeline.aggregate( ... avg(Field.of("rating")).as_("averageRating"), @@ -572,7 +572,7 @@ def distinct(self, *fields: str | Selectable) -> Self: Example: >>> from google.cloud.firestore_v1.pipeline_expressions import Field, to_upper - >>> pipeline = firestore.pipeline().collection("books") + >>> pipeline = client.collection("books").pipeline() >>> # Get a list of unique genres (output has only 'genre' field) >>> pipeline = pipeline.distinct("genre") >>> # Get unique combinations of author (uppercase) and genre diff --git a/google/cloud/firestore_v1/pipeline.py b/google/cloud/firestore_v1/pipeline.py index bbead6bf2..fc0ee428d 100644 --- a/google/cloud/firestore_v1/pipeline.py +++ b/google/cloud/firestore_v1/pipeline.py @@ -30,20 +30,18 @@ class Pipeline(_BasePipeline): multiple stages like filtering, projection, aggregation, and vector search. Usage Example: - >>> from google.cloud.firestore_v1.pipeline_expressions import Field, gt + >>> from google.cloud.firestore_v1.pipeline_expressions import Field >>> - >>> async def run_pipeline(): + >>> def run_pipeline(): ... client = Client(...) - ... pipeline = client.pipeline() - ... .collection("books") - ... .where(gt(Field.of("published"), 1980)) + ... pipeline = client.collection("books") + ... .pipeline() + ... .where(Field.of("published").gt(1980)) ... .select("title", "author") ... for result in pipeline.execute(): ... print(result) - >>> - >>> asyncio.run(run_pipeline()) - Use `Client.pipeline()` to create instances of this class. + Use `client.collection("...").pipeline()` to create instances of this class. """ def __init__(self, client:Client, *stages: stages.Stage): """ From 810ccd2d48abe10efdf4862adc632d453f0887f8 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 28 Mar 2025 16:04:04 -0700 Subject: [PATCH 069/131] fixed collection setup --- google/cloud/firestore_v1/base_pipeline.py | 14 +++++++------- google/cloud/firestore_v1/base_query.py | 11 ++++++----- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/google/cloud/firestore_v1/base_pipeline.py b/google/cloud/firestore_v1/base_pipeline.py index 6d6f8d5f6..e5c7577aa 100644 --- a/google/cloud/firestore_v1/base_pipeline.py +++ b/google/cloud/firestore_v1/base_pipeline.py @@ -154,22 +154,22 @@ def where(self, condition: FilterCondition) -> Self: clause in SQL. You can filter documents based on their field values, using implementations of `FilterCondition`, typically including but not limited to: - field comparators: `eq`, `lt` (less than), `gt` (greater than), etc. - - logical operators: `and_`, `or_`, `not_`, etc. - - advanced functions: `regex_match`, `array_contains`, etc. + - logical operators: `And`, `Or`, `Not`, etc. + - advanced functions: `regex_matches`, `array_contains`, etc. Example: - >>> from google.cloud.firestore_v1.pipeline_expressions import Field, and_, gt, eq + >>> from google.cloud.firestore_v1.pipeline_expressions import Field, And, >>> pipeline = client.collection("books").pipeline() >>> # Using static functions >>> pipeline = pipeline.where( - ... and_( - ... gt(Field.of("rating"), 4.0), # Filter for ratings > 4.0 - ... eq(Field.of("genre"), "Science Fiction") # Filter for genre + ... And( + ... Field.of("rating").gt(4.0), # Filter for ratings > 4.0 + ... Field.of("genre").eq("Science Fiction") # Filter for genre ... ) ... ) >>> # Using methods on expressions >>> pipeline = pipeline.where( - ... and_( + ... And( ... Field.of("rating").gt(4.0), ... Field.of("genre").eq("Science Fiction") ... ) diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index 1eb18e749..f70fc66ea 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -58,6 +58,7 @@ ) from google.cloud.firestore_v1.vector import Vector from google.cloud.firestore_v1 import pipeline_expressions +from google.cloud.firestore_v1 import pipeline_stages if TYPE_CHECKING: # pragma: NO COVER from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator @@ -1106,15 +1107,15 @@ def recursive(self: QueryType) -> QueryType: def pipeline(self): # TODO: add extensive tests - ppl = self._client.pipeline() if self._all_descendants: - ppl = ppl.collection_group(self._parent.id) + base_stage = pipeline_stages.CollectionGroup(self._parent.id) else: - ppl = ppl.collection("/".join(self._parent._path)) + base_stage = pipeline_stages.Collection("/".join(self._parent._path)) + ppl = self._client.pipeline(base_stage) # Filters - for filter in self._field_filters: - ppl = ppl.where(filter) + for filter_ in self._field_filters: + ppl = ppl.where(filter_) # Projections if self._projection and self._projection.fields: From 78b5833000daa1bfc69fb1b956822a2816a32088 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 28 Mar 2025 17:02:42 -0700 Subject: [PATCH 070/131] fixed some collection errors --- google/cloud/firestore_v1/base_collection.py | 3 +++ google/cloud/firestore_v1/base_query.py | 6 +++--- google/cloud/firestore_v1/pipeline_expressions.py | 4 ++++ 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/google/cloud/firestore_v1/base_collection.py b/google/cloud/firestore_v1/base_collection.py index 1ac1ba318..f0e3a7b6a 100644 --- a/google/cloud/firestore_v1/base_collection.py +++ b/google/cloud/firestore_v1/base_collection.py @@ -590,6 +590,9 @@ def find_nearest( distance_threshold=distance_threshold, ) + def pipeline(self): + return self._query().pipeline() + def _auto_id() -> str: """Generate a "random" automatically generated ID. diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index f70fc66ea..2c5e07721 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -1115,7 +1115,7 @@ def pipeline(self): # Filters for filter_ in self._field_filters: - ppl = ppl.where(filter_) + ppl = ppl.where(pipeline_expressions.FilterCondition._from_pb(filter_)) # Projections if self._projection and self._projection.fields: @@ -1135,13 +1135,13 @@ def pipeline(self): field = pipeline_expressions.Field.of(order.field.field_path) exists.append(field.exists()) direction = ( - "ascending" if order.direction == BaseQuery.Direction.ASCENDING else "descending" + "ascending" if order.direction == StructuredQuery.Direction.ASCENDING else "descending" ) orderings.append(pipeline_expressions.Ordering(field, direction)) # Add exists filters to match Query's implicit orderby semantics. if len(exists) > 1: - ppl = ppl.where(field_path_module.And(*exists)) + ppl = ppl.where(pipeline_expressions.And(*exists)) elif len(exists) == 1: ppl = ppl.where(exists[0]) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index 92bc5d55a..a3401af20 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -1417,6 +1417,10 @@ def _to_pb(self): class FilterCondition(Function): """Filters the given data in some way.""" + @staticmethod + def _from_pb(filter_pb): + raise NotImplementedError + class And(FilterCondition): def __init__(self, *conditions: "FilterCondition"): From 36e0228e96e7b81aecee49460690ebe870a1f159 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 28 Mar 2025 17:43:55 -0700 Subject: [PATCH 071/131] implemented FitlerCondition._from_pb --- google/cloud/firestore_v1/base_query.py | 2 +- .../firestore_v1/pipeline_expressions.py | 56 ++++++++++++++++++- 2 files changed, 55 insertions(+), 3 deletions(-) diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index 2c5e07721..bddc423b8 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -1115,7 +1115,7 @@ def pipeline(self): # Filters for filter_ in self._field_filters: - ppl = ppl.where(pipeline_expressions.FilterCondition._from_pb(filter_)) + ppl = ppl.where(pipeline_expressions.FilterCondition._from_pb(filter_, self._client)) # Projections if self._projection and self._projection.fields: diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index a3401af20..833fdfcb4 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -21,9 +21,11 @@ import datetime from dataclasses import dataclass from google.cloud.firestore_v1.types.document import Value +from google.cloud.firestore_v1.types.query import StructuredQuery as Query_pb from google.cloud.firestore_v1.vector import Vector from google.cloud.firestore_v1._helpers import GeoPoint from google.cloud.firestore_v1._helpers import encode_value +from google.cloud.firestore_v1._helpers import decode_value CONSTANT_TYPE = TypeVar('CONSTANT_TYPE', str, int, float, bool, datetime.datetime, bytes, GeoPoint, Vector, list, Dict[str, Any], None) @@ -1418,8 +1420,58 @@ class FilterCondition(Function): """Filters the given data in some way.""" @staticmethod - def _from_pb(filter_pb): - raise NotImplementedError + def _from_pb(filter_pb, client): + if isinstance(filter_pb, Query_pb.CompositeFilter): + sub_filters = [FilterCondition._from_pb(f, client) for f in filter_pb.filters] + if filter_pb.op == Query_pb.CompositeFilter.Operator.OR: + return Or(*sub_filters) + elif filter_pb.op == Query_pb.CompositeFilter.Operator.AND: + return And(*sub_filters) + else: + raise TypeError(f"Unexpected CompositeFilter operator type: {filter_pb.op}") + elif isinstance(filter_pb, Query_pb.UnaryFilter): + field = Field.of(filter_pb.field) + if filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NAN: + return And(field.exists(), field.is_nan()) + elif filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NOT_NAN: + return And(field.exists(), Not(field.is_nan())) + elif filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NULL: + return And(field.exists(), field.eq(None)) + elif filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NOT_NULL: + return And(field.exists(), Not(field.eq(None))) + else: + raise TypeError(f"Unexpected UnaryFilter operator type: {filter_pb.op}") + elif isinstance(filter_pb, Query_pb.FieldFilter): + field = Field.of(filter_pb.field) + value = decode_value(filter_pb.value, client) + if filter_pb.op == Query_pb.FieldFilter.Operator.LESS_THAN: + return And(field.exists(), field.lt(value)) + elif filter_pb.op == Query_pb.FieldFilter.Operator.LESS_THAN_OR_EQUAL: + return And(field.exists(), field.lte(value)) + elif filter_pb.op == Query_pb.FieldFilter.Operator.GREATER_THAN: + return And(field.exists(), field.gt(value)) + elif filter_pb.op == Query_pb.FieldFilter.Operator.GREATER_THAN_OR_EQUAL: + return And(field.exists(), field.gte(value)) + elif filter_pb.op == Query_pb.FieldFilter.Operator.EQUAL: + return And(field.exists(), field.eq(value)) + elif filter_pb.op == Query_pb.FieldFilter.Operator.NOT_EQUAL: + return And(field.exists(), field.neq(value)) + if filter_pb.op == Query_pb.FieldFilter.Operator.ARRAY_CONTAINS: + return And(field.exists(), field.array_contains(value)) + elif filter_pb.op == Query_pb.FieldFilter.Operator.ARRAY_CONTAINS_ANY: + return And(field.exists(), field.array_contains_any(value)) + elif filter_pb.op == Query_pb.FieldFilter.Operator.IN: + return And(field.exists(), field.in_any(value)) + elif filter_pb.op == Query_pb.FieldFilter.Operator.NOT_IN: + return And(field.exists(), field.not_in_any(value)) + else: + raise TypeError(f"Unexpected FieldFilter operator type: {filter_pb.op}") + elif isinstance(filter_pb, Query_pb.Filter): + # unwrap oneof + f = filter_pb.composite_filter or filter_pb.field_filter or filter_pb.unary_filter + return FilterCondition._from_pb(f, client) + else: + raise TypeError(f"Unexpected filter type: {type(filter_pb)}") class And(FilterCondition): From 43f4cf4519ebd0847b2cc22e260d4e516ef14d62 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 28 Mar 2025 17:45:40 -0700 Subject: [PATCH 072/131] renamed function --- google/cloud/firestore_v1/base_query.py | 2 +- google/cloud/firestore_v1/pipeline_expressions.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index bddc423b8..74e2a0e84 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -1115,7 +1115,7 @@ def pipeline(self): # Filters for filter_ in self._field_filters: - ppl = ppl.where(pipeline_expressions.FilterCondition._from_pb(filter_, self._client)) + ppl = ppl.where(pipeline_expressions.FilterCondition._from_query_filter_pb(filter_, self._client)) # Projections if self._projection and self._projection.fields: diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index 833fdfcb4..f04a18e4c 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -1420,9 +1420,9 @@ class FilterCondition(Function): """Filters the given data in some way.""" @staticmethod - def _from_pb(filter_pb, client): + def _from_query_filter_pb(filter_pb, client): if isinstance(filter_pb, Query_pb.CompositeFilter): - sub_filters = [FilterCondition._from_pb(f, client) for f in filter_pb.filters] + sub_filters = [FilterCondition._from_query_filter_pb(f, client) for f in filter_pb.filters] if filter_pb.op == Query_pb.CompositeFilter.Operator.OR: return Or(*sub_filters) elif filter_pb.op == Query_pb.CompositeFilter.Operator.AND: @@ -1469,7 +1469,7 @@ def _from_pb(filter_pb, client): elif isinstance(filter_pb, Query_pb.Filter): # unwrap oneof f = filter_pb.composite_filter or filter_pb.field_filter or filter_pb.unary_filter - return FilterCondition._from_pb(f, client) + return FilterCondition._from_query_filter_pb(f, client) else: raise TypeError(f"Unexpected filter type: {type(filter_pb)}") From 51e9b25e37d51447771f35a64bfed01809f9ad68 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 28 Mar 2025 19:21:08 -0700 Subject: [PATCH 073/131] added tests for query.pipeline --- google/cloud/firestore_v1/base_query.py | 1 - tests/unit/v1/test_base_query.py | 167 ++++++++++++++++++++++++ 2 files changed, 167 insertions(+), 1 deletion(-) diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index 74e2a0e84..100b137ef 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -1106,7 +1106,6 @@ def recursive(self: QueryType) -> QueryType: return copied def pipeline(self): - # TODO: add extensive tests if self._all_descendants: base_stage = pipeline_stages.CollectionGroup(self._parent.id) else: diff --git a/tests/unit/v1/test_base_query.py b/tests/unit/v1/test_base_query.py index 24caa5e40..517d176f7 100644 --- a/tests/unit/v1/test_base_query.py +++ b/tests/unit/v1/test_base_query.py @@ -1951,6 +1951,173 @@ def test__collection_group_query_response_to_snapshot_response(): assert snapshot.create_time == response_pb._pb.document.create_time assert snapshot.update_time == response_pb._pb.document.update_time +def test__query_pipeline_decendants(): + from google.cloud.firestore_v1 import pipeline_stages + + client = make_client() + query = client.collection_group("my_col") + pipeline = query.pipeline() + + assert len(pipeline.stages) == 1 + stage = pipeline.stages[0] + assert isinstance(stage, pipeline_stages.CollectionGroup) + assert stage.collection_id == "my_col" + + +@pytest.mark.parametrize("in_path,out_path",[ + ("my_col/doc/", "/my_col/doc/"), + ("/my_col/doc", "/my_col/doc"), + ("my_col/doc/sub_col", "/my_col/doc/sub_col"), +]) +def test__query_pipeline_no_decendants(in_path, out_path): + from google.cloud.firestore_v1 import pipeline_stages + + client = make_client() + query = client.collection(in_path) + pipeline = query.pipeline() + + assert len(pipeline.stages) == 1 + stage = pipeline.stages[0] + assert isinstance(stage, pipeline_stages.Collection) + assert stage.path == out_path + + +def test__query_pipeline_composite_filter(): + from google.cloud.firestore_v1 import FieldFilter + from google.cloud.firestore_v1 import pipeline_expressions as expr + from google.cloud.firestore_v1 import pipeline_stages + + client = make_client() + in_filter = FieldFilter("field_a", "==", "value_a") + query = client.collection("my_col").where(filter=in_filter) + with mock.patch.object(expr.FilterCondition, "_from_query_filter_pb") as convert_mock: + pipeline = query.pipeline() + convert_mock.assert_called_once_with(in_filter._to_pb(), client) + assert len(pipeline.stages) == 2 + stage = pipeline.stages[1] + assert isinstance(stage, pipeline_stages.Where) + assert stage.condition == convert_mock.return_value + + +def test__query_pipeline_projections(): + from google.cloud.firestore_v1 import pipeline_stages + + client = make_client() + query = client.collection("my_col").select(["field_a", "field_b.c"]) + pipeline = query.pipeline() + + assert len(pipeline.stages) == 2 + stage = pipeline.stages[1] + assert isinstance(stage, pipeline_stages.Select) + assert len(stage.projections) == 2 + assert stage.projections[0].path == "field_a" + assert stage.projections[1].path == "field_b.c" + + +def test__query_pipeline_order_exists_multiple(): + from google.cloud.firestore_v1 import pipeline_expressions as expr + from google.cloud.firestore_v1 import pipeline_stages + + client = make_client() + query = client.collection("my_col").order_by("field_a").order_by("field_b") + pipeline = query.pipeline() + + # should have collection, where, and sort + # we're interested in where + assert len(pipeline.stages) == 3 + where_stage = pipeline.stages[1] + assert isinstance(where_stage, pipeline_stages.Where) + # should have and with both orderings + assert isinstance(where_stage.condition, expr.And) + assert len(where_stage.condition.params) == 2 + operands = [p for p in where_stage.condition.params] + assert isinstance(operands[0], expr.Exists) + assert operands[0].params[0].path == "field_a" + assert isinstance(operands[1], expr.Exists) + assert operands[1].params[0].path == "field_b" + +def test__query_pipeline_order_exists_single(): + from google.cloud.firestore_v1 import pipeline_expressions as expr + from google.cloud.firestore_v1 import pipeline_stages + + client = make_client() + query_single = client.collection("my_col").order_by("field_c") + pipeline_single = query_single.pipeline() + + # should have collection, where, and sort + # we're interested in where + assert len(pipeline_single.stages) == 3 + where_stage_single = pipeline_single.stages[1] + assert isinstance(where_stage_single, pipeline_stages.Where) + assert isinstance(where_stage_single.condition, expr.Exists) + assert where_stage_single.condition.params[0].path == "field_c" + + +def test__query_pipeline_order_sorts(): + from google.cloud.firestore_v1 import pipeline_expressions as expr + from google.cloud.firestore_v1 import pipeline_stages + from google.cloud.firestore_v1.base_query import BaseQuery + + client = make_client() + query = ( + client.collection("my_col") + .order_by("field_a", direction=BaseQuery.ASCENDING) + .order_by("field_b", direction=BaseQuery.DESCENDING) + ) + pipeline = query.pipeline() + + assert len(pipeline.stages) == 3 + sort_stage = pipeline.stages[2] + assert isinstance(sort_stage, pipeline_stages.Sort) + assert len(sort_stage.orders) == 2 + assert isinstance(sort_stage.orders[0], expr.Ordering) + assert sort_stage.orders[0].expr.path == "field_a" + assert sort_stage.orders[0].order_dir == expr.Ordering.Direction.ASCENDING + assert isinstance(sort_stage.orders[1], expr.Ordering) + assert sort_stage.orders[1].expr.path == "field_b" + assert sort_stage.orders[1].order_dir == expr.Ordering.Direction.DESCENDING + + +def test__query_pipeline_cursor(): + client = make_client() + query_start = client.collection("my_col").start_at({"field_a": "value"}) + with pytest.raises(NotImplementedError, match="cursors"): + query_start.pipeline() + + query_end = client.collection("my_col").end_at({"field_a": "value"}) + with pytest.raises(NotImplementedError, match="cursors"): + query_end.pipeline() + + query_limit_last = client.collection("my_col").limit_to_last(10) + with pytest.raises(NotImplementedError, match="limitToLast"): + query_limit_last.pipeline() + + +def test__query_pipeline_limit(): + from google.cloud.firestore_v1 import pipeline_stages + + client = make_client() + query = client.collection("my_col").limit(15) + pipeline = query.pipeline() + + assert len(pipeline.stages) == 2 + stage = pipeline.stages[1] + assert isinstance(stage, pipeline_stages.Limit) + assert stage.limit == 15 + + +def test__query_pipeline_offset(): + from google.cloud.firestore_v1 import pipeline_stages + + client = make_client() + query = client.collection("my_col").offset(5) + pipeline = query.pipeline() + + assert len(pipeline.stages) == 2 + stage = pipeline.stages[1] + assert isinstance(stage, pipeline_stages.Offset) + assert stage.offset == 5 + def _make_order_pb(field_path, direction): from google.cloud.firestore_v1.types import query From 35cb0bbc4b7dd58fc0e00a15a435756104746a28 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 31 Mar 2025 15:28:43 -0700 Subject: [PATCH 074/131] added verify pipeline to system tests --- tests/system/test_system.py | 164 ++++++++++++++++++++++++++---------- 1 file changed, 118 insertions(+), 46 deletions(-) diff --git a/tests/system/test_system.py b/tests/system/test_system.py index ed525db57..48af63855 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -80,6 +80,31 @@ def cleanup(): operation() +@pytest.fixture +def verify_pipeline(query): + """ + This fixture ensures a pipeline produces the same + results as the query it is derived from + + It can be attached to existing query tests both + modalities at the same time + """ + query_exception = None + query_results = None + try: + query_results = query.get() + except Exception as e: + query_exception = e + pipeline = query.pipeline() + if query_exception: + # ensure that the pipeline uses same error as query + with pytest.raises(query_exception): + pipeline.execute() + else: + # ensure results match query + pipeline_results = pipeline.execute() + assert query_results == pipeline_results + @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) def test_collections(client, database): collections = list(client.collections()) @@ -104,7 +129,7 @@ def test_collections_w_import(database): ) @pytest.mark.parametrize("method", ["stream", "get"]) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_collection_stream_or_get_w_no_explain_options(database, query_docs, method): +def test_collection_stream_or_get_w_no_explain_options(database, query_docs, method, verify_pipeline): from google.cloud.firestore_v1.query_profile import QueryExplainError collection, _, _ = query_docs @@ -119,7 +144,7 @@ def test_collection_stream_or_get_w_no_explain_options(database, query_docs, met match="explain_options not set on query.", ): results.get_explain_metrics() - + verify_pipeline(collection) @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." @@ -127,7 +152,7 @@ def test_collection_stream_or_get_w_no_explain_options(database, query_docs, met @pytest.mark.parametrize("method", ["get", "stream"]) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) def test_collection_stream_or_get_w_explain_options_analyze_false( - database, method, query_docs + database, method, query_docs, verify_pipeline ): from google.cloud.firestore_v1.query_profile import ( ExplainMetrics, @@ -157,6 +182,7 @@ def test_collection_stream_or_get_w_explain_options_analyze_false( match="execution_stats not available when explain_options.analyze=False", ): explain_metrics.execution_stats + verify_pipeline(collection) @pytest.mark.skipif( @@ -165,7 +191,7 @@ def test_collection_stream_or_get_w_explain_options_analyze_false( @pytest.mark.parametrize("method", ["get", "stream"]) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) def test_collection_stream_or_get_w_explain_options_analyze_true( - database, method, query_docs + database, method, query_docs, verify_pipeline ): from google.cloud.firestore_v1.query_profile import ( ExecutionStats, @@ -215,6 +241,7 @@ def test_collection_stream_or_get_w_explain_options_analyze_true( assert "documents_scanned" in execution_stats.debug_stats assert "index_entries_scanned" in execution_stats.debug_stats assert len(execution_stats.debug_stats) > 0 + verify_pipeline(collection) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) @@ -1113,7 +1140,7 @@ def query(collection): @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_query_stream_legacy_where(query_docs, database): +def test_query_stream_legacy_where(query_docs, database, verify_pipeline): """Assert the legacy code still works and returns value""" collection, stored, allowed_vals = query_docs with pytest.warns( @@ -1126,10 +1153,11 @@ def test_query_stream_legacy_where(query_docs, database): for key, value in values.items(): assert stored[key] == value assert value["a"] == 1 + verify_pipeline(query) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_query_stream_w_simple_field_eq_op(query_docs, database): +def test_query_stream_w_simple_field_eq_op(query_docs, database, verify_pipeline): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("a", "==", 1)) values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()} @@ -1137,10 +1165,13 @@ def test_query_stream_w_simple_field_eq_op(query_docs, database): for key, value in values.items(): assert stored[key] == value assert value["a"] == 1 + verify_pipeline(query) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_query_stream_w_simple_field_array_contains_op(query_docs, database): +def test_query_stream_w_simple_field_array_contains_op( + query_docs, database, verify_pipeline +): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("c", "array_contains", 1)) values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()} @@ -1148,10 +1179,11 @@ def test_query_stream_w_simple_field_array_contains_op(query_docs, database): for key, value in values.items(): assert stored[key] == value assert value["a"] == 1 + verify_pipeline(query) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_query_stream_w_simple_field_in_op(query_docs, database): +def test_query_stream_w_simple_field_in_op(query_docs, database, verify_pipeline): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) query = collection.where(filter=FieldFilter("a", "in", [1, num_vals + 100])) @@ -1160,10 +1192,11 @@ def test_query_stream_w_simple_field_in_op(query_docs, database): for key, value in values.items(): assert stored[key] == value assert value["a"] == 1 + verify_pipeline(query) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_query_stream_w_not_eq_op(query_docs, database): +def test_query_stream_w_not_eq_op(query_docs, database, verify_pipeline): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("stats.sum", "!=", 4)) values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()} @@ -1182,10 +1215,11 @@ def test_query_stream_w_not_eq_op(query_docs, database): ] ) assert expected_ab_pairs == ab_pairs2 + verify_pipeline(query) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_query_stream_w_simple_not_in_op(query_docs, database): +def test_query_stream_w_simple_not_in_op(query_docs, database, verify_pipeline): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) query = collection.where( @@ -1194,10 +1228,13 @@ def test_query_stream_w_simple_not_in_op(query_docs, database): values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()} assert len(values) == 22 + verify_pipeline(query) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_query_stream_w_simple_field_array_contains_any_op(query_docs, database): +def test_query_stream_w_simple_field_array_contains_any_op( + query_docs, database, verify_pipeline +): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) query = collection.where( @@ -1208,10 +1245,11 @@ def test_query_stream_w_simple_field_array_contains_any_op(query_docs, database) for key, value in values.items(): assert stored[key] == value assert value["a"] == 1 + verify_pipeline(query) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_query_stream_w_order_by(query_docs, database): +def test_query_stream_w_order_by(query_docs, database, verify_pipeline): collection, stored, allowed_vals = query_docs query = collection.order_by("b", direction=firestore.Query.DESCENDING) values = [(snapshot.id, snapshot.to_dict()) for snapshot in query.stream()] @@ -1222,10 +1260,11 @@ def test_query_stream_w_order_by(query_docs, database): b_vals.append(value["b"]) # Make sure the ``b``-values are in DESCENDING order. assert sorted(b_vals, reverse=True) == b_vals + verify_pipeline(query) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_query_stream_w_field_path(query_docs, database): +def test_query_stream_w_field_path(query_docs, database, verify_pipeline): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("stats.sum", ">", 4)) values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()} @@ -1247,7 +1286,7 @@ def test_query_stream_w_field_path(query_docs, database): @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_query_stream_w_start_end_cursor(query_docs, database): +def test_query_stream_w_start_end_cursor(query_docs, database, verify_pipeline): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) query = ( @@ -1260,19 +1299,21 @@ def test_query_stream_w_start_end_cursor(query_docs, database): for key, value in values: assert stored[key] == value assert value["a"] == num_vals - 2 + verify_pipeline(query) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_query_stream_wo_results(query_docs, database): +def test_query_stream_wo_results(query_docs, database, verify_pipeline): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) query = collection.where(filter=FieldFilter("b", "==", num_vals + 100)) values = list(query.stream()) assert len(values) == 0 + verify_pipeline(query) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_query_stream_w_projection(query_docs, database): +def test_query_stream_w_projection(query_docs, database, verify_pipeline): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) query = collection.where(filter=FieldFilter("b", "<=", 1)).select( @@ -1286,10 +1327,11 @@ def test_query_stream_w_projection(query_docs, database): "stats": {"product": stored[key]["stats"]["product"]}, } assert expected == value + verify_pipeline(query) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_query_stream_w_multiple_filters(query_docs, database): +def test_query_stream_w_multiple_filters(query_docs, database, verify_pipeline): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("stats.product", ">", 5)).where( filter=FieldFilter("stats.product", "<", 10) @@ -1306,10 +1348,11 @@ def test_query_stream_w_multiple_filters(query_docs, database): assert stored[key] == value pair = (value["a"], value["b"]) assert pair in matching_pairs + verify_pipeline(query) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_query_stream_w_offset(query_docs, database): +def test_query_stream_w_offset(query_docs, database, verify_pipeline): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) offset = 3 @@ -1322,6 +1365,7 @@ def test_query_stream_w_offset(query_docs, database): for key, value in values.items(): assert stored[key] == value assert value["b"] == 2 + verify_pipeline(query) @pytest.mark.skipif( @@ -1451,7 +1495,7 @@ def test_query_stream_or_get_w_explain_options_analyze_false( @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_query_with_order_dot_key(client, cleanup, database): +def test_query_with_order_dot_key(client, cleanup, database, verify_pipeline): db = client collection_id = "collek" + UNIQUE_RESOURCE_ID collection = db.collection(collection_id) @@ -1463,33 +1507,38 @@ def test_query_with_order_dot_key(client, cleanup, database): query = collection.order_by("wordcount.page1").limit(3) data = [doc.to_dict()["wordcount"]["page1"] for doc in query.stream()] assert [100, 110, 120] == data - for snapshot in collection.order_by("wordcount.page1").limit(3).stream(): + verify_pipeline(query) + query2 = collection.order_by("wordcount.page1").limit(3) + for snapshot in query2.stream(): last_value = snapshot.get("wordcount.page1") + verify_pipeline(query2) cursor_with_nested_keys = {"wordcount": {"page1": last_value}} - found = list( + query3 = ( collection.order_by("wordcount.page1") .start_after(cursor_with_nested_keys) .limit(3) - .stream() ) + found = list(query3.stream()) found_data = [ {"count": 30, "wordcount": {"page1": 130}}, {"count": 40, "wordcount": {"page1": 140}}, {"count": 50, "wordcount": {"page1": 150}}, ] assert found_data == [snap.to_dict() for snap in found] + verify_pipeline(query3) cursor_with_dotted_paths = {"wordcount.page1": last_value} - cursor_with_key_data = list( + query4 = ( collection.order_by("wordcount.page1") .start_after(cursor_with_dotted_paths) .limit(3) - .stream() ) + cursor_with_key_data = list(query4.stream()) assert found_data == [snap.to_dict() for snap in cursor_with_key_data] + verify_pipeline(query4) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_query_unary(client, cleanup, database): +def test_query_unary(client, cleanup, database, verify_pipeline): collection_name = "unary" + UNIQUE_RESOURCE_ID collection = client.collection(collection_name) field_name = "foo" @@ -1510,6 +1559,7 @@ def test_query_unary(client, cleanup, database): snapshot0 = values0[0] assert snapshot0.reference._path == document0._path assert snapshot0.to_dict() == {field_name: None} + verify_pipeline(query0) # 1. Query for a NAN. query1 = collection.where(filter=FieldFilter(field_name, "==", nan_val)) @@ -1520,10 +1570,11 @@ def test_query_unary(client, cleanup, database): data1 = snapshot1.to_dict() assert len(data1) == 1 assert math.isnan(data1[field_name]) + verify_pipeline(query1) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_collection_group_queries(client, cleanup, database): +def test_collection_group_queries(client, cleanup, database, verify_pipeline): collection_group = "b" + UNIQUE_RESOURCE_ID doc_paths = [ @@ -1553,10 +1604,13 @@ def test_collection_group_queries(client, cleanup, database): found = [snapshot.id for snapshot in snapshots] expected = ["cg-doc1", "cg-doc2", "cg-doc3", "cg-doc4", "cg-doc5"] assert found == expected + verify_pipeline(query) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_collection_group_queries_startat_endat(client, cleanup, database): +def test_collection_group_queries_startat_endat( + client, cleanup, database, verify_pipeline +): collection_group = "b" + UNIQUE_RESOURCE_ID doc_paths = [ @@ -1586,6 +1640,7 @@ def test_collection_group_queries_startat_endat(client, cleanup, database): snapshots = list(query.stream()) found = set(snapshot.id for snapshot in snapshots) assert found == set(["cg-doc2", "cg-doc3", "cg-doc4"]) + verify_pipeline(query) query = ( client.collection_group(collection_group) @@ -1596,10 +1651,11 @@ def test_collection_group_queries_startat_endat(client, cleanup, database): snapshots = list(query.stream()) found = set(snapshot.id for snapshot in snapshots) assert found == set(["cg-doc2"]) + verify_pipeline(query) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_collection_group_queries_filters(client, cleanup, database): +def test_collection_group_queries_filters(client, cleanup, database, verify_pipeline): collection_group = "b" + UNIQUE_RESOURCE_ID doc_paths = [ @@ -1641,6 +1697,7 @@ def test_collection_group_queries_filters(client, cleanup, database): snapshots = list(query.stream()) found = set(snapshot.id for snapshot in snapshots) assert found == set(["cg-doc2", "cg-doc3", "cg-doc4"]) + verify_pipeline(query) query = ( client.collection_group(collection_group) @@ -1662,6 +1719,7 @@ def test_collection_group_queries_filters(client, cleanup, database): snapshots = list(query.stream()) found = set(snapshot.id for snapshot in snapshots) assert found == set(["cg-doc2"]) + verify_pipeline(query) @pytest.mark.skipif( @@ -1927,7 +1985,7 @@ def on_snapshot(docs, changes, read_time): @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_watch_query(client, cleanup, database): +def test_watch_query(client, cleanup, database, verify_pipeline): db = client collection_ref = db.collection("wq-users" + UNIQUE_RESOURCE_ID) doc_ref = collection_ref.document("alovelace") @@ -1944,10 +2002,10 @@ def on_snapshot(docs, changes, read_time): on_snapshot.called_count += 1 # A snapshot should return the same thing as if a query ran now. - query_ran = collection_ref.where( - filter=FieldFilter("first", "==", "Ada") - ).stream() + query_ran_query = collection_ref.where(filter=FieldFilter("first", "==", "Ada")) + query_ran = query_ran_query.stream() assert len(docs) == len([i for i in query_ran]) + verify_pipeline(query_ran_query) on_snapshot.called_count = 0 @@ -2148,11 +2206,12 @@ def test_recursive_delete_serialized_empty(client, cleanup, database): @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_recursive_query(client, cleanup, database): +def test_recursive_query(client, cleanup, database, verify_pipeline): col_id: str = f"philosophers-recursive-query{UNIQUE_RESOURCE_ID}" _persist_documents(client, col_id, philosophers_data_set, cleanup) - ids = [doc.id for doc in client.collection_group(col_id).recursive().get()] + query = client.collection_group(col_id).recursive() + ids = [doc.id for doc in query.get()] expected_ids = [ # Aristotle doc and subdocs @@ -2184,16 +2243,18 @@ def test_recursive_query(client, cleanup, database): f"Expected '{expected_ids[index]}' at spot {index}, " "got '{ids[index]}'" ) assert ids[index] == expected_ids[index], error_msg + verify_pipeline(query) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_nested_recursive_query(client, cleanup, database): +def test_nested_recursive_query(client, cleanup, database, verify_pipeline): col_id: str = f"philosophers-nested-recursive-query{UNIQUE_RESOURCE_ID}" _persist_documents(client, col_id, philosophers_data_set, cleanup) collection_ref = client.collection(col_id) aristotle = collection_ref.document("Aristotle") - ids = [doc.id for doc in aristotle.collection("pets").recursive().get()] + query = aristotle.collection("pets").recursive() + ids = [doc.id for doc in query.get()] expected_ids = [ # Aristotle pets @@ -2208,7 +2269,7 @@ def test_nested_recursive_query(client, cleanup, database): f"Expected '{expected_ids[index]}' at spot {index}, " "got '{ids[index]}'" ) assert ids[index] == expected_ids[index], error_msg - + verify_pipeline(query) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) def test_chunked_query(client, cleanup, database): @@ -2287,7 +2348,7 @@ def test_chunked_and_recursive(client, cleanup, database): @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_watch_query_order(client, cleanup, database): +def test_watch_query_order(client, cleanup, database, verify_pipeline): db = client collection_ref = db.collection("users") doc_ref1 = collection_ref.document("alovelace" + UNIQUE_RESOURCE_ID) @@ -2323,6 +2384,7 @@ def on_snapshot(docs, changes, read_time): ), "expect the sort order to match, born" on_snapshot.called_count += 1 on_snapshot.last_doc_count = len(docs) + verify_pipeline(query_ref) except Exception as e: on_snapshot.failed = e @@ -2363,7 +2425,7 @@ def on_snapshot(docs, changes, read_time): @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_repro_429(client, cleanup, database): +def test_repro_429(client, cleanup, database, verify_pipeline): # See: https://github.com/googleapis/python-firestore/issues/429 now = datetime.datetime.now(tz=datetime.timezone.utc) collection = client.collection("repro-429" + UNIQUE_RESOURCE_ID) @@ -2388,6 +2450,8 @@ def test_repro_429(client, cleanup, database): for snapshot in query2.stream(): print(f"id: {snapshot.id}") + verify_pipeline(query) + verify_pipeline(query2) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) @@ -2957,7 +3021,7 @@ def test_aggregation_query_stream_or_get_w_explain_options_analyze_false( @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_query_with_and_composite_filter(collection, database): +def test_query_with_and_composite_filter(collection, database, verify_pipeline): and_filter = And( filters=[ FieldFilter("stats.product", ">", 5), @@ -2969,10 +3033,11 @@ def test_query_with_and_composite_filter(collection, database): for result in query.stream(): assert result.get("stats.product") > 5 assert result.get("stats.product") < 10 + verify_pipeline(query) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_query_with_or_composite_filter(collection, database): +def test_query_with_or_composite_filter(collection, database, verify_pipeline): or_filter = Or( filters=[ FieldFilter("stats.product", ">", 5), @@ -2992,10 +3057,11 @@ def test_query_with_or_composite_filter(collection, database): assert gt_5 > 0 assert lt_10 > 0 + verify_pipeline(query) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_query_with_complex_composite_filter(collection, database): +def test_query_with_complex_composite_filter(collection, database, verify_pipeline): field_filter = FieldFilter("b", "==", 0) or_filter = Or( filters=[FieldFilter("stats.sum", "==", 0), FieldFilter("stats.sum", "==", 4)] @@ -3015,6 +3081,7 @@ def test_query_with_complex_composite_filter(collection, database): assert sum_0 > 0 assert sum_4 > 0 + verify_pipeline(query) # b == 3 || (stats.sum == 4 && a == 4) comp_filter = Or( @@ -3037,6 +3104,7 @@ def test_query_with_complex_composite_filter(collection, database): assert b_3 is True assert b_not_3 is True + verify_pipeline(query) @pytest.mark.parametrize( @@ -3045,7 +3113,7 @@ def test_query_with_complex_composite_filter(collection, database): ) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) def test_aggregation_query_in_transaction( - client, cleanup, database, aggregation_type, aggregation_args, expected + client, cleanup, database, aggregation_type, aggregation_args, expected, verify_pipeline ): """ Test creating an aggregation query inside a transaction @@ -3079,6 +3147,7 @@ def in_transaction(transaction): assert len(result[0]) == 1 assert result[0][0].value == expected inner_fn_ran = True + verify_pipeline(aggregation_query) in_transaction(transaction) # make sure we didn't skip assertions in inner function @@ -3086,7 +3155,7 @@ def in_transaction(transaction): @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_or_query_in_transaction(client, cleanup, database): +def test_or_query_in_transaction(client, cleanup, database, verify_pipeline): """ Test running or query inside a transaction. Should pass transaction id along with request """ @@ -3124,6 +3193,7 @@ def in_transaction(transaction): result[0].get("b") == 2 and result[1].get("b") == 1 ) inner_fn_ran = True + verify_pipeline(query) in_transaction(transaction) # make sure we didn't skip assertions in inner function @@ -3134,7 +3204,7 @@ def in_transaction(transaction): FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_query_in_transaction_with_explain_options(client, cleanup, database): +def test_query_in_transaction_with_explain_options(client, cleanup, database, verify_pipeline): """ Test query profiling in transactions. """ @@ -3179,6 +3249,7 @@ def in_transaction(transaction): assert explain_metrics.execution_stats is not None inner_fn_ran = True + verify_pipeline(query) in_transaction(transaction) # make sure we didn't skip assertions in inner function @@ -3187,7 +3258,7 @@ def in_transaction(transaction): @pytest.mark.parametrize("with_rollback,expected", [(True, 2), (False, 3)]) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_transaction_rollback(client, cleanup, database, with_rollback, expected): +def test_transaction_rollback(client, cleanup, database, with_rollback, expected, verify_pipeline): """ Create a document in a transaction that is rolled back Document should not show up in later queries @@ -3231,3 +3302,4 @@ def in_transaction(transaction, rollback): assert len(result) == 1 assert len(result[0]) == 1 assert result[0][0].value == expected + Vector(query) From 6283d1aa96ef2966e18e03af6c3bb617e512c11a Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 31 Mar 2025 15:29:02 -0700 Subject: [PATCH 075/131] fixed bug in filter conversion --- google/cloud/firestore_v1/pipeline_expressions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index f04a18e4c..b80ea3258 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -1430,7 +1430,7 @@ def _from_query_filter_pb(filter_pb, client): else: raise TypeError(f"Unexpected CompositeFilter operator type: {filter_pb.op}") elif isinstance(filter_pb, Query_pb.UnaryFilter): - field = Field.of(filter_pb.field) + field = Field.of(filter_pb.field.field_path) if filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NAN: return And(field.exists(), field.is_nan()) elif filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NOT_NAN: @@ -1442,7 +1442,7 @@ def _from_query_filter_pb(filter_pb, client): else: raise TypeError(f"Unexpected UnaryFilter operator type: {filter_pb.op}") elif isinstance(filter_pb, Query_pb.FieldFilter): - field = Field.of(filter_pb.field) + field = Field.of(filter_pb.field.field_path) value = decode_value(filter_pb.value, client) if filter_pb.op == Query_pb.FieldFilter.Operator.LESS_THAN: return And(field.exists(), field.lt(value)) From 080bf42a6f8b0aa9cb9827947a58bcdb95823aff Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 31 Mar 2025 15:48:11 -0700 Subject: [PATCH 076/131] return pipeline copy --- google/cloud/firestore_v1/base_pipeline.py | 86 ++++++++++------------ 1 file changed, 39 insertions(+), 47 deletions(-) diff --git a/google/cloud/firestore_v1/base_pipeline.py b/google/cloud/firestore_v1/base_pipeline.py index e5c7577aa..c1666a240 100644 --- a/google/cloud/firestore_v1/base_pipeline.py +++ b/google/cloud/firestore_v1/base_pipeline.py @@ -13,7 +13,8 @@ # limitations under the License. from __future__ import annotations -from typing import Optional, Sequence, Self +from typing import Optional, Sequence +from typing_extensions import Self from google.cloud.firestore_v1 import pipeline_stages as stages from google.cloud.firestore_v1.types.pipeline import StructuredPipeline as StructuredPipeline_pb from google.cloud.firestore_v1.vector import Vector @@ -43,7 +44,7 @@ def __init__(self, *stages: stages.Stage): Args: *stages: Initial stages for the pipeline. """ - self.stages = list(stages) + self.stages = tuple(stages) def __repr__(self): if not self.stages: @@ -57,6 +58,12 @@ def __repr__(self): def _to_pb(self) -> StructuredPipeline_pb: return StructuredPipeline_pb(pipeline={"stages":[s._to_pb() for s in self.stages]}) + def _append(self, new_stage): + """ + Create a new Pipeline object with a new stage appended + """ + return self.__class__((*self.stages, new_stage)) + def add_fields(self, *fields: Selectable) -> Self: """ Adds new fields to outputs from previous stages. @@ -83,10 +90,9 @@ def add_fields(self, *fields: Selectable) -> Self: expressions. Returns: - A reference to this pipeline instance. Used for method chaining + A new Pipeline object with this stage appended to the stage list """ - self.stages.append(stages.AddFields(*fields)) - return self + return self._append(stages.AddFields(*fields)) def remove_fields(self, *fields: Field | str) -> Self: """ @@ -106,10 +112,9 @@ def remove_fields(self, *fields: Field | str) -> Self: `Field` objects. Returns: - A reference to this pipeline instance. Used for method chaining + A new Pipeline object with this stage appended to the stage list """ - self.stages.append(stages.RemoveFields(*fields)) - return self + return self._append(stages.RemoveFields(*fields)) def select(self, *selections: str | Selectable) -> Self: """ @@ -140,10 +145,9 @@ def select(self, *selections: str | Selectable) -> Self: field names (str) or `Selectable` expressions. Returns: - A reference to this pipeline instance. Used for method chaining + A new Pipeline object with this stage appended to the stage list """ - self.stages.append(stages.Select(*selections)) - return self + return self._append(stages.Select(*selections)) def where(self, condition: FilterCondition) -> Self: """ @@ -180,10 +184,9 @@ def where(self, condition: FilterCondition) -> Self: condition: The `FilterCondition` to apply. Returns: - A reference to this pipeline instance. Used for method chaining + A new Pipeline object with this stage appended to the stage list """ - self.stages.append(stages.Where(condition)) - return self + return self._append(stages.Where(condition)) def find_nearest( self, @@ -235,10 +238,9 @@ def find_nearest( such as limit and output distance field name. Returns: - A reference to this pipeline instance. Used for method chaining + A new Pipeline object with this stage appended to the stage list """ - self.stages.append(stages.FindNearest(field, vector, distance_measure, options)) - return self + return self._append(stages.FindNearest(field, vector, distance_measure, options)) def sort(self, *orders: stages.Ordering) -> Self: """ @@ -264,10 +266,9 @@ def sort(self, *orders: stages.Ordering) -> Self: *orders: One or more `Ordering` instances specifying the sorting criteria. Returns: - A reference to this pipeline instance. Used for method chaining + A new Pipeline object with this stage appended to the stage list """ - self.stages.append(stages.Sort(*orders)) - return self + return self._append(stages.Sort(*orders)) def replace( self, @@ -313,10 +314,9 @@ def replace( mode: The replacement mode Returns: - A reference to this pipeline instance. Used for method chaining + A new Pipeline object with this stage appended to the stage list """ - self.stages.append(stages.Replace(field, mode)) - return self + return self._append(stages.Replace(field, mode)) def sample(self, limit_or_options: int | SampleOptions) -> Self: """ @@ -343,10 +343,9 @@ def sample(self, limit_or_options: int | SampleOptions) -> Self: documents to sample, or a `SampleOptions` object. Returns: - A reference to this pipeline instance. Used for method chaining + A new Pipeline object with this stage appended to the stage list """ - self.stages.append(stages.Sample(limit_or_options)) - return self + return self._append(stages.Sample(limit_or_options)) def union(self, other: Self) -> Self: """ @@ -367,10 +366,9 @@ def union(self, other: Self) -> Self: other: The other `Pipeline` whose results will be unioned with this one. Returns: - A reference to this pipeline instance. Used for method chaining + A new Pipeline object with this stage appended to the stage list """ - self.stages.append(stages.Union(other)) - return self + return self._append(stages.Union(other)) def unnest( self, @@ -431,10 +429,9 @@ def unnest( options: Optional `UnnestOptions` to configure additional behavior, like adding an index field. Returns: - A reference to this pipeline instance. Used for method chaining + A new Pipeline object with this stage appended to the stage list """ - self.stages.append(stages.Unnest(field, alias, options)) - return self + return self._append(stages.Unnest(field, alias, options)) def generic_stage(self, name: str, *params: Expr) -> Self: """ @@ -455,10 +452,9 @@ def generic_stage(self, name: str, *params: Expr) -> Self: *params: A sequence of `Expr` objects representing the parameters for the stage. Returns: - A reference to this pipeline instance. Used for method chaining + A new Pipeline object with this stage appended to the stage list """ - self.stages.append(stages.GenericStage(name, *params)) - return self + return self._append(stages.GenericStage(name, *params)) def offset(self, offset: int) -> Self: """ @@ -480,10 +476,9 @@ def offset(self, offset: int) -> Self: offset: The non-negative number of documents to skip. Returns: - A reference to this pipeline instance. Used for method chaining + A new Pipeline object with this stage appended to the stage list """ - self.stages.append(stages.Offset(offset)) - return self + return self._append(stages.Offset(offset)) def limit(self, limit: int) -> Self: """ @@ -505,10 +500,9 @@ def limit(self, limit: int) -> Self: limit: The non-negative maximum number of documents to return. Returns: - A reference to this pipeline instance. Used for method chaining + A new Pipeline object with this stage appended to the stage list """ - self.stages.append(stages.Limit(limit)) - return self + return self._append(stages.Limit(limit)) def aggregate( self, @@ -556,10 +550,9 @@ def aggregate( expressions to group by before aggregating. Returns: - A reference to this pipeline instance. Used for method chaining + A new Pipeline object with this stage appended to the stage list """ - self.stages.append(stages.Aggregate(*accumulators, groups=groups)) - return self + return self._append(stages.Aggregate(*accumulators, groups=groups)) def distinct(self, *fields: str | Selectable) -> Self: """ @@ -588,7 +581,6 @@ def distinct(self, *fields: str | Selectable) -> Self: contain these fields/expressions. Returns: - A reference to this pipeline instance. Used for method chaining + A new Pipeline object with this stage appended to the stage list """ - self.stages.append(stages.Distinct(*fields)) - return self + return self._append(stages.Distinct(*fields)) From 6cd5c63da88283abcb4e32c8a073acc973a98a96 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 31 Mar 2025 16:09:21 -0700 Subject: [PATCH 077/131] updated results format --- tests/system/pipeline_e2e.yaml | 92 +++++++++++++++++----------------- 1 file changed, 46 insertions(+), 46 deletions(-) diff --git a/tests/system/pipeline_e2e.yaml b/tests/system/pipeline_e2e.yaml index 414e49144..ad4b804a4 100644 --- a/tests/system/pipeline_e2e.yaml +++ b/tests/system/pipeline_e2e.yaml @@ -133,7 +133,7 @@ tests: - ExprWithAlias: - Count - "count" - results: + assert_results: - count: 10 - description: "testAggregates - avg, count, max" pipeline: @@ -154,7 +154,7 @@ tests: - Max: - Field: rating - "max_rating" - results: + assert_results: - count: 2 avg_rating: 4.4 max_rating: 4.6 @@ -168,7 +168,7 @@ tests: - Aggregate: accumulators: [] groups: [genre] - error: ".* requires at least one accumulator" + assert_error: ".* requires at least one accumulator" - description: testDistinct pipeline: - Collection: books @@ -181,7 +181,7 @@ tests: - ToLower: - Field: genre - "lower_genre" - results: + assert_results: - lower_genre: romance - lower_genre: psychological thriller - description: testGroupBysAndAggregate @@ -202,7 +202,7 @@ tests: - Gt: - Field: avg_rating - Constant: 4.3 - results: + assert_results: - avg_rating: 4.7 genre: Fantasy - avg_rating: 4.5 @@ -224,7 +224,7 @@ tests: - Min: - Field: published - "min_published" - results: + assert_results: - count: 10 max_rating: 4.7 min_published: 1813 @@ -238,7 +238,7 @@ tests: - Ordering: - Field: author - ASCENDING - results: + assert_results: - title: "The Hitchhiker's Guide to the Galaxy" author: "Douglas Adams" - title: "Pride and Prejudice" @@ -288,7 +288,7 @@ tests: - Ordering: - Field: author_title - ASCENDING - results: + assert_results: - author: Douglas Adams author_title: Douglas Adams_The Hitchhiker's Guide to the Galaxy - author: Jane Austen @@ -320,7 +320,7 @@ tests: - Eq: - Field: genre - Constant: Science Fiction - results: + assert_results: - title: Dune author: Frank Herbert genre: Science Fiction @@ -346,7 +346,7 @@ tests: - Constant: Dystopian - Select: - title - results: + assert_results: - title: Pride and Prejudice - title: The Handmaid's Tale - title: 1984 @@ -362,7 +362,7 @@ tests: - Select: - title - author - results: + assert_results: - title: 1984 author: George Orwell - title: To Kill a Mockingbird @@ -376,7 +376,7 @@ tests: - ArrayContains: - Constant: tags - Constant: comedy - results: + assert_results: - title: The Hitchhiker's Guide to the Galaxy author: Douglas Adams genre: Science Fiction @@ -399,7 +399,7 @@ tests: - Constant: classic - Select: - title - results: + assert_results: - title: The Hitchhiker's Guide to the Galaxy - title: Pride and Prejudice - description: testArrayContainsAll @@ -412,7 +412,7 @@ tests: - Constant: magic - Select: - title - results: + assert_results: - title: The Lord of the Rings - description: testArrayLength pipeline: @@ -426,7 +426,7 @@ tests: - Eq: - Field: tagsCount - Constant: 3 - results: # All documents have 3 tags + assert_results: # All documents have 3 tags - tagsCount: 3 - tagsCount: 3 - tagsCount: 3 @@ -448,7 +448,7 @@ tests: - Constant: newTag2 - "modifiedTags" - Limit: 1 - results: + assert_results: - modifiedTags: - comedy - space @@ -466,7 +466,7 @@ tests: - Field: title - "bookInfo" - Limit: 1 - results: + assert_results: - bookInfo: Douglas Adams - The Hitchhiker's Guide to the Galaxy - description: testStartsWith pipeline: @@ -481,7 +481,7 @@ tests: - Ordering: - Field: title - ASCENDING - results: + assert_results: - title: The Great Gatsby - title: The Handmaid's Tale - title: The Hitchhiker's Guide to the Galaxy @@ -499,7 +499,7 @@ tests: - Ordering: - Field: title - DESCENDING - results: + assert_results: - title: The Hitchhiker's Guide to the Galaxy - title: The Great Gatsby - description: testLength @@ -515,7 +515,7 @@ tests: - Gt: - Field: titleLength - Constant: 20 - results: + assert_results: - titleLength: 32 title: The Hitchhiker's Guide to the Galaxy - titleLength: 27 @@ -532,7 +532,7 @@ tests: - Eq: - Field: author - Constant: Douglas Adams - results: + assert_results: - reversed_title: yxalaG ot ediug s'reknhiHcH ehT - description: testStringFunctions - ReplaceFirst pipeline: @@ -548,7 +548,7 @@ tests: - Eq: - Field: author - Constant: Douglas Adams - results: + assert_results: - replaced_title: A Hitchhiker's Guide to the Galaxy - description: testStringFunctions - ReplaceAll pipeline: @@ -564,7 +564,7 @@ tests: - Eq: - Field: author - Constant: Douglas Adams - results: + assert_results: - replaced_title: The_Hitchhiker's_Guide_to_the_Galaxy - description: testStringFunctions - CharLength pipeline: @@ -578,7 +578,7 @@ tests: - Eq: - Field: author - Constant: Douglas Adams - results: + assert_results: - title_length: 30 - description: testStringFunctions - ByteLength pipeline: @@ -594,7 +594,7 @@ tests: - Eq: - Field: author - Constant: Douglas Adams - results: + assert_results: - title_byte_length: 42 - description: testToLowercase pipeline: @@ -605,7 +605,7 @@ tests: - Field: title - "lowercaseTitle" - Limit: 1 - results: + assert_results: - lowercaseTitle: the hitchhiker's guide to the galaxy - description: testToUppercase pipeline: @@ -616,7 +616,7 @@ tests: - Field: author - "uppercaseAuthor" - Limit: 1 - results: + assert_results: - uppercaseAuthor: DOUGLAS ADAMS - description: testTrim pipeline: @@ -635,7 +635,7 @@ tests: - "trimmedTitle" - spacedTitle - Limit: 1 - results: + assert_results: - trimmedTitle: The Hitchhiker's Guide to the Galaxy spacedTitle: " The Hitchhiker's Guide to the Galaxy " - description: testLike @@ -647,7 +647,7 @@ tests: - Constant: "%Guide%" - Select: - title - results: + assert_results: - title: The Hitchhiker's Guide to the Galaxy - description: testRegexContains pipeline: @@ -656,7 +656,7 @@ tests: - RegexContains: - Field: title - Constant: "(?i)(the|of)" - results: + assert_results: - title: The Hitchhiker's Guide to the Galaxy - title: One Hundred Years of Solitude - title: The Lord of the Rings @@ -669,7 +669,7 @@ tests: - RegexMatch: - Field: title - Constant: ".*(?i)(the|of).*" - results: + assert_results: - title: The Hitchhiker's Guide to the Galaxy - title: One Hundred Years of Solitude - title: The Lord of the Rings @@ -700,7 +700,7 @@ tests: - Constant: 2 - "ratingDividedByTwo" - Limit: 1 - results: + assert_results: - ratingPlusOne: 5.2 yearsSince1900: 79 ratingTimesTen: 42.0 @@ -726,7 +726,7 @@ tests: - Ordering: - title - ASCENDING - results: + assert_results: - rating: 4.3 title: Crime and Punishment - rating: 4.3 @@ -754,7 +754,7 @@ tests: - Ordering: - Field: title - ASCENDING - results: + assert_results: - title: Crime and Punishment - title: Dune - title: Pride and Prejudice @@ -777,7 +777,7 @@ tests: - Field: rating - "ratingIsNotNaN" - Limit: 1 - results: + assert_results: - ratingIsNull: false ratingIsNotNaN: true - description: testLogicalMinMax @@ -798,7 +798,7 @@ tests: - Field: published - Constant: 1900 - "max_published" - results: + assert_results: - max_rating: 4.5 max_published: 1979 - description: testLogicalMinMax - min @@ -815,7 +815,7 @@ tests: - Field: published - Constant: 1900 - "min_published" - results: + assert_results: - min_rating: 4.2 min_published: 1900 - description: testMapGet @@ -832,7 +832,7 @@ tests: - Eq: - Field: hugoAward - Constant: true - results: + assert_results: - hugoAward: true title: The Hitchhiker's Guide to the Galaxy - hugoAward: true @@ -857,7 +857,7 @@ tests: - Constant: [[0.5, 0.8]] - "euclideanDistance" - Limit: 1 - results: + assert_results: - cosineDistance: 0.02560880430538015 dotProductDistance: 0.13 euclideanDistance: 0.806225774829855 @@ -871,7 +871,7 @@ tests: - Select: - title - Field: awards.hugo - results: + assert_results: - title: The Hitchhiker's Guide to the Galaxy awards.hugo: true - title: Dune @@ -887,7 +887,7 @@ tests: - title - Field: awards.hugo - Field: "__name__" - results: + assert_results: - title: The Hitchhiker's Guide to the Galaxy awards.hugo: true - title: Dune @@ -900,7 +900,7 @@ tests: - Field: title - Constant: "The Hitchhiker's Guide to the Galaxy" - Replace: awards - results: + assert_results: - title: The Hitchhiker's Guide to the Galaxy author: Douglas Adams genre: Science Fiction @@ -916,7 +916,7 @@ tests: pipeline: - Collection: books - Sample: 3 - results_num: 3 # Results will vary due to randomness + assert_count: 3 # Results will vary due to randomness - description: testSamplePercentage pipeline: - Collection: books @@ -924,14 +924,14 @@ tests: - SampleOptions: - 0.6 - percent - results_num: 6 # Results will vary due to randomness + assert_count: 6 # Results will vary due to randomness - description: testUnion pipeline: - Collection: books - Union: - Pipeline: - Collection: books - results_num: 20 # Results will be duplicated + assert_count: 20 # Results will be duplicated - description: testUnnest pipeline: - Collection: books @@ -940,7 +940,7 @@ tests: - Field: title - Constant: The Hitchhiker's Guide to the Galaxy - Unnest: tags - results: + assert_results: - tags: comedy - tags: space - tags: adventure From 7143247a66f172d1356ff5e3af3566e8420093bb Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 31 Mar 2025 16:39:23 -0700 Subject: [PATCH 078/131] added proto assertions to e2e tests --- tests/system/pipeline_e2e.yaml | 1085 ++++++++++++++++++++++ tests/system/test_pipeline_acceptance.py | 37 +- 2 files changed, 1096 insertions(+), 26 deletions(-) diff --git a/tests/system/pipeline_e2e.yaml b/tests/system/pipeline_e2e.yaml index ad4b804a4..a12fd1759 100644 --- a/tests/system/pipeline_e2e.yaml +++ b/tests/system/pipeline_e2e.yaml @@ -135,6 +135,20 @@ tests: - "count" assert_results: - count: 10 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + count: + functionValue: + name: count + - mapValue: {} + name: aggregate - description: "testAggregates - avg, count, max" pipeline: - Collection: books @@ -158,6 +172,37 @@ tests: - count: 2 avg_rating: 4.4 max_rating: 4.6 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: genre + - stringValue: Science Fiction + name: eq + name: where + - args: + - mapValue: + fields: + avg_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: avg + count: + functionValue: + name: count + max_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: maximum + - mapValue: {} + name: aggregate - description: testGroupBysWithoutAccumulators pipeline: - Collection: books @@ -184,6 +229,28 @@ tests: assert_results: - lower_genre: romance - lower_genre: psychological thriller + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: published + - integerValue: '1900' + name: lt + name: where + - args: + - mapValue: + fields: + lower_genre: + functionValue: + args: + - fieldReferenceValue: genre + name: to_lower + name: distinct - description: testGroupBysAndAggregate pipeline: - Collection: books @@ -209,6 +276,39 @@ tests: genre: Romance - avg_rating: 4.4 genre: Science Fiction + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: published + - integerValue: '1984' + name: lt + name: where + - args: + - mapValue: + fields: + avg_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: avg + - mapValue: + fields: + genre: + fieldReferenceValue: genre + name: aggregate + - args: + - functionValue: + args: + - fieldReferenceValue: avg_rating + - doubleValue: 4.3 + name: gt + name: where - description: testMinMax pipeline: - Collection: books @@ -228,6 +328,30 @@ tests: - count: 10 max_rating: 4.7 min_published: 1813 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + count: + functionValue: + name: count + max_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: maximum + min_published: + functionValue: + args: + - fieldReferenceValue: published + name: minimum + - mapValue: {} + name: aggregate - description: selectSpecificFields pipeline: - Collection: books @@ -259,6 +383,28 @@ tests: author: "George Orwell" - title: "The Lord of the Rings" author: "J.R.R. Tolkien" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + author: + fieldReferenceValue: author + title: + fieldReferenceValue: title + name: select + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: author + name: sort - description: addAndRemoveFields pipeline: - Collection: books @@ -309,6 +455,48 @@ tests: author_title: George Orwell_1984 - author: J.R.R. Tolkien author_title: J.R.R. Tolkien_The Lord of the Rings + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + author_title: + functionValue: + args: + - fieldReferenceValue: author + - stringValue: _ + - fieldReferenceValue: title + name: str_concat + title_author: + functionValue: + args: + - fieldReferenceValue: title + - stringValue: _ + - fieldReferenceValue: author + name: str_concat + name: add_fields + - args: + - fieldReferenceValue: title_author + - fieldReferenceValue: tags + - fieldReferenceValue: awards + - fieldReferenceValue: rating + - fieldReferenceValue: title + - fieldReferenceValue: published + - fieldReferenceValue: genre + - fieldReferenceValue: nestedField + name: remove_fields + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: author_title + name: sort - description: whereByMultipleConditions pipeline: - Collection: books @@ -333,6 +521,27 @@ tests: awards: hugo: true nebula: true + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - functionValue: + args: + - fieldReferenceValue: rating + - doubleValue: 4.5 + name: gt + - functionValue: + args: + - fieldReferenceValue: genre + - stringValue: Science Fiction + name: eq + name: and + name: where - description: whereByOrCondition pipeline: - Collection: books @@ -350,6 +559,33 @@ tests: - title: Pride and Prejudice - title: The Handmaid's Tale - title: 1984 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - functionValue: + args: + - fieldReferenceValue: genre + - stringValue: Romance + name: eq + - functionValue: + args: + - fieldReferenceValue: genre + - stringValue: Dystopian + name: eq + name: or + name: where + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + name: select - description: testPipelineWithOffsetAndLimit pipeline: - Collection: books @@ -369,6 +605,34 @@ tests: author: Harper Lee - title: The Lord of the Rings author: J.R.R. Tolkien + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: author + name: sort + - args: + - integerValue: '5' + name: offset + - args: + - integerValue: '3' + name: limit + - args: + - mapValue: + fields: + author: + fieldReferenceValue: author + title: + fieldReferenceValue: title + name: select - description: testArrayContains pipeline: - Collection: books @@ -378,6 +642,25 @@ tests: - Constant: comedy assert_results: - title: The Hitchhiker's Guide to the Galaxy + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: '%Guide%' + name: like + name: where + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + name: select author: Douglas Adams genre: Science Fiction published: 1979 @@ -389,6 +672,19 @@ tests: awards: hugo: true nebula: false + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - stringValue: tags + - stringValue: comedy + name: array_contains + name: where - description: testArrayContainsAny pipeline: - Collection: books @@ -402,6 +698,28 @@ tests: assert_results: - title: The Hitchhiker's Guide to the Galaxy - title: Pride and Prejudice + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: tags + - arrayValue: + values: + - stringValue: comedy + - stringValue: classic + name: array_contains_any + name: where + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + name: select - description: testArrayContainsAll pipeline: - Collection: books @@ -414,6 +732,28 @@ tests: - title assert_results: - title: The Lord of the Rings + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: tags + - arrayValue: + values: + - stringValue: adventure + - stringValue: magic + name: array_contains_all + name: where + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + name: select - description: testArrayLength pipeline: - Collection: books @@ -437,6 +777,28 @@ tests: - tagsCount: 3 - tagsCount: 3 - tagsCount: 3 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + tagsCount: + functionValue: + args: + - fieldReferenceValue: tags + name: array_length + name: select + - args: + - functionValue: + args: + - fieldReferenceValue: tagsCount + - integerValue: '3' + name: eq + name: where - description: testArrayConcat pipeline: - Collection: books @@ -455,6 +817,26 @@ tests: - adventure - newTag1 - newTag2 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + modifiedTags: + functionValue: + args: + - fieldReferenceValue: tags + - stringValue: newTag1 + - stringValue: newTag2 + name: array_concat + name: select + - args: + - integerValue: '1' + name: limit - description: testStrConcat pipeline: - Collection: books @@ -468,6 +850,26 @@ tests: - Limit: 1 assert_results: - bookInfo: Douglas Adams - The Hitchhiker's Guide to the Galaxy + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + bookInfo: + functionValue: + args: + - fieldReferenceValue: author + - stringValue: ' - ' + - fieldReferenceValue: title + name: str_concat + name: select + - args: + - integerValue: '1' + name: limit - description: testStartsWith pipeline: - Collection: books @@ -486,6 +888,33 @@ tests: - title: The Handmaid's Tale - title: The Hitchhiker's Guide to the Galaxy - title: The Lord of the Rings + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: the + name: starts_with + name: where + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + name: select + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: title + name: sort - description: testEndsWith pipeline: - Collection: books @@ -502,6 +931,33 @@ tests: assert_results: - title: The Hitchhiker's Guide to the Galaxy - title: The Great Gatsby + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: y + name: ends_with + name: where + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + name: select + - args: + - mapValue: + fields: + direction: + stringValue: descending + expression: + fieldReferenceValue: title + name: sort - description: testLength pipeline: - Collection: books @@ -520,6 +976,30 @@ tests: title: The Hitchhiker's Guide to the Galaxy - titleLength: 27 title: One Hundred Years of Solitude + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + titleLength: + functionValue: + args: + - fieldReferenceValue: title + name: char_length + name: select + - args: + - functionValue: + args: + - fieldReferenceValue: titleLength + - integerValue: '20' + name: gt + name: where - description: testStringFunctions - Reverse pipeline: - Collection: books @@ -534,6 +1014,28 @@ tests: - Constant: Douglas Adams assert_results: - reversed_title: yxalaG ot ediug s'reknhiHcH ehT + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + reversed_title: + functionValue: + args: + - fieldReferenceValue: title + name: reverse + name: select + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: Douglas Adams + name: eq + name: where - description: testStringFunctions - ReplaceFirst pipeline: - Collection: books @@ -550,6 +1052,30 @@ tests: - Constant: Douglas Adams assert_results: - replaced_title: A Hitchhiker's Guide to the Galaxy + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + replaced_title: + functionValue: + args: + - fieldReferenceValue: title + - stringValue: The + - stringValue: A + name: replace_first + name: select + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: Douglas Adams + name: eq + name: where - description: testStringFunctions - ReplaceAll pipeline: - Collection: books @@ -566,6 +1092,30 @@ tests: - Constant: Douglas Adams assert_results: - replaced_title: The_Hitchhiker's_Guide_to_the_Galaxy + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + replaced_title: + functionValue: + args: + - fieldReferenceValue: title + - stringValue: ' ' + - stringValue: _ + name: replace_all + name: select + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: Douglas Adams + name: eq + name: where - description: testStringFunctions - CharLength pipeline: - Collection: books @@ -580,6 +1130,28 @@ tests: - Constant: Douglas Adams assert_results: - title_length: 30 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + title_length: + functionValue: + args: + - fieldReferenceValue: title + name: char_length + name: select + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: Douglas Adams + name: eq + name: where - description: testStringFunctions - ByteLength pipeline: - Collection: books @@ -596,6 +1168,32 @@ tests: - Constant: Douglas Adams assert_results: - title_byte_length: 42 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + title_byte_length: + functionValue: + args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: "_\u94F6\u6CB3\u7CFB\u6F2B\u6E38\u6307\u5357" + name: str_concat + name: byte_length + name: select + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: Douglas Adams + name: eq + name: where - description: testToLowercase pipeline: - Collection: books @@ -607,6 +1205,24 @@ tests: - Limit: 1 assert_results: - lowercaseTitle: the hitchhiker's guide to the galaxy + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + lowercaseTitle: + functionValue: + args: + - fieldReferenceValue: title + name: to_lower + name: select + - args: + - integerValue: '1' + name: limit - description: testToUppercase pipeline: - Collection: books @@ -618,6 +1234,24 @@ tests: - Limit: 1 assert_results: - uppercaseAuthor: DOUGLAS ADAMS + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + uppercaseAuthor: + functionValue: + args: + - fieldReferenceValue: author + name: to_upper + name: select + - args: + - integerValue: '1' + name: limit - description: testTrim pipeline: - Collection: books @@ -638,6 +1272,37 @@ tests: assert_results: - trimmedTitle: The Hitchhiker's Guide to the Galaxy spacedTitle: " The Hitchhiker's Guide to the Galaxy " + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + spacedTitle: + functionValue: + args: + - stringValue: ' ' + - fieldReferenceValue: title + - stringValue: ' ' + name: str_concat + name: add_fields + - args: + - mapValue: + fields: + spacedTitle: + fieldReferenceValue: spacedTitle + trimmedTitle: + functionValue: + args: + - fieldReferenceValue: spacedTitle + name: trim + name: select + - args: + - integerValue: '1' + name: limit - description: testLike pipeline: - Collection: books @@ -662,6 +1327,32 @@ tests: - title: The Lord of the Rings - title: To Kill a Mockingbird - title: The Great Gatsby + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: .*(?i)(the|of).* + name: regex_match + name: where + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: (?i)(the|of) + name: regex_contains + name: where - description: testRegexMatches pipeline: - Collection: books @@ -705,6 +1396,43 @@ tests: yearsSince1900: 79 ratingTimesTen: 42.0 ratingDividedByTwo: 2.1 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + ratingDividedByTwo: + functionValue: + args: + - fieldReferenceValue: rating + - integerValue: '2' + name: divide + ratingPlusOne: + functionValue: + args: + - fieldReferenceValue: rating + - integerValue: '1' + name: add + ratingTimesTen: + functionValue: + args: + - fieldReferenceValue: rating + - integerValue: '10' + name: multiply + yearsSince1900: + functionValue: + args: + - fieldReferenceValue: published + - integerValue: '1900' + name: subtract + name: select + - args: + - integerValue: '1' + name: limit - description: testComparisonOperators pipeline: - Collection: books @@ -733,6 +1461,48 @@ tests: title: One Hundred Years of Solitude - rating: 4.5 title: Pride and Prejudice + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - functionValue: + args: + - fieldReferenceValue: rating + - doubleValue: 4.2 + name: gt + - functionValue: + args: + - fieldReferenceValue: rating + - doubleValue: 4.5 + name: lte + - functionValue: + args: + - fieldReferenceValue: genre + - stringValue: Science Fiction + name: neq + name: and + name: where + - args: + - mapValue: + fields: + rating: + fieldReferenceValue: rating + title: + fieldReferenceValue: title + name: select + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: title + name: sort - description: testLogicalOperators pipeline: - Collection: books @@ -758,6 +1528,49 @@ tests: - title: Crime and Punishment - title: Dune - title: Pride and Prejudice + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - functionValue: + args: + - functionValue: + args: + - fieldReferenceValue: rating + - doubleValue: 4.5 + name: gt + - functionValue: + args: + - fieldReferenceValue: genre + - stringValue: Science Fiction + name: eq + name: and + - functionValue: + args: + - fieldReferenceValue: published + - integerValue: '1900' + name: lt + name: or + name: where + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + name: select + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: title + name: sort - description: testChecks pipeline: - Collection: books @@ -780,6 +1593,42 @@ tests: assert_results: - ratingIsNull: false ratingIsNotNaN: true + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - functionValue: + args: + - fieldReferenceValue: rating + name: is_nan + name: not + name: where + - args: + - mapValue: + fields: + ratingIsNotNaN: + functionValue: + args: + - functionValue: + args: + - fieldReferenceValue: rating + name: is_nan + name: not + ratingIsNull: + functionValue: + args: + - fieldReferenceValue: rating + - nullValue: null + name: eq + name: select + - args: + - integerValue: '1' + name: limit - description: testLogicalMinMax pipeline: - Collection: books @@ -801,6 +1650,35 @@ tests: assert_results: - max_rating: 4.5 max_published: 1979 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: Douglas Adams + name: eq + name: where + - args: + - mapValue: + fields: + max_published: + functionValue: + args: + - fieldReferenceValue: published + - integerValue: '1900' + name: logical_maximum + max_rating: + functionValue: + args: + - fieldReferenceValue: rating + - doubleValue: 4.5 + name: logical_maximum + name: select - description: testLogicalMinMax - min pipeline: - Collection: books @@ -818,6 +1696,28 @@ tests: assert_results: - min_rating: 4.2 min_published: 1900 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + min_published: + functionValue: + args: + - fieldReferenceValue: published + - integerValue: '1900' + name: logical_minimum + min_rating: + functionValue: + args: + - fieldReferenceValue: rating + - doubleValue: 4.5 + name: logical_minimum + name: select - description: testMapGet pipeline: - Collection: books @@ -837,6 +1737,31 @@ tests: title: The Hitchhiker's Guide to the Galaxy - hugoAward: true title: Dune + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + hugoAward: + functionValue: + args: + - fieldReferenceValue: awards + - stringValue: hugo + name: map_get + title: + fieldReferenceValue: title + name: select + - args: + - functionValue: + args: + - fieldReferenceValue: hugoAward + - booleanValue: true + name: eq + name: where - description: testDistanceFunctions pipeline: - Collection: books @@ -861,6 +1786,55 @@ tests: - cosineDistance: 0.02560880430538015 dotProductDistance: 0.13 euclideanDistance: 0.806225774829855 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + cosineDistance: + functionValue: + args: + - arrayValue: + values: + - doubleValue: 0.1 + - doubleValue: 0.1 + - arrayValue: + values: + - doubleValue: 0.5 + - doubleValue: 0.8 + name: cosine_distance + dotProductDistance: + functionValue: + args: + - arrayValue: + values: + - doubleValue: 0.1 + - doubleValue: 0.1 + - arrayValue: + values: + - doubleValue: 0.5 + - doubleValue: 0.8 + name: dot_product + euclideanDistance: + functionValue: + args: + - arrayValue: + values: + - doubleValue: 0.1 + - doubleValue: 0.1 + - arrayValue: + values: + - doubleValue: 0.5 + - doubleValue: 0.8 + name: euclidean_distance + name: select + - args: + - integerValue: '1' + name: limit - description: testNestedFields pipeline: - Collection: books @@ -876,6 +1850,50 @@ tests: awards.hugo: true - title: Dune awards.hugo: true + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: awards.hugo + - booleanValue: true + name: eq + name: where + - args: + - mapValue: + fields: + __name__: + fieldReferenceValue: __name__ + awards.hugo: + fieldReferenceValue: awards.hugo + title: + fieldReferenceValue: title + name: select + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: awards.hugo + - booleanValue: true + name: eq + name: where + - args: + - mapValue: + fields: + awards.hugo: + fieldReferenceValue: awards.hugo + title: + fieldReferenceValue: title + name: select - description: testPipelineInTransactions pipeline: - Collection: books @@ -912,11 +1930,38 @@ tests: - adventure hugo: true nebula: false + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: The Hitchhiker's Guide to the Galaxy + name: eq + name: where + - args: + - fieldReferenceValue: awards + - stringValue: full_replace + name: replace - description: testSampleLimit pipeline: - Collection: books - Sample: 3 assert_count: 3 # Results will vary due to randomness + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - integerValue: '3' + - stringValue: documents + name: sample - description: testSamplePercentage pipeline: - Collection: books @@ -925,6 +1970,16 @@ tests: - 0.6 - percent assert_count: 6 # Results will vary due to randomness + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - doubleValue: 0.6 + - stringValue: percent + name: sample - description: testUnion pipeline: - Collection: books @@ -932,6 +1987,19 @@ tests: - Pipeline: - Collection: books assert_count: 20 # Results will be duplicated + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - pipelineValue: + stages: + - args: + - referenceValue: /books + name: collection + name: union - description: testUnnest pipeline: - Collection: books @@ -944,3 +2012,20 @@ tests: - tags: comedy - tags: space - tags: adventure + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: The Hitchhiker's Guide to the Galaxy + name: eq + name: where + - args: + - fieldReferenceValue: tags + - fieldReferenceValue: tags + name: unnest diff --git a/tests/system/test_pipeline_acceptance.py b/tests/system/test_pipeline_acceptance.py index 772698095..c51fd996b 100644 --- a/tests/system/test_pipeline_acceptance.py +++ b/tests/system/test_pipeline_acceptance.py @@ -21,6 +21,8 @@ from typing import Any from contextlib import nullcontext +from google.protobuf.json_format import MessageToDict + # from google.cloud.firestore_v1.pipeline_stages import * from google.cloud.firestore_v1 import pipeline_stages from google.cloud.firestore_v1 import pipeline_expressions @@ -127,35 +129,18 @@ def parse_expressions(client, yaml_element: Any): ) def test_e2e_scenario(test_dict): client = Client(project=FIRESTORE_PROJECT, database=FIRESTORE_TEST_DB) - error_regex = test_dict.get("error", None) + error_regex = test_dict.get("assert_error", None) + expected_proto = test_dict.get("assert_proto", None) + pipeline = parse_pipeline(client, test_dict["pipeline"]) + # check if proto matches as expected + if expected_proto: + got_proto = MessageToDict(pipeline._to_pb()._pb) + assert yaml.dump(expected_proto) == yaml.dump(got_proto) + # check if server responds as expected with pytest.raises(GoogleAPIError) if error_regex else nullcontext() as ctx: - pipeline = parse_pipeline(client, test_dict["pipeline"]) - print(pipeline._to_pb()) pipeline.execute() # check for error message if expected if error_regex: found_error = str(ctx.value) match = re.search(error_regex, found_error) - assert match, f"error '{found_error}' does not match '{error_regex}'" - - # before_ast = ast.parse(test_dict["before"]) - # got_ast = before_ast - # for transformer_info in test_dict["transformers"]: - # # transformer can be passed as a string, or a dict with name and args - # if isinstance(transformer_info, str): - # transformer_class = globals()[transformer_info] - # transformer_args = {} - # else: - # transformer_class = globals()[transformer_info["name"]] - # transformer_args = transformer_info.get("args", {}) - # transformer = transformer_class(**transformer_args) - # got_ast = transformer.visit(got_ast) - # if got_ast is None: - # final_str = "" - # else: - # final_str = black.format_str(ast.unparse(got_ast), mode=black.FileMode()) - # if test_dict.get("after") is None: - # expected_str = "" - # else: - # expected_str = black.format_str(test_dict["after"], mode=black.FileMode()) - # assert final_str == expected_str, f"Expected:\n{expected_str}\nGot:\n{final_str}" + assert match, f"error '{found_error}' does not match '{error_regex}'" \ No newline at end of file From 05c4f23c2d49854b779ccdc41fd4675d5d0f10f2 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 31 Mar 2025 17:35:54 -0700 Subject: [PATCH 079/131] added tests for FieldFilter._from_filter_pb --- .../firestore_v1/pipeline_expressions.py | 17 +- tests/unit/v1/test_pipeline_expressions.py | 270 ++++++++++++++++++ 2 files changed, 280 insertions(+), 7 deletions(-) create mode 100644 tests/unit/v1/test_pipeline_expressions.py diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index b80ea3258..c48196ae9 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -366,29 +366,29 @@ def lte(self, other: Expr | CONSTANT_TYPE) -> "Lte": """ return Lte(self, self._cast_to_expr_or_convert_to_constant(other)) - def in_any(self, *others: Expr | CONSTANT_TYPE) -> "In": + def in_any(self, array: List[Expr | CONSTANT_TYPE]) -> "In": """Creates an expression that checks if this expression is equal to any of the provided values or expressions. Example: >>> # Check if the 'category' field is either "Electronics" or value of field 'primaryType' - >>> Field.of("category").in_any("Electronics", Field.of("primaryType")) + >>> Field.of("category").in_any(["Electronics", Field.of("primaryType")]) Args: - *others: The values or expressions to check against. + array: The values or expressions to check against. Returns: A new `Expr` representing the 'IN' comparison. """ - return In(self, [self._cast_to_expr_or_convert_to_constant(o) for o in others]) + return In(self, [self._cast_to_expr_or_convert_to_constant(v) for v in array]) - def not_in_any(self, *others: Expr | CONSTANT_TYPE) -> "Not": + def not_in_any(self, array: List[Expr | CONSTANT_TYPE]) -> "Not": """Creates an expression that checks if this expression is not equal to any of the provided values or expressions. Example: >>> # Check if the 'status' field is neither "pending" nor "cancelled" - >>> Field.of("status").not_in_any("pending", "cancelled") + >>> Field.of("status").not_in_any(["pending", "cancelled"]) Args: *others: The values or expressions to check against. @@ -396,7 +396,7 @@ def not_in_any(self, *others: Expr | CONSTANT_TYPE) -> "Not": Returns: A new `Expr` representing the 'NOT IN' comparison. """ - return Not(self.in_any(*others)) + return Not(self.in_any(array)) def array_concat(self, array: List[Expr | CONSTANT_TYPE]) -> "ArrayConcat": """Creates an expression that concatenates an array expression with another array. @@ -1070,6 +1070,9 @@ class ListOfExprs(Expr): def __init__(self, exprs: List[Expr]): self.exprs: list[Expr] = exprs + def __repr__(self): + return f"{self.__class__.__name__}({', '.join([repr(e) for e in self.exprs])})" + def _to_pb(self): return Value(array_value={"values": [e._to_pb() for e in self.exprs]}) diff --git a/tests/unit/v1/test_pipeline_expressions.py b/tests/unit/v1/test_pipeline_expressions.py new file mode 100644 index 000000000..8d46e6c53 --- /dev/null +++ b/tests/unit/v1/test_pipeline_expressions.py @@ -0,0 +1,270 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# limitations under the License. + +import pytest +import mock + +from google.cloud.firestore_v1 import _helpers +from google.cloud.firestore_v1.types import document as document_pb +from google.cloud.firestore_v1.types import query as query_pb +from google.cloud.firestore_v1.pipeline_expressions import FilterCondition +from google.cloud.firestore_v1 import pipeline_expressions as expr + + +@pytest.fixture +def mock_client(): + client = mock.Mock(spec=["_database_string", "collection"]) + client._database_string = "projects/p/databases/d" + return client + + +class TestFilterCondition: + + def test__from_query_filter_pb_composite_filter_or(self, mock_client): + """ + test composite OR filters + + should create an or statement, made up of ands checking of existance of relevant fields + """ + filter1_pb = query_pb.StructuredQuery.FieldFilter( + field=query_pb.StructuredQuery.FieldReference(field_path="field1"), + op=query_pb.StructuredQuery.FieldFilter.Operator.EQUAL, + value=_helpers.encode_value("val1"), + ) + filter2_pb = query_pb.StructuredQuery.UnaryFilter( + field=query_pb.StructuredQuery.FieldReference(field_path="field2"), + op=query_pb.StructuredQuery.UnaryFilter.Operator.IS_NULL, + ) + + composite_pb = query_pb.StructuredQuery.CompositeFilter( + op=query_pb.StructuredQuery.CompositeFilter.Operator.OR, + filters=[ + query_pb.StructuredQuery.Filter(field_filter=filter1_pb), + query_pb.StructuredQuery.Filter(unary_filter=filter2_pb), + ], + ) + wrapped_filter_pb = query_pb.StructuredQuery.Filter(composite_filter=composite_pb) + + result = FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + + # should include existance checks + expected_cond1 = expr.And(expr.Exists(expr.Field.of("field1")), expr.Eq(expr.Field.of("field1"), expr.Constant("val1"))) + expected_cond2 = expr.And(expr.Exists(expr.Field.of("field2")), expr.Eq(expr.Field.of("field2"), expr.Constant(None))) + expected = expr.Or(expected_cond1, expected_cond2) + + assert repr(result) == repr(expected) + + def test__from_query_filter_pb_composite_filter_and(self, mock_client): + """ + test composite AND filters + + should create an and statement, made up of ands checking of existance of relevant fields + """ + filter1_pb = query_pb.StructuredQuery.FieldFilter( + field=query_pb.StructuredQuery.FieldReference(field_path="field1"), + op=query_pb.StructuredQuery.FieldFilter.Operator.GREATER_THAN, + value=_helpers.encode_value(100), + ) + filter2_pb = query_pb.StructuredQuery.FieldFilter( + field=query_pb.StructuredQuery.FieldReference(field_path="field2"), + op=query_pb.StructuredQuery.FieldFilter.Operator.LESS_THAN, + value=_helpers.encode_value(200), + ) + + composite_pb = query_pb.StructuredQuery.CompositeFilter( + op=query_pb.StructuredQuery.CompositeFilter.Operator.AND, + filters=[ + query_pb.StructuredQuery.Filter(field_filter=filter1_pb), + query_pb.StructuredQuery.Filter(field_filter=filter2_pb), + ], + ) + wrapped_filter_pb = query_pb.StructuredQuery.Filter(composite_filter=composite_pb) + + result = FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + + # should include existance checks + expected_cond1 = expr.And(expr.Exists(expr.Field.of("field1")), expr.Gt(expr.Field.of("field1"), expr.Constant(100))) + expected_cond2 = expr.And(expr.Exists(expr.Field.of("field2")), expr.Lt(expr.Field.of("field2"), expr.Constant(200))) + expected = expr.And(expected_cond1, expected_cond2) + assert repr(result) == repr(expected) + + def test__from_query_filter_pb_composite_filter_nested(self, mock_client): + """ + test composite filter with complex nested checks + """ + # OR (field1 == "val1", AND(field2 > 10, field3 IS NOT NULL)) + filter1_pb = query_pb.StructuredQuery.FieldFilter( + field=query_pb.StructuredQuery.FieldReference(field_path="field1"), + op=query_pb.StructuredQuery.FieldFilter.Operator.EQUAL, + value=_helpers.encode_value("val1"), + ) + filter2_pb = query_pb.StructuredQuery.FieldFilter( + field=query_pb.StructuredQuery.FieldReference(field_path="field2"), + op=query_pb.StructuredQuery.FieldFilter.Operator.GREATER_THAN, + value=_helpers.encode_value(10), + ) + filter3_pb = query_pb.StructuredQuery.UnaryFilter( + field=query_pb.StructuredQuery.FieldReference(field_path="field3"), + op=query_pb.StructuredQuery.UnaryFilter.Operator.IS_NOT_NULL, + ) + inner_and_pb = query_pb.StructuredQuery.CompositeFilter( + op=query_pb.StructuredQuery.CompositeFilter.Operator.AND, + filters=[ + query_pb.StructuredQuery.Filter(field_filter=filter2_pb), + query_pb.StructuredQuery.Filter(unary_filter=filter3_pb), + ], + ) + outer_or_pb = query_pb.StructuredQuery.CompositeFilter( + op=query_pb.StructuredQuery.CompositeFilter.Operator.OR, + filters=[ + query_pb.StructuredQuery.Filter(field_filter=filter1_pb), + query_pb.StructuredQuery.Filter(composite_filter=inner_and_pb), + ], + ) + wrapped_filter_pb = query_pb.StructuredQuery.Filter(composite_filter=outer_or_pb) + + result = FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + + expected_cond1 = expr.And(expr.Exists(expr.Field.of("field1")), expr.Eq(expr.Field.of("field1"), expr.Constant("val1"))) + expected_cond2 = expr.And(expr.Exists(expr.Field.of("field2")), expr.Gt(expr.Field.of("field2"), expr.Constant(10))) + expected_cond3 = expr.And(expr.Exists(expr.Field.of("field3")), expr.Not(expr.Eq(expr.Field.of("field3"), expr.Constant(None)))) + expected_inner_and = expr.And(expected_cond2, expected_cond3) + expected_outer_or = expr.Or(expected_cond1, expected_inner_and) + + assert repr(result) == repr(expected_outer_or) + + + def test__from_query_filter_pb_composite_filter_unknown_op(self, mock_client): + """ + check composite filter with unsupported operator type + """ + filter1_pb = query_pb.StructuredQuery.FieldFilter( + field=query_pb.StructuredQuery.FieldReference(field_path="field1"), + op=query_pb.StructuredQuery.FieldFilter.Operator.EQUAL, + value=_helpers.encode_value("val1"), + ) + composite_pb = query_pb.StructuredQuery.CompositeFilter( + op=query_pb.StructuredQuery.CompositeFilter.Operator.OPERATOR_UNSPECIFIED, + filters=[query_pb.StructuredQuery.Filter(field_filter=filter1_pb)], + ) + wrapped_filter_pb = query_pb.StructuredQuery.Filter(composite_filter=composite_pb) + + with pytest.raises(TypeError, match="Unexpected CompositeFilter operator type"): + FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + + @pytest.mark.parametrize( + "op_enum, expected_expr_func", + [ + (query_pb.StructuredQuery.UnaryFilter.Operator.IS_NAN, expr.IsNaN), + (query_pb.StructuredQuery.UnaryFilter.Operator.IS_NOT_NAN, lambda f: expr.Not(f.is_nan())), + (query_pb.StructuredQuery.UnaryFilter.Operator.IS_NULL, lambda f: f.eq(None)), + (query_pb.StructuredQuery.UnaryFilter.Operator.IS_NOT_NULL, lambda f: expr.Not(f.eq(None))), + ], + ) + def test__from_query_filter_pb_unary_filter(self, mock_client, op_enum, expected_expr_func): + """ + test supported unary filters + """ + field_path = "unary_field" + filter_pb = query_pb.StructuredQuery.UnaryFilter( + field=query_pb.StructuredQuery.FieldReference(field_path=field_path), + op=op_enum, + ) + wrapped_filter_pb = query_pb.StructuredQuery.Filter(unary_filter=filter_pb) + + result = FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + + field_expr_inst = expr.Field.of(field_path) + expected_condition = expected_expr_func(field_expr_inst) + # should include existance checks + expected = expr.And(expr.Exists(field_expr_inst), expected_condition) + + assert repr(result) == repr(expected) + + def test__from_query_filter_pb_unary_filter_unknown_op(self, mock_client): + """ + check unary filter with unsupported operator type + """ + field_path = "unary_field" + filter_pb = query_pb.StructuredQuery.UnaryFilter( + field=query_pb.StructuredQuery.FieldReference(field_path=field_path), + op=query_pb.StructuredQuery.UnaryFilter.Operator.OPERATOR_UNSPECIFIED, # Unknown op + ) + wrapped_filter_pb = query_pb.StructuredQuery.Filter(unary_filter=filter_pb) + + with pytest.raises(TypeError, match="Unexpected UnaryFilter operator type"): + FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + + + @pytest.mark.parametrize( + "op_enum, value, expected_expr_func", + [ + (query_pb.StructuredQuery.FieldFilter.Operator.LESS_THAN, 10, expr.Lt), + (query_pb.StructuredQuery.FieldFilter.Operator.LESS_THAN_OR_EQUAL, 10, expr.Lte), + (query_pb.StructuredQuery.FieldFilter.Operator.GREATER_THAN, 10, expr.Gt), + (query_pb.StructuredQuery.FieldFilter.Operator.GREATER_THAN_OR_EQUAL, 10, expr.Gte), + (query_pb.StructuredQuery.FieldFilter.Operator.EQUAL, 10, expr.Eq), + (query_pb.StructuredQuery.FieldFilter.Operator.NOT_EQUAL, 10, expr.Neq), + (query_pb.StructuredQuery.FieldFilter.Operator.ARRAY_CONTAINS, 10, expr.ArrayContains), + (query_pb.StructuredQuery.FieldFilter.Operator.ARRAY_CONTAINS_ANY, [10, 20], expr.ArrayContainsAny), + (query_pb.StructuredQuery.FieldFilter.Operator.IN, [10, 20], expr.In), + (query_pb.StructuredQuery.FieldFilter.Operator.NOT_IN, [10, 20], lambda f, v: expr.Not(f.in_any(v))), + ], + ) + def test__from_query_filter_pb_field_filter(self, mock_client, op_enum, value, expected_expr_func): + """ + test supported field filters + """ + field_path = "test_field" + value_pb = _helpers.encode_value(value) + filter_pb = query_pb.StructuredQuery.FieldFilter( + field=query_pb.StructuredQuery.FieldReference(field_path=field_path), + op=op_enum, + value=value_pb, + ) + wrapped_filter_pb = query_pb.StructuredQuery.Filter(field_filter=filter_pb) + + result = FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + + field_expr = expr.Field.of(field_path) + # convert values into constants + value = [expr.Constant(e) for e in value] if isinstance(value, list) else expr.Constant(value) + expected_condition = expected_expr_func(field_expr, value) + # should include existance checks + expected = expr.And(expr.Exists(field_expr), expected_condition) + + assert repr(result) == repr(expected) + + def test__from_query_filter_pb_field_filter_unknown_op(self, mock_client): + """ + check field filter with unsupported operator type + """ + field_path = "test_field" + value_pb = _helpers.encode_value(10) + filter_pb = query_pb.StructuredQuery.FieldFilter( + field=query_pb.StructuredQuery.FieldReference(field_path=field_path), + op=query_pb.StructuredQuery.FieldFilter.Operator.OPERATOR_UNSPECIFIED, # Unknown op + value=value_pb, + ) + wrapped_filter_pb = query_pb.StructuredQuery.Filter(field_filter=filter_pb) + + with pytest.raises(TypeError, match="Unexpected FieldFilter operator type"): + FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + + def test__from_query_filter_pb_unknown_filter_type(self, mock_client): + """ + test with unsupported filter type + """ + # Test with an unexpected protobuf type + with pytest.raises(TypeError, match="Unexpected filter type"): + FilterCondition._from_query_filter_pb(document_pb.Value(), mock_client) From 7512a1a775c1096d88c321b6e048fa4821a07c01 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 31 Mar 2025 17:39:30 -0700 Subject: [PATCH 080/131] fixed repr --- google/cloud/firestore_v1/pipeline_expressions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index c48196ae9..d7e8fe4a4 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -1071,7 +1071,7 @@ def __init__(self, exprs: List[Expr]): self.exprs: list[Expr] = exprs def __repr__(self): - return f"{self.__class__.__name__}({', '.join([repr(e) for e in self.exprs])})" + return f"{self.__class__.__name__}({self.exprs})" def _to_pb(self): return Value(array_value={"values": [e._to_pb() for e in self.exprs]}) From 901c0e12185729daa53c67c48c52bb8d1cd506b1 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 31 Mar 2025 18:42:16 -0700 Subject: [PATCH 081/131] fixed typing issues --- google/cloud/firestore_v1/pipeline_expressions.py | 8 ++++---- google/cloud/firestore_v1/pipeline_stages.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index d7e8fe4a4..7acd764b3 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -102,7 +102,7 @@ def percentage(value:float): Args: value: percentage of documents to return """ - return SampleOptions(value, mode=SampleOptions.Mode.PERCENTAGE) + return SampleOptions(value, mode=SampleOptions.Mode.PERCENT) class Expr(ABC): """Represents an expression that can be evaluated to a value within the @@ -817,7 +817,7 @@ def map_get(self, key: str) -> "MapGet": Returns: A new `Expr` representing the value associated with the given key in the map. """ - return MapGet(self, key) + return MapGet(self, Constant.of(key)) def cosine_distance(self, other: Expr | list[float] | Vector) -> "CosineDistance": """Calculates the cosine distance between two vectors. @@ -1126,8 +1126,8 @@ def __init__(self, left: Expr, right: Expr): class MapGet(Function): """Represents accessing a value within a map by key.""" - def __init__(self, map_: Expr, key: str): - super().__init__("map_get", [map_, Constant(key)]) + def __init__(self, map_: Expr, key: Constant[str]): + super().__init__("map_get", [map_, key]) class Mod(Function): diff --git a/google/cloud/firestore_v1/pipeline_stages.py b/google/cloud/firestore_v1/pipeline_stages.py index 01df06c2a..14b3359be 100644 --- a/google/cloud/firestore_v1/pipeline_stages.py +++ b/google/cloud/firestore_v1/pipeline_stages.py @@ -316,7 +316,7 @@ class Unnest(Stage): """Produces a document for each element in an array field.""" def __init__(self, field: Selectable | str, alias: Field | str | None=None, options: UnnestOptions|None=None): super().__init__() - self.field: Field = Field(field) if isinstance(field, str) else field + self.field: Selectable = Field(field) if isinstance(field, str) else field if alias is None: self.alias = self.field elif isinstance(alias, str): From 6a5f72e028647059c049131b5a6280a7768461b5 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 1 Apr 2025 15:12:41 -0700 Subject: [PATCH 082/131] fixed client fixture --- tests/system/test_pipeline_acceptance.py | 25 ++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/tests/system/test_pipeline_acceptance.py b/tests/system/test_pipeline_acceptance.py index c51fd996b..42fcd412e 100644 --- a/tests/system/test_pipeline_acceptance.py +++ b/tests/system/test_pipeline_acceptance.py @@ -36,14 +36,18 @@ test_dir_name = os.path.dirname(__file__) - -def loader(): - # load test cases +def yaml_loader(field="tests"): + """ + loads test cases or data from yaml file + """ with open(f"{test_dir_name}/pipeline_e2e.yaml") as f: test_cases = yaml.safe_load(f) - # load data - data = test_cases["data"] + return test_cases[field] + +@pytest.fixture +def client(): client = Client(project=FIRESTORE_PROJECT, database=FIRESTORE_TEST_DB) + data = yaml_loader("data") try: # setup data batch = client.batch() @@ -54,9 +58,8 @@ def loader(): batch.set(document_ref, document_data) batch.commit() - # run tests - for test in test_cases["tests"]: - yield test + yield client + finally: # clear data for collection_name, documents in data.items(): @@ -123,12 +126,10 @@ def parse_expressions(client, yaml_element: Any): else: return yaml_element - @pytest.mark.parametrize( - "test_dict", loader(), ids=lambda x: f"{x.get('description', '')}" + "test_dict", yaml_loader(), ids=lambda x: f"{x.get('description', '')}" ) -def test_e2e_scenario(test_dict): - client = Client(project=FIRESTORE_PROJECT, database=FIRESTORE_TEST_DB) +def test_e2e_scenario(test_dict, client): error_regex = test_dict.get("assert_error", None) expected_proto = test_dict.get("assert_proto", None) pipeline = parse_pipeline(client, test_dict["pipeline"]) From 072669f2d6528a59f8238358cf5c6f5780dfe7f5 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 1 Apr 2025 15:18:46 -0700 Subject: [PATCH 083/131] fixed mapget in e2e yaml --- tests/system/pipeline_e2e.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/system/pipeline_e2e.yaml b/tests/system/pipeline_e2e.yaml index a12fd1759..7321b7c77 100644 --- a/tests/system/pipeline_e2e.yaml +++ b/tests/system/pipeline_e2e.yaml @@ -1725,7 +1725,7 @@ tests: - ExprWithAlias: - MapGet: - Field: awards - - hugo + - Constant: hugo - "hugoAward" - Field: title - Where: From 05a749826156953165c92bd2a10d5d062d560fe6 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 1 Apr 2025 16:51:27 -0700 Subject: [PATCH 084/131] fixed docstring --- google/cloud/firestore_v1/async_pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/google/cloud/firestore_v1/async_pipeline.py b/google/cloud/firestore_v1/async_pipeline.py index 2dfbaedda..aeba2af23 100644 --- a/google/cloud/firestore_v1/async_pipeline.py +++ b/google/cloud/firestore_v1/async_pipeline.py @@ -33,13 +33,13 @@ class AsyncPipeline(_BasePipeline): defined pipeline stages using an asynchronous `AsyncClient`. Usage Example: - >>> from google.cloud.firestore_v1.pipeline_expressions import Field, gt + >>> from google.cloud.firestore_v1.pipeline_expressions import Field >>> >>> async def run_pipeline(): ... client = AsyncClient(...) ... pipeline = client.collection("books") ... .pipeline() - ... .where(gt(Field.of("published"), 1980)) + ... .where(Field.of("published").gt(1980)) ... .select("title", "author") ... async for result in pipeline.execute_async(): ... print(result) From 9f8d4c895e021830dc1699a53b86ae8967c631b6 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 1 Apr 2025 16:57:41 -0700 Subject: [PATCH 085/131] fixed append bug --- google/cloud/firestore_v1/async_pipeline.py | 3 +++ google/cloud/firestore_v1/base_pipeline.py | 3 ++- google/cloud/firestore_v1/pipeline.py | 3 +++ 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/google/cloud/firestore_v1/async_pipeline.py b/google/cloud/firestore_v1/async_pipeline.py index aeba2af23..31c63eb81 100644 --- a/google/cloud/firestore_v1/async_pipeline.py +++ b/google/cloud/firestore_v1/async_pipeline.py @@ -57,6 +57,9 @@ def __init__(self, client:AsyncClient, *stages: stages.Stage): super().__init__(*stages) self._client = client + def _append(self, new_stage): + return self.__class__(self._client, *self.stages, new_stage) + async def execute_async(self) -> AsyncIterable["ExecutePipelineResponse"]: database_name = f"projects/{self._client.project}/databases/{self._client._database}" request = ExecutePipelineRequest( diff --git a/google/cloud/firestore_v1/base_pipeline.py b/google/cloud/firestore_v1/base_pipeline.py index c1666a240..6582f2789 100644 --- a/google/cloud/firestore_v1/base_pipeline.py +++ b/google/cloud/firestore_v1/base_pipeline.py @@ -58,11 +58,12 @@ def __repr__(self): def _to_pb(self) -> StructuredPipeline_pb: return StructuredPipeline_pb(pipeline={"stages":[s._to_pb() for s in self.stages]}) + def _append(self, new_stage): """ Create a new Pipeline object with a new stage appended """ - return self.__class__((*self.stages, new_stage)) + return self.__class__(*self.stages, new_stage) def add_fields(self, *fields: Selectable) -> Self: """ diff --git a/google/cloud/firestore_v1/pipeline.py b/google/cloud/firestore_v1/pipeline.py index fc0ee428d..ca2ced8b8 100644 --- a/google/cloud/firestore_v1/pipeline.py +++ b/google/cloud/firestore_v1/pipeline.py @@ -54,6 +54,9 @@ def __init__(self, client:Client, *stages: stages.Stage): super().__init__(*stages) self._client = client + def _append(self, new_stage): + return self.__class__(self._client, *self.stages, new_stage) + def execute(self) -> Iterable["ExecutePipelineResponse"]: database_name = f"projects/{self._client.project}/databases/{self._client._database}" request = ExecutePipelineRequest( From 863a09d52cd8d3474b1a8b588a07d177316c3b97 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 1 Apr 2025 16:58:02 -0700 Subject: [PATCH 086/131] compare result data in tests --- tests/system/test_pipeline_acceptance.py | 5 ++++- tests/system/test_system.py | 4 ++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/system/test_pipeline_acceptance.py b/tests/system/test_pipeline_acceptance.py index 42fcd412e..41ea0cf2a 100644 --- a/tests/system/test_pipeline_acceptance.py +++ b/tests/system/test_pipeline_acceptance.py @@ -132,6 +132,7 @@ def parse_expressions(client, yaml_element: Any): def test_e2e_scenario(test_dict, client): error_regex = test_dict.get("assert_error", None) expected_proto = test_dict.get("assert_proto", None) + expected_results = test_dict.get("assert_results", None) pipeline = parse_pipeline(client, test_dict["pipeline"]) # check if proto matches as expected if expected_proto: @@ -139,7 +140,9 @@ def test_e2e_scenario(test_dict, client): assert yaml.dump(expected_proto) == yaml.dump(got_proto) # check if server responds as expected with pytest.raises(GoogleAPIError) if error_regex else nullcontext() as ctx: - pipeline.execute() + got_results = [snapshot.to_dict() for snapshot in pipeline.execute()] + if expected_results: + assert got_results == expected_results # check for error message if expected if error_regex: found_error = str(ctx.value) diff --git a/tests/system/test_system.py b/tests/system/test_system.py index 48af63855..22a3663fd 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -92,7 +92,7 @@ def verify_pipeline(query): query_exception = None query_results = None try: - query_results = query.get() + query_results = [s.to_dict() for s in query.get()] except Exception as e: query_exception = e pipeline = query.pipeline() @@ -102,7 +102,7 @@ def verify_pipeline(query): pipeline.execute() else: # ensure results match query - pipeline_results = pipeline.execute() + pipeline_results = [s.to_dict() for s in pipeline.execute()] assert query_results == pipeline_results @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) From e9f11c620a40bb49d97cf8f6b8ea72a916e7cdc4 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 1 Apr 2025 16:58:31 -0700 Subject: [PATCH 087/131] fixed incorrect test case ordering --- tests/system/pipeline_e2e.yaml | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/system/pipeline_e2e.yaml b/tests/system/pipeline_e2e.yaml index 7321b7c77..7ce7ba98a 100644 --- a/tests/system/pipeline_e2e.yaml +++ b/tests/system/pipeline_e2e.yaml @@ -365,24 +365,24 @@ tests: assert_results: - title: "The Hitchhiker's Guide to the Galaxy" author: "Douglas Adams" - - title: "Pride and Prejudice" - author: "Jane Austen" - - title: "The Handmaid's Tale" - author: "Margaret Atwood" - - title: "Crime and Punishment" - author: "Fyodor Dostoevsky" - title: "The Great Gatsby" author: "F. Scott Fitzgerald" - title: "Dune" author: "Frank Herbert" - - title: "To Kill a Mockingbird" - author: "Harper Lee" + - title: "Crime and Punishment" + author: "Fyodor Dostoevsky" - title: "One Hundred Years of Solitude" author: "Gabriel García Márquez" - title: "1984" author: "George Orwell" + - title: "To Kill a Mockingbird" + author: "Harper Lee" - title: "The Lord of the Rings" author: "J.R.R. Tolkien" + - title: "Pride and Prejudice" + author: "Jane Austen" + - title: "The Handmaid's Tale" + author: "Margaret Atwood" assert_proto: pipeline: stages: From d2c33f208db9ddd005753fd3b592cf9a2b3e68e6 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 1 Apr 2025 17:08:19 -0700 Subject: [PATCH 088/131] yield document snapshots --- google/cloud/firestore_v1/async_pipeline.py | 6 ++++-- google/cloud/firestore_v1/base_pipeline.py | 14 ++++++++++++++ google/cloud/firestore_v1/pipeline.py | 6 +++--- 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/google/cloud/firestore_v1/async_pipeline.py b/google/cloud/firestore_v1/async_pipeline.py index 31c63eb81..1fe0d16a3 100644 --- a/google/cloud/firestore_v1/async_pipeline.py +++ b/google/cloud/firestore_v1/async_pipeline.py @@ -60,11 +60,13 @@ def __init__(self, client:AsyncClient, *stages: stages.Stage): def _append(self, new_stage): return self.__class__(self._client, *self.stages, new_stage) - async def execute_async(self) -> AsyncIterable["ExecutePipelineResponse"]: + async def execute_async(self) -> AsyncIterable["DocumentSnapshot"]: database_name = f"projects/{self._client.project}/databases/{self._client._database}" request = ExecutePipelineRequest( database=database_name, structured_pipeline=self._to_pb(), read_time=datetime.datetime.now(), ) - return await self._client._firestore_api.execute_pipeline(request) + async for response in await self._client._firestore_api.execute_pipeline(request): + for snapshot in self._parse_response(response, self._client): + yield snapshot diff --git a/google/cloud/firestore_v1/base_pipeline.py b/google/cloud/firestore_v1/base_pipeline.py index 6582f2789..99c238dfe 100644 --- a/google/cloud/firestore_v1/base_pipeline.py +++ b/google/cloud/firestore_v1/base_pipeline.py @@ -19,6 +19,7 @@ from google.cloud.firestore_v1.types.pipeline import StructuredPipeline as StructuredPipeline_pb from google.cloud.firestore_v1.vector import Vector from google.cloud.firestore_v1.base_vector_query import DistanceMeasure +from google.cloud.firestore_v1 import _helpers, document from google.cloud.firestore_v1.pipeline_expressions import ( Accumulator, Expr, @@ -65,6 +66,19 @@ def _append(self, new_stage): """ return self.__class__(*self.stages, new_stage) + @staticmethod + def _parse_response(response_pb, client): + for doc in response_pb.results: + data = _helpers.decode_dict(doc.fields, client) + yield document.DocumentSnapshot( + None, + data, + exists=True, + read_time=response_pb._pb.execution_time, + create_time=doc.create_time, + update_time=doc.update_time, + ) + def add_fields(self, *fields: Selectable) -> Self: """ Adds new fields to outputs from previous stages. diff --git a/google/cloud/firestore_v1/pipeline.py b/google/cloud/firestore_v1/pipeline.py index ca2ced8b8..d38b2086e 100644 --- a/google/cloud/firestore_v1/pipeline.py +++ b/google/cloud/firestore_v1/pipeline.py @@ -57,11 +57,11 @@ def __init__(self, client:Client, *stages: stages.Stage): def _append(self, new_stage): return self.__class__(self._client, *self.stages, new_stage) - def execute(self) -> Iterable["ExecutePipelineResponse"]: + def execute(self) -> Iterable["DocumentSnapshot"]: database_name = f"projects/{self._client.project}/databases/{self._client._database}" request = ExecutePipelineRequest( database=database_name, structured_pipeline=self._to_pb(), ) - results = self._client._firestore_api.execute_pipeline(request) - return results \ No newline at end of file + for response in self._client._firestore_api.execute_pipeline(request): + yield from self._parse_response(response, self) \ No newline at end of file From db1ecaa62bde0f18403937208a06b403ff5c7610 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 1 Apr 2025 17:16:06 -0700 Subject: [PATCH 089/131] added async test --- tests/system/test_pipeline_acceptance.py | 32 ++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/tests/system/test_pipeline_acceptance.py b/tests/system/test_pipeline_acceptance.py index 41ea0cf2a..24ecd11d6 100644 --- a/tests/system/test_pipeline_acceptance.py +++ b/tests/system/test_pipeline_acceptance.py @@ -29,7 +29,7 @@ from google.cloud.firestore_v1.pipeline import Pipeline from google.api_core.exceptions import GoogleAPIError -from google.cloud.firestore import Client +from google.cloud.firestore import Client, AsyncClient FIRESTORE_TEST_DB = os.environ.get("SYSTEM_TESTS_DATABASE", "system-tests-named-db") FIRESTORE_PROJECT = os.environ.get("GCLOUD_PROJECT") @@ -68,6 +68,9 @@ def client(): document_ref = collection_ref.document(document_id) document_ref.delete() +@pytest.fixture +def async_client(client): + yield AsyncClient(project=client.project, database=client._database) def _apply_yaml_args(cls, client, yaml_args): if isinstance(yaml_args, dict): @@ -97,7 +100,7 @@ def parse_pipeline(client, pipeline: list[dict[str, Any], str]): # yaml has no arguments stage_obj = stage_cls() result_list.append(stage_obj) - return Pipeline(client, *result_list) + return client.pipeline(*result_list) def _is_expr_string(yaml_str): return isinstance(yaml_str, str) and \ @@ -144,6 +147,31 @@ def test_e2e_scenario(test_dict, client): if expected_results: assert got_results == expected_results # check for error message if expected + if error_regex: + found_error = str(ctx.value) + match = re.search(error_regex, found_error) + assert match, f"error '{found_error}' does not match '{error_regex}'" + + +@pytest.mark.parametrize( + "test_dict", yaml_loader(), ids=lambda x: f"{x.get('description', '')}" +) +@pytest.mark.asyncio +async def test_e2e_scenario_async(test_dict, async_client): + error_regex = test_dict.get("assert_error", None) + expected_proto = test_dict.get("assert_proto", None) + expected_results = test_dict.get("assert_results", None) + pipeline = parse_pipeline(async_client, test_dict["pipeline"]) + # check if proto matches as expected + if expected_proto: + got_proto = MessageToDict(pipeline._to_pb()._pb) + assert yaml.dump(expected_proto) == yaml.dump(got_proto) + # check if server responds as expected + with pytest.raises(GoogleAPIError) if error_regex else nullcontext() as ctx: + got_results = [snapshot.to_dict() async for snapshot in pipeline.execute_async()] + if expected_results: + assert got_results == expected_results + # check for error message if expected if error_regex: found_error = str(ctx.value) match = re.search(error_regex, found_error) From b6d74dafb90bda91fe600aab400c7f92b7d19b2e Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 1 Apr 2025 17:26:46 -0700 Subject: [PATCH 090/131] cleaning up e2e yaml --- tests/system/pipeline_e2e.yaml | 48 +++++++--------------------------- 1 file changed, 9 insertions(+), 39 deletions(-) diff --git a/tests/system/pipeline_e2e.yaml b/tests/system/pipeline_e2e.yaml index 7ce7ba98a..7dc5df301 100644 --- a/tests/system/pipeline_e2e.yaml +++ b/tests/system/pipeline_e2e.yaml @@ -437,24 +437,24 @@ tests: assert_results: - author: Douglas Adams author_title: Douglas Adams_The Hitchhiker's Guide to the Galaxy - - author: Jane Austen - author_title: Jane Austen_Pride and Prejudice - - author: Margaret Atwood - author_title: Margaret Atwood_The Handmaid's Tale - - author: Fyodor Dostoevsky - author_title: Fyodor Dostoevsky_Crime and Punishment - author: F. Scott Fitzgerald author_title: F. Scott Fitzgerald_The Great Gatsby - author: Frank Herbert author_title: Frank Herbert_Dune - - author: Harper Lee - author_title: Harper Lee_To Kill a Mockingbird + - author: Fyodor Dostoevsky + author_title: Fyodor Dostoevsky_Crime and Punishment - author: Gabriel García Márquez author_title: Gabriel García Márquez_One Hundred Years of Solitude - author: George Orwell author_title: George Orwell_1984 + - author: Harper Lee + author_title: Harper Lee_To Kill a Mockingbird - author: J.R.R. Tolkien author_title: J.R.R. Tolkien_The Lord of the Rings + - author: Jane Austen + author_title: Jane Austen_Pride and Prejudice + - author: Margaret Atwood + author_title: Margaret Atwood_The Handmaid's Tale assert_proto: pipeline: stages: @@ -599,7 +599,7 @@ tests: - title - author assert_results: - - title: 1984 + - title: "1984" author: George Orwell - title: To Kill a Mockingbird author: Harper Lee @@ -642,36 +642,6 @@ tests: - Constant: comedy assert_results: - title: The Hitchhiker's Guide to the Galaxy - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - fieldReferenceValue: title - - stringValue: '%Guide%' - name: like - name: where - - args: - - mapValue: - fields: - title: - fieldReferenceValue: title - name: select - author: Douglas Adams - genre: Science Fiction - published: 1979 - rating: 4.2 - tags: - - comedy - - space - - adventure - awards: - hugo: true - nebula: false assert_proto: pipeline: stages: From 70efea765154fb98f84e3e93126e17b25f1cd8ef Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 2 Apr 2025 17:13:34 -0700 Subject: [PATCH 091/131] fixed scope in tests --- tests/system/test_pipeline_acceptance.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/tests/system/test_pipeline_acceptance.py b/tests/system/test_pipeline_acceptance.py index 24ecd11d6..6b58874dd 100644 --- a/tests/system/test_pipeline_acceptance.py +++ b/tests/system/test_pipeline_acceptance.py @@ -44,7 +44,17 @@ def yaml_loader(field="tests"): test_cases = yaml.safe_load(f) return test_cases[field] -@pytest.fixture +@pytest.fixture(scope="session") +def event_loop(): + import asyncio + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + yield loop + loop.close() + +@pytest.fixture(scope="module") def client(): client = Client(project=FIRESTORE_PROJECT, database=FIRESTORE_TEST_DB) data = yaml_loader("data") @@ -57,9 +67,7 @@ def client(): document_ref = collection_ref.document(document_id) batch.set(document_ref, document_data) batch.commit() - yield client - finally: # clear data for collection_name, documents in data.items(): @@ -68,7 +76,7 @@ def client(): document_ref = collection_ref.document(document_id) document_ref.delete() -@pytest.fixture +@pytest.fixture(scope="module") def async_client(client): yield AsyncClient(project=client.project, database=client._database) From b36afa8409a9e3bf84a4ca2604006d3dc8ed92b9 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 2 Apr 2025 17:20:38 -0700 Subject: [PATCH 092/131] added _client to base pipeline --- google/cloud/firestore_v1/async_pipeline.py | 6 +----- google/cloud/firestore_v1/base_pipeline.py | 9 +++++++-- google/cloud/firestore_v1/pipeline.py | 6 +----- 3 files changed, 9 insertions(+), 12 deletions(-) diff --git a/google/cloud/firestore_v1/async_pipeline.py b/google/cloud/firestore_v1/async_pipeline.py index 1fe0d16a3..8b745d2ba 100644 --- a/google/cloud/firestore_v1/async_pipeline.py +++ b/google/cloud/firestore_v1/async_pipeline.py @@ -54,11 +54,7 @@ def __init__(self, client:AsyncClient, *stages: stages.Stage): client: The asynchronous `AsyncClient` instance to use for execution. *stages: Initial stages for the pipeline. """ - super().__init__(*stages) - self._client = client - - def _append(self, new_stage): - return self.__class__(self._client, *self.stages, new_stage) + super().__init__(client, *stages) async def execute_async(self) -> AsyncIterable["DocumentSnapshot"]: database_name = f"projects/{self._client.project}/databases/{self._client._database}" diff --git a/google/cloud/firestore_v1/base_pipeline.py b/google/cloud/firestore_v1/base_pipeline.py index 99c238dfe..2a2cbad25 100644 --- a/google/cloud/firestore_v1/base_pipeline.py +++ b/google/cloud/firestore_v1/base_pipeline.py @@ -16,6 +16,7 @@ from typing import Optional, Sequence from typing_extensions import Self from google.cloud.firestore_v1 import pipeline_stages as stages +from google.cloud.firestore_v1.base_client import BaseClient from google.cloud.firestore_v1.types.pipeline import StructuredPipeline as StructuredPipeline_pb from google.cloud.firestore_v1.vector import Vector from google.cloud.firestore_v1.base_vector_query import DistanceMeasure @@ -38,13 +39,17 @@ class _BasePipeline: This class is not intended to be instantiated directly. Use `client.collection.("...").pipeline()` to create pipeline instances. """ - def __init__(self, *stages: stages.Stage): + def __init__(self, client: BaseClient, *stages: stages.Stage): """ Initializes a new pipeline with the given stages. + Pipeline classes should not be instantiated directly. + Args: + client: The client associated with the pipeline *stages: Initial stages for the pipeline. """ + self._client = client self.stages = tuple(stages) def __repr__(self): @@ -64,7 +69,7 @@ def _append(self, new_stage): """ Create a new Pipeline object with a new stage appended """ - return self.__class__(*self.stages, new_stage) + return self.__class__(self._client, *self.stages, new_stage) @staticmethod def _parse_response(response_pb, client): diff --git a/google/cloud/firestore_v1/pipeline.py b/google/cloud/firestore_v1/pipeline.py index d38b2086e..f17a68ed7 100644 --- a/google/cloud/firestore_v1/pipeline.py +++ b/google/cloud/firestore_v1/pipeline.py @@ -51,11 +51,7 @@ def __init__(self, client:Client, *stages: stages.Stage): client: The `Client` instance to use for execution. *stages: Initial stages for the pipeline. """ - super().__init__(*stages) - self._client = client - - def _append(self, new_stage): - return self.__class__(self._client, *self.stages, new_stage) + super().__init__(client, *stages) def execute(self) -> Iterable["DocumentSnapshot"]: database_name = f"projects/{self._client.project}/databases/{self._client._database}" From 7a26f7eb8855eddfd031a03ecefd2cb0919cec16 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 2 Apr 2025 17:36:34 -0700 Subject: [PATCH 093/131] fixing test yaml --- tests/system/pipeline_e2e.yaml | 62 +++++++++++++++++++--------------- 1 file changed, 35 insertions(+), 27 deletions(-) diff --git a/tests/system/pipeline_e2e.yaml b/tests/system/pipeline_e2e.yaml index 7dc5df301..38c9b5001 100644 --- a/tests/system/pipeline_e2e.yaml +++ b/tests/system/pipeline_e2e.yaml @@ -638,10 +638,18 @@ tests: - Collection: books - Where: - ArrayContains: - - Constant: tags + - Field: tags - Constant: comedy assert_results: - title: The Hitchhiker's Guide to the Galaxy + author: Douglas Adams + awards: + hugo: true + nebula: false + genre: Science Fiction + published: 1979 + rating: 4.2 + tags: ["comedy", "space", "adventure"] assert_proto: pipeline: stages: @@ -651,7 +659,7 @@ tests: - args: - functionValue: args: - - stringValue: tags + - fieldReferenceValue: tags - stringValue: comedy name: array_contains name: where @@ -846,7 +854,7 @@ tests: - Where: - StartsWith: - Field: title - - Constant: the + - Constant: The - Select: - title - Sort: @@ -868,7 +876,7 @@ tests: - functionValue: args: - fieldReferenceValue: title - - stringValue: the + - stringValue: The name: starts_with name: where - args: @@ -1089,23 +1097,30 @@ tests: - description: testStringFunctions - CharLength pipeline: - Collection: books + - Where: + - Eq: + - Field: author + - Constant: "Douglas Adams" - Select: - ExprWithAlias: - CharLength: - Field: title - "title_length" - - Where: - - Eq: - - Field: author - - Constant: Douglas Adams assert_results: - - title_length: 30 + - title_length: 36 assert_proto: pipeline: stages: - args: - referenceValue: /books name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: Douglas Adams + name: eq + name: where - args: - mapValue: fields: @@ -1115,16 +1130,13 @@ tests: - fieldReferenceValue: title name: char_length name: select - - args: - - functionValue: - args: - - fieldReferenceValue: author - - stringValue: Douglas Adams - name: eq - name: where - description: testStringFunctions - ByteLength pipeline: - Collection: books + - Where: + - Eq: + - Field: author + - Constant: Douglas Adams - Select: - ExprWithAlias: - ByteLength: @@ -1132,10 +1144,6 @@ tests: - Field: title - Constant: _银河系漫游指南 - "title_byte_length" - - Where: - - Eq: - - Field: author - - Constant: Douglas Adams assert_results: - title_byte_length: 42 assert_proto: @@ -1144,6 +1152,13 @@ tests: - args: - referenceValue: /books name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: Douglas Adams + name: eq + name: where - args: - mapValue: fields: @@ -1157,13 +1172,6 @@ tests: name: str_concat name: byte_length name: select - - args: - - functionValue: - args: - - fieldReferenceValue: author - - stringValue: Douglas Adams - name: eq - name: where - description: testToLowercase pipeline: - Collection: books From fee6c90653e5a78bc21d8973b22751043dffa14c Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 3 Apr 2025 16:21:17 -0700 Subject: [PATCH 094/131] renamed exceute_async to execute --- google/cloud/firestore_v1/async_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/cloud/firestore_v1/async_pipeline.py b/google/cloud/firestore_v1/async_pipeline.py index 8b745d2ba..07fe3ddbf 100644 --- a/google/cloud/firestore_v1/async_pipeline.py +++ b/google/cloud/firestore_v1/async_pipeline.py @@ -56,7 +56,7 @@ def __init__(self, client:AsyncClient, *stages: stages.Stage): """ super().__init__(client, *stages) - async def execute_async(self) -> AsyncIterable["DocumentSnapshot"]: + async def execute(self) -> AsyncIterable["DocumentSnapshot"]: database_name = f"projects/{self._client.project}/databases/{self._client._database}" request = ExecutePipelineRequest( database=database_name, From da18289bf3c4225846bc7ad2cbed913897ff907d Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 3 Apr 2025 16:22:23 -0700 Subject: [PATCH 095/131] broke up pipeline tests into separate functions --- tests/system/test_pipeline_acceptance.py | 106 ++++++++++++++++------- 1 file changed, 76 insertions(+), 30 deletions(-) diff --git a/tests/system/test_pipeline_acceptance.py b/tests/system/test_pipeline_acceptance.py index 6b58874dd..644cb6dc6 100644 --- a/tests/system/test_pipeline_acceptance.py +++ b/tests/system/test_pipeline_acceptance.py @@ -138,49 +138,95 @@ def parse_expressions(client, yaml_element: Any): return yaml_element @pytest.mark.parametrize( - "test_dict", yaml_loader(), ids=lambda x: f"{x.get('description', '')}" + "test_dict", + [t for t in yaml_loader() if "assert_proto" in t], + ids=lambda x: f"{x.get('description', '')}" ) -def test_e2e_scenario(test_dict, client): - error_regex = test_dict.get("assert_error", None) +def test_pipeline_parse_proto(test_dict, client): + """ + Finds assert_proto statements in yaml, and compares generated proto against expected value + """ expected_proto = test_dict.get("assert_proto", None) - expected_results = test_dict.get("assert_results", None) pipeline = parse_pipeline(client, test_dict["pipeline"]) # check if proto matches as expected if expected_proto: got_proto = MessageToDict(pipeline._to_pb()._pb) assert yaml.dump(expected_proto) == yaml.dump(got_proto) + +@pytest.mark.parametrize( + "test_dict", + [t for t in yaml_loader() if "assert_error" in t], + ids=lambda x: f"{x.get('description', '')}" +) +def test_pipeline_expected_errors(test_dict, client): + """ + Finds assert_error statements in yaml, and ensures the pipeline raises the expected error + """ + error_regex = test_dict["assert_error"] + pipeline = parse_pipeline(client, test_dict["pipeline"]) # check if server responds as expected - with pytest.raises(GoogleAPIError) if error_regex else nullcontext() as ctx: - got_results = [snapshot.to_dict() for snapshot in pipeline.execute()] - if expected_results: - assert got_results == expected_results - # check for error message if expected - if error_regex: - found_error = str(ctx.value) - match = re.search(error_regex, found_error) - assert match, f"error '{found_error}' does not match '{error_regex}'" + with pytest.raises(GoogleAPIError) as err: + [_ for _ in pipeline.execute()] + found_error = str(err.value) + match = re.search(error_regex, found_error) + assert match, f"error '{found_error}' does not match '{error_regex}'" @pytest.mark.parametrize( - "test_dict", yaml_loader(), ids=lambda x: f"{x.get('description', '')}" + "test_dict", + [t for t in yaml_loader() if "assert_results" in t or "assert_count" in t], + ids=lambda x: f"{x.get('description', '')}" +) +def test_pipeline_results(test_dict, client): + """ + Ensure pipeline returns expected results + """ + expected_results = test_dict.get("assert_results", None) + expected_count = test_dict.get("assert_count", None) + pipeline = parse_pipeline(client, test_dict["pipeline"]) + # check if server responds as expected + got_results = [snapshot.to_dict() for snapshot in pipeline.execute()] + if expected_results: + assert got_results == expected_results + if expected_count is not None: + assert len(got_results) == expected_count + +@pytest.mark.parametrize( + "test_dict", + [t for t in yaml_loader() if "assert_error" in t], + ids=lambda x: f"{x.get('description', '')}" ) @pytest.mark.asyncio -async def test_e2e_scenario_async(test_dict, async_client): - error_regex = test_dict.get("assert_error", None) - expected_proto = test_dict.get("assert_proto", None) +async def test_pipeline_expected_errors_async(test_dict, async_client): + """ + Finds assert_error statements in yaml, and ensures the pipeline raises the expected error + """ + error_regex = test_dict["assert_error"] + pipeline = parse_pipeline(async_client, test_dict["pipeline"]) + # check if server responds as expected + with pytest.raises(GoogleAPIError) as err: + [_ async for _ in pipeline.execute()] + found_error = str(err.value) + match = re.search(error_regex, found_error) + assert match, f"error '{found_error}' does not match '{error_regex}'" + +@pytest.mark.parametrize( + "test_dict", + [t for t in yaml_loader() if "assert_results" in t or "assert_count" in t], + ids=lambda x: f"{x.get('description', '')}" +) +@pytest.mark.asyncio +async def test_pipeline_results_async(test_dict, async_client): + """ + Ensure pipeline returns expected results + """ expected_results = test_dict.get("assert_results", None) + expected_count = test_dict.get("assert_count", None) pipeline = parse_pipeline(async_client, test_dict["pipeline"]) - # check if proto matches as expected - if expected_proto: - got_proto = MessageToDict(pipeline._to_pb()._pb) - assert yaml.dump(expected_proto) == yaml.dump(got_proto) # check if server responds as expected - with pytest.raises(GoogleAPIError) if error_regex else nullcontext() as ctx: - got_results = [snapshot.to_dict() async for snapshot in pipeline.execute_async()] - if expected_results: - assert got_results == expected_results - # check for error message if expected - if error_regex: - found_error = str(ctx.value) - match = re.search(error_regex, found_error) - assert match, f"error '{found_error}' does not match '{error_regex}'" \ No newline at end of file + got_results = [snapshot.to_dict() async for snapshot in pipeline.execute()] + if expected_results: + assert got_results == expected_results + if expected_count is not None: + assert len(got_results) == expected_count + From fc80062af3ff9c13dfcbac8e397a041cc84c5e2a Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 29 Apr 2025 17:15:27 -0700 Subject: [PATCH 096/131] improved faulty test yaml --- tests/system/pipeline_e2e.yaml | 295 +++++++++++++++++++-------------- 1 file changed, 174 insertions(+), 121 deletions(-) diff --git a/tests/system/pipeline_e2e.yaml b/tests/system/pipeline_e2e.yaml index 38c9b5001..38a1928c5 100644 --- a/tests/system/pipeline_e2e.yaml +++ b/tests/system/pipeline_e2e.yaml @@ -269,13 +269,17 @@ tests: - Gt: - Field: avg_rating - Constant: 4.3 + - Sort: + - Ordering: + - Field: avg_rating + - ASCENDING assert_results: - - avg_rating: 4.7 - genre: Fantasy - - avg_rating: 4.5 - genre: Romance - avg_rating: 4.4 genre: Science Fiction + - avg_rating: 4.5 + genre: Romance + - avg_rating: 4.7 + genre: Fantasy assert_proto: pipeline: stages: @@ -309,6 +313,14 @@ tests: - doubleValue: 4.3 name: gt name: where + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: avg_rating + name: sort - description: testMinMax pipeline: - Collection: books @@ -555,10 +567,14 @@ tests: - Constant: Dystopian - Select: - title + - Sort: + - Ordering: + - Field: title + - ASCENDING assert_results: + - title: "1984" - title: Pride and Prejudice - title: The Handmaid's Tale - - title: 1984 assert_proto: pipeline: stages: @@ -586,6 +602,14 @@ tests: title: fieldReferenceValue: title name: select + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: title + name: sort - description: testPipelineWithOffsetAndLimit pipeline: - Collection: books @@ -673,9 +697,13 @@ tests: - Constant: classic - Select: - title + - Sort: + - Ordering: + - Field: title + - ASCENDING assert_results: - - title: The Hitchhiker's Guide to the Galaxy - title: Pride and Prejudice + - title: The Hitchhiker's Guide to the Galaxy assert_proto: pipeline: stages: @@ -698,6 +726,14 @@ tests: title: fieldReferenceValue: title name: select + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: title + name: sort - description: testArrayContainsAll pipeline: - Collection: books @@ -818,6 +854,10 @@ tests: - description: testStrConcat pipeline: - Collection: books + - Sort: + - Ordering: + - Field: author + - ASCENDING - Select: - ExprWithAlias: - StrConcat: @@ -834,6 +874,14 @@ tests: - args: - referenceValue: /books name: collection + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: author + name: sort - args: - mapValue: fields: @@ -949,11 +997,19 @@ tests: - Gt: - Field: titleLength - Constant: 20 + - Sort: + - Ordering: + - Field: title + - ASCENDING assert_results: - - titleLength: 32 - title: The Hitchhiker's Guide to the Galaxy - - titleLength: 27 + - titleLength: 29 title: One Hundred Years of Solitude + - titleLength: 36 + title: The Hitchhiker's Guide to the Galaxy + - titleLength: 21 + title: The Lord of the Rings + - titleLength: 21 + title: To Kill a Mockingbird assert_proto: pipeline: stages: @@ -978,6 +1034,14 @@ tests: - integerValue: '20' name: gt name: where + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: title + name: sort - description: testStringFunctions - Reverse pipeline: - Collection: books @@ -1145,7 +1209,7 @@ tests: - Constant: _银河系漫游指南 - "title_byte_length" assert_results: - - title_byte_length: 42 + - title_byte_length: 58 assert_proto: pipeline: stages: @@ -1293,18 +1357,14 @@ tests: assert_results: - title: The Hitchhiker's Guide to the Galaxy - description: testRegexContains + # Find titles that contain either "the" or "of" (case-insensitive) pipeline: - Collection: books - Where: - RegexContains: - Field: title - Constant: "(?i)(the|of)" - assert_results: - - title: The Hitchhiker's Guide to the Galaxy - - title: One Hundred Years of Solitude - - title: The Lord of the Rings - - title: To Kill a Mockingbird - - title: The Great Gatsby + assert_count: 5 assert_proto: pipeline: stages: @@ -1315,9 +1375,18 @@ tests: - functionValue: args: - fieldReferenceValue: title - - stringValue: .*(?i)(the|of).* - name: regex_match + - stringValue: "(?i)(the|of)" + name: regex_contains name: where + - description: testRegexMatches + # Find titles that contain either "the" or "of" (case-insensitive) + pipeline: + - Collection: books + - Where: + - RegexMatch: + - Field: title + - Constant: ".*(?i)(the|of).*" + assert_count: 5 assert_proto: pipeline: stages: @@ -1328,25 +1397,16 @@ tests: - functionValue: args: - fieldReferenceValue: title - - stringValue: (?i)(the|of) - name: regex_contains + - stringValue: "(?i)(the|of)" + name: regex_matches name: where - - description: testRegexMatches - pipeline: - - Collection: books - - Where: - - RegexMatch: - - Field: title - - Constant: ".*(?i)(the|of).*" - assert_results: - - title: The Hitchhiker's Guide to the Galaxy - - title: One Hundred Years of Solitude - - title: The Lord of the Rings - - title: To Kill a Mockingbird - - title: The Great Gatsby - description: testArithmeticOperations pipeline: - Collection: books + - Where: + - Eq: + - Field: title + - Constant: To Kill a Mockingbird - Select: - ExprWithAlias: - Add: @@ -1368,18 +1428,42 @@ tests: - Field: rating - Constant: 2 - "ratingDividedByTwo" - - Limit: 1 + - ExprWithAlias: + - Multiply: + - Field: rating + - Constant: 20 + - "ratingTimes20" + - ExprWithAlias: + - Add: + - Field: rating + - Constant: 3 + - "ratingPlus3" + - ExprWithAlias: + - Mod: + - Field: rating + - Constant: 2 + - "ratingMod2" assert_results: - ratingPlusOne: 5.2 - yearsSince1900: 79 + yearsSince1900: 60 ratingTimesTen: 42.0 ratingDividedByTwo: 2.1 + ratingTimes20: 84 + ratingPlus3: 7.2 + ratingMod2: 0.20000000000000018 assert_proto: pipeline: stages: - args: - referenceValue: /books name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: To Kill a Mockingbird + name: eq + name: where - args: - mapValue: fields: @@ -1407,10 +1491,25 @@ tests: - fieldReferenceValue: published - integerValue: '1900' name: subtract + ratingTimes20: + functionValue: + args: + - fieldReferenceValue: rating + - integerValue: '20' + name: multiply + ratingPlus3: + functionValue: + args: + - fieldReferenceValue: rating + - integerValue: '3' + name: add + ratingMod2: + functionValue: + args: + - fieldReferenceValue: rating + - integerValue: '2' + name: mod name: select - - args: - - integerValue: '1' - name: limit - description: testComparisonOperators pipeline: - Collection: books @@ -1657,48 +1756,13 @@ tests: - doubleValue: 4.5 name: logical_maximum name: select - - description: testLogicalMinMax - min - pipeline: - - Collection: books - - Select: - - ExprWithAlias: - - LogicalMin: - - Field: rating - - Constant: 4.5 - - "min_rating" - - ExprWithAlias: - - LogicalMin: - - Field: published - - Constant: 1900 - - "min_published" - assert_results: - - min_rating: 4.2 - min_published: 1900 - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - mapValue: - fields: - min_published: - functionValue: - args: - - fieldReferenceValue: published - - integerValue: '1900' - name: logical_minimum - min_rating: - functionValue: - args: - - fieldReferenceValue: rating - - doubleValue: 4.5 - name: logical_minimum - name: select - description: testMapGet pipeline: - Collection: books + - Sort: + - Ordering: + - Field: published + - DESCENDING - Select: - ExprWithAlias: - MapGet: @@ -1721,6 +1785,14 @@ tests: - args: - referenceValue: /books name: collection + - args: + - mapValue: + fields: + direction: + stringValue: descending + expression: + fieldReferenceValue: published + name: sort - args: - mapValue: fields: @@ -1820,6 +1892,10 @@ tests: - Eq: - Field: awards.hugo - Constant: true + - Sort: + - Ordering: + - Field: title + - DESCENDING - Select: - title - Field: awards.hugo @@ -1844,26 +1920,11 @@ tests: - args: - mapValue: fields: - __name__: - fieldReferenceValue: __name__ - awards.hugo: - fieldReferenceValue: awards.hugo - title: + direction: + stringValue: descending + expression: fieldReferenceValue: title - name: select - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - fieldReferenceValue: awards.hugo - - booleanValue: true - name: eq - name: where + name: sort - args: - mapValue: fields: @@ -1872,22 +1933,6 @@ tests: title: fieldReferenceValue: title name: select - - description: testPipelineInTransactions - pipeline: - - Collection: books - - Where: - - Eq: - - Field: awards.hugo - - Constant: true - - Select: - - title - - Field: awards.hugo - - Field: "__name__" - assert_results: - - title: The Hitchhiker's Guide to the Galaxy - awards.hugo: true - - title: Dune - awards.hugo: true - description: testReplace pipeline: - Collection: books @@ -1947,7 +1992,6 @@ tests: - SampleOptions: - 0.6 - percent - assert_count: 6 # Results will vary due to randomness assert_proto: pipeline: stages: @@ -1982,14 +2026,17 @@ tests: pipeline: - Collection: books - Where: - - Eq: - - Field: title - - Constant: The Hitchhiker's Guide to the Galaxy - - Unnest: tags + - Eq: + - Field: title + - Constant: The Hitchhiker's Guide to the Galaxy + - Unnest: + - tags + - tags_alias + - Select: tags_alias assert_results: - - tags: comedy - - tags: space - - tags: adventure + - tags_alias: comedy + - tags_alias: space + - tags_alias: adventure assert_proto: pipeline: stages: @@ -2005,5 +2052,11 @@ tests: name: where - args: - fieldReferenceValue: tags - - fieldReferenceValue: tags + - fieldReferenceValue: tags_alias name: unnest + - args: + - mapValue: + fields: + tags_alias: + fieldReferenceValue: tags_alias + name: select From c049e2107e87c6a8c5a0a693136efde3992c8832 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 2 May 2025 14:09:33 -0700 Subject: [PATCH 097/131] remvoved unready stages and expressions --- google/cloud/firestore_v1/base_pipeline.py | 48 -- .../firestore_v1/pipeline_expressions.py | 212 --------- google/cloud/firestore_v1/pipeline_stages.py | 16 - tests/system/pipeline_e2e.yaml | 424 +----------------- 4 files changed, 1 insertion(+), 699 deletions(-) diff --git a/google/cloud/firestore_v1/base_pipeline.py b/google/cloud/firestore_v1/base_pipeline.py index 2a2cbad25..56a8b2d59 100644 --- a/google/cloud/firestore_v1/base_pipeline.py +++ b/google/cloud/firestore_v1/base_pipeline.py @@ -290,54 +290,6 @@ def sort(self, *orders: stages.Ordering) -> Self: """ return self._append(stages.Sort(*orders)) - def replace( - self, - field: Selectable, - mode: stages.Replace.Mode = stages.Replace.Mode.FULL_REPLACE, - ) -> Self: - """ - Replaces the entire document content with the value of a specified field, - typically a map. - - This stage allows you to emit a map value as the new document structure. - Each key of the map becomes a field in the output document, containing the - corresponding value. - - Example: - Input document: - ```json - { - "name": "John Doe Jr.", - "parents": { - "father": "John Doe Sr.", - "mother": "Jane Doe" - } - } - ``` - - >>> from google.cloud.firestore_v1.pipeline_expressions import Field - >>> pipeline = client.collection("people").pipeline() - >>> # Emit the 'parents' map as the document - >>> pipeline = pipeline.replace(Field.of("parents")) - - Output document: - ```json - { - "father": "John Doe Sr.", - "mother": "Jane Doe" - } - ``` - - Args: - field: The `Selectable` field containing the map whose content will - replace the document. - mode: The replacement mode - - Returns: - A new Pipeline object with this stage appended to the stage list - """ - return self._append(stages.Replace(field, mode)) - def sample(self, limit_or_options: int | SampleOptions) -> Self: """ Performs a pseudo-random sampling of the documents from the previous stage. diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index 7acd764b3..c26b2a4cb 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -398,21 +398,6 @@ def not_in_any(self, array: List[Expr | CONSTANT_TYPE]) -> "Not": """ return Not(self.in_any(array)) - def array_concat(self, array: List[Expr | CONSTANT_TYPE]) -> "ArrayConcat": - """Creates an expression that concatenates an array expression with another array. - - Example: - >>> # Combine the 'tags' array with a new array and an array field - >>> Field.of("tags").array_concat(["newTag1", "newTag2", Field.of("otherTag")]) - - Args: - array: The list of constants or expressions to concat with. - - Returns: - A new `Expr` representing the concatenated array. - """ - return ArrayConcat(self, [self._cast_to_expr_or_convert_to_constant(o) for o in array]) - def array_contains(self, element: Expr | CONSTANT_TYPE) -> "ArrayContains": """Creates an expression that checks if an array contains a specific element or value. @@ -717,92 +702,6 @@ def str_concat(self, *elements: Expr | CONSTANT_TYPE) -> "StrConcat": """ return StrConcat(*[self._cast_to_expr_or_convert_to_constant(el) for el in elements]) - def to_lower(self) -> "ToLower": - """Creates an expression that converts a string to lowercase. - - Example: - >>> # Convert the 'name' field to lowercase - >>> Field.of("name").to_lower() - - Returns: - A new `Expr` representing the lowercase string. - """ - return ToLower(self) - - def to_upper(self) -> "ToUpper": - """Creates an expression that converts a string to uppercase. - - Example: - >>> # Convert the 'title' field to uppercase - >>> Field.of("title").to_upper() - - Returns: - A new `Expr` representing the uppercase string. - """ - return ToUpper(self) - - def trim(self) -> "Trim": - """Creates an expression that removes leading and trailing whitespace from a string. - - Example: - >>> # Trim whitespace from the 'userInput' field - >>> Field.of("userInput").trim() - - Returns: - A new `Expr` representing the trimmed string. - """ - return Trim(self) - - def reverse(self) -> "Reverse": - """Creates an expression that reverses a string. - - Example: - >>> # Reverse the 'userInput' field - >>> Field.of("userInput").reverse() - - Returns: - A new `Expr` representing the reversed string. - """ - return Reverse(self) - - def replace_first(self, find: Expr | str, replace: Expr | str) -> "ReplaceFirst": - """Creates an expression that replaces the first occurrence of a substring within a string with - another substring. - - Example: - >>> # Replace the first occurrence of "hello" with "hi" in the 'message' field - >>> Field.of("message").replace_first("hello", "hi") - >>> # Replace the first occurrence of the value in 'findField' with the value in 'replaceField' in the 'message' field - >>> Field.of("message").replace_first(Field.of("findField"), Field.of("replaceField")) - - Args: - find: The substring (string or expression) to search for. - replace: The substring (string or expression) to replace the first occurrence of 'find' with. - - Returns: - A new `Expr` representing the string with the first occurrence replaced. - """ - return ReplaceFirst(self, self._cast_to_expr_or_convert_to_constant(find), self._cast_to_expr_or_convert_to_constant(replace)) - - def replace_all(self, find: Expr | str, replace: Expr | str) -> "ReplaceAll": - """Creates an expression that replaces all occurrences of a substring within a string with another - substring. - - Example: - >>> # Replace all occurrences of "hello" with "hi" in the 'message' field - >>> Field.of("message").replace_all("hello", "hi") - >>> # Replace all occurrences of the value in 'findField' with the value in 'replaceField' in the 'message' field - >>> Field.of("message").replace_all(Field.of("findField"), Field.of("replaceField")) - - Args: - find: The substring (string or expression) to search for. - replace: The substring (string or expression) to replace all occurrences of 'find' with. - - Returns: - A new `Expr` representing the string with all occurrences replaced. - """ - return ReplaceAll(self, self._cast_to_expr_or_convert_to_constant(find), self._cast_to_expr_or_convert_to_constant(replace)) - def map_get(self, key: str) -> "MapGet": """Accesses a value from a map (object) field using the provided key. @@ -819,57 +718,6 @@ def map_get(self, key: str) -> "MapGet": """ return MapGet(self, Constant.of(key)) - def cosine_distance(self, other: Expr | list[float] | Vector) -> "CosineDistance": - """Calculates the cosine distance between two vectors. - - Example: - >>> # Calculate the cosine distance between the 'userVector' field and the 'itemVector' field - >>> Field.of("userVector").cosine_distance(Field.of("itemVector")) - >>> # Calculate the Cosine distance between the 'location' field and a target location - >>> Field.of("location").cosine_distance([37.7749, -122.4194]) - - Args: - other: The other vector (represented as an Expr, list of floats, or Vector) to compare against. - - Returns: - A new `Expr` representing the cosine distance between the two vectors. - """ - return CosineDistance(self, self._cast_to_expr_or_convert_to_constant(other)) - - def euclidean_distance(self, other: Expr | list[float] | Vector) -> "EuclideanDistance": - """Calculates the Euclidean distance between two vectors. - - Example: - >>> # Calculate the Euclidean distance between the 'location' field and a target location - >>> Field.of("location").euclidean_distance([37.7749, -122.4194]) - >>> # Calculate the Euclidean distance between two vector fields: 'pointA' and 'pointB' - >>> Field.of("pointA").euclidean_distance(Field.of("pointB")) - - Args: - other: The other vector (represented as an Expr, list of floats, or Vector) to compare against. - - Returns: - A new `Expr` representing the Euclidean distance between the two vectors. - """ - return EuclideanDistance(self, self._cast_to_expr_or_convert_to_constant(other)) - - def dot_product(self, other: Expr | list[float] | Vector) -> "DotProduct": - """Calculates the dot product between two vectors. - - Example: - >>> # Calculate the dot product between a feature vector and a target vector - >>> Field.of("features").dot_product([0.5, 0.8, 0.2]) - >>> # Calculate the dot product between two document vectors: 'docVector1' and 'docVector2' - >>> Field.of("docVector1").dot_product(Field.of("docVector2")) - - Args: - other: The other vector (represented as an Expr, list of floats, or Vector) to calculate dot product with. - - Returns: - A new `Expr` representing the dot product between the two vectors. - """ - return DotProduct(self, self._cast_to_expr_or_convert_to_constant(other)) - def vector_length(self) -> "VectorLength": """Creates an expression that calculates the length (dimension) of a Firestore Vector. @@ -1100,18 +948,6 @@ def __init__(self, left: Expr, right: Expr): super().__init__("divide", [left, right]) -class DotProduct(Function): - """Represents the vector dot product function.""" - def __init__(self, vector1: Expr, vector2: Expr): - super().__init__("dot_product", [vector1, vector2]) - - -class EuclideanDistance(Function): - """Represents the vector Euclidean distance function.""" - def __init__(self, vector1: Expr, vector2: Expr): - super().__init__("euclidean_distance", [vector1, vector2]) - - class LogicalMax(Function): """Represents the logical maximum function based on Firestore type ordering.""" def __init__(self, left: Expr, right: Expr): @@ -1148,24 +984,6 @@ def __init__(self, value: Expr): super().__init__("parent", [value]) -class ReplaceAll(Function): - """Represents replacing all occurrences of a substring.""" - def __init__(self, value: Expr, pattern: Expr, replacement: Expr): - super().__init__("replace_all", [value, pattern, replacement]) - - -class ReplaceFirst(Function): - """Represents replacing the first occurrence of a substring.""" - def __init__(self, value: Expr, pattern: Expr, replacement: Expr): - super().__init__("replace_first", [value, pattern, replacement]) - - -class Reverse(Function): - """Represents reversing a string.""" - def __init__(self, expr: Expr): - super().__init__("reverse", [expr]) - - class StrConcat(Function): """Represents concatenating multiple strings.""" def __init__(self, *exprs: Expr): @@ -1208,24 +1026,6 @@ def __init__(self, input: Expr): super().__init__("timestamp_to_unix_seconds", [input]) -class ToLower(Function): - """Represents converting a string to lowercase.""" - def __init__(self, value: Expr): - super().__init__("to_lower", [value]) - - -class ToUpper(Function): - """Represents converting a string to uppercase.""" - def __init__(self, value: Expr): - super().__init__("to_upper", [value]) - - -class Trim(Function): - """Represents trimming whitespace from a string.""" - def __init__(self, expr: Expr): - super().__init__("trim", [expr]) - - class UnixMicrosToTimestamp(Function): """Represents converting microseconds since epoch to a timestamp.""" def __init__(self, input: Expr): @@ -1256,12 +1056,6 @@ def __init__(self, left: Expr, right: Expr): super().__init__("add", [left, right]) -class ArrayConcat(Function): - """Represents concatenating multiple arrays.""" - def __init__(self, array: Expr, rest: List[Expr]): - super().__init__("array_concat", [array] + rest) - - class ArrayElement(Function): """Represents accessing an element within an array""" def __init__(self): @@ -1310,12 +1104,6 @@ def __init__(self, value: Expr): super().__init__("collection_id", [value]) -class CosineDistance(Function): - """Represents the vector cosine distance function.""" - def __init__(self, vector1: Expr, vector2: Expr): - super().__init__("cosine_distance", [vector1, vector2]) - - class Accumulator(Function): """A base class for aggregation functions that operate across multiple inputs.""" diff --git a/google/cloud/firestore_v1/pipeline_stages.py b/google/cloud/firestore_v1/pipeline_stages.py index 14b3359be..459aa13b6 100644 --- a/google/cloud/firestore_v1/pipeline_stages.py +++ b/google/cloud/firestore_v1/pipeline_stages.py @@ -249,22 +249,6 @@ def _pb_args(self) -> list[Value]: return [f._to_pb() for f in self.fields] -class Replace(Stage): - """Replaces the document content with the value of a specified field.""" - class Mode(Enum): - FULL_REPLACE = "full_replace" - MERGE_PREFER_NEXT = "merge_prefer_nest" - MERGE_PREFER_PARENT = "merge_prefer_parent" - - def __init__(self, field: Selectable | str, mode: Mode | str = Mode.FULL_REPLACE): - super().__init__() - self.field = Field(field) if isinstance(field, str) else field - self.mode = self.Mode[mode] if isinstance(mode, str) else mode - - def _pb_args(self): - return [self.field._to_pb(), Value(string_value=self.mode.value)] - - class Sample(Stage): """Performs pseudo-random sampling of documents.""" def __init__(self, limit_or_options: int | SampleOptions): diff --git a/tests/system/pipeline_e2e.yaml b/tests/system/pipeline_e2e.yaml index 38a1928c5..dc0a9ab9c 100644 --- a/tests/system/pipeline_e2e.yaml +++ b/tests/system/pipeline_e2e.yaml @@ -214,43 +214,6 @@ tests: accumulators: [] groups: [genre] assert_error: ".* requires at least one accumulator" - - description: testDistinct - pipeline: - - Collection: books - - Where: - - Lt: - - Field: published - - Constant: 1900 - - Distinct: - - ExprWithAlias: - - ToLower: - - Field: genre - - "lower_genre" - assert_results: - - lower_genre: romance - - lower_genre: psychological thriller - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - fieldReferenceValue: published - - integerValue: '1900' - name: lt - name: where - - args: - - mapValue: - fields: - lower_genre: - functionValue: - args: - - fieldReferenceValue: genre - name: to_lower - name: distinct - description: testGroupBysAndAggregate pipeline: - Collection: books @@ -813,44 +776,6 @@ tests: - integerValue: '3' name: eq name: where - - description: testArrayConcat - pipeline: - - Collection: books - - Select: - - ExprWithAlias: - - ArrayConcat: - - Field: tags - - - Constant: newTag1 - - Constant: newTag2 - - "modifiedTags" - - Limit: 1 - assert_results: - - modifiedTags: - - comedy - - space - - adventure - - newTag1 - - newTag2 - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - mapValue: - fields: - modifiedTags: - functionValue: - args: - - fieldReferenceValue: tags - - stringValue: newTag1 - - stringValue: newTag2 - name: array_concat - name: select - - args: - - integerValue: '1' - name: limit - description: testStrConcat pipeline: - Collection: books @@ -1042,122 +967,6 @@ tests: expression: fieldReferenceValue: title name: sort - - description: testStringFunctions - Reverse - pipeline: - - Collection: books - - Select: - - ExprWithAlias: - - Reverse: - - Field: title - - "reversed_title" - - Where: - - Eq: - - Field: author - - Constant: Douglas Adams - assert_results: - - reversed_title: yxalaG ot ediug s'reknhiHcH ehT - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - mapValue: - fields: - reversed_title: - functionValue: - args: - - fieldReferenceValue: title - name: reverse - name: select - - args: - - functionValue: - args: - - fieldReferenceValue: author - - stringValue: Douglas Adams - name: eq - name: where - - description: testStringFunctions - ReplaceFirst - pipeline: - - Collection: books - - Select: - - ExprWithAlias: - - ReplaceFirst: - - Field: title - - Constant: The - - Constant: A - - "replaced_title" - - Where: - - Eq: - - Field: author - - Constant: Douglas Adams - assert_results: - - replaced_title: A Hitchhiker's Guide to the Galaxy - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - mapValue: - fields: - replaced_title: - functionValue: - args: - - fieldReferenceValue: title - - stringValue: The - - stringValue: A - name: replace_first - name: select - - args: - - functionValue: - args: - - fieldReferenceValue: author - - stringValue: Douglas Adams - name: eq - name: where - - description: testStringFunctions - ReplaceAll - pipeline: - - Collection: books - - Select: - - ExprWithAlias: - - ReplaceAll: - - Field: title - - Constant: " " - - Constant: "_" - - "replaced_title" - - Where: - - Eq: - - Field: author - - Constant: Douglas Adams - assert_results: - - replaced_title: The_Hitchhiker's_Guide_to_the_Galaxy - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - mapValue: - fields: - replaced_title: - functionValue: - args: - - fieldReferenceValue: title - - stringValue: ' ' - - stringValue: _ - name: replace_all - name: select - - args: - - functionValue: - args: - - fieldReferenceValue: author - - stringValue: Douglas Adams - name: eq - name: where - description: testStringFunctions - CharLength pipeline: - Collection: books @@ -1236,115 +1045,6 @@ tests: name: str_concat name: byte_length name: select - - description: testToLowercase - pipeline: - - Collection: books - - Select: - - ExprWithAlias: - - ToLower: - - Field: title - - "lowercaseTitle" - - Limit: 1 - assert_results: - - lowercaseTitle: the hitchhiker's guide to the galaxy - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - mapValue: - fields: - lowercaseTitle: - functionValue: - args: - - fieldReferenceValue: title - name: to_lower - name: select - - args: - - integerValue: '1' - name: limit - - description: testToUppercase - pipeline: - - Collection: books - - Select: - - ExprWithAlias: - - ToUpper: - - Field: author - - "uppercaseAuthor" - - Limit: 1 - assert_results: - - uppercaseAuthor: DOUGLAS ADAMS - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - mapValue: - fields: - uppercaseAuthor: - functionValue: - args: - - fieldReferenceValue: author - name: to_upper - name: select - - args: - - integerValue: '1' - name: limit - - description: testTrim - pipeline: - - Collection: books - - AddFields: - - ExprWithAlias: - - StrConcat: - - Constant: " " - - Field: title - - Constant: " " - - "spacedTitle" - - Select: - - ExprWithAlias: - - Trim: - - Field: spacedTitle - - "trimmedTitle" - - spacedTitle - - Limit: 1 - assert_results: - - trimmedTitle: The Hitchhiker's Guide to the Galaxy - spacedTitle: " The Hitchhiker's Guide to the Galaxy " - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - mapValue: - fields: - spacedTitle: - functionValue: - args: - - stringValue: ' ' - - fieldReferenceValue: title - - stringValue: ' ' - name: str_concat - name: add_fields - - args: - - mapValue: - fields: - spacedTitle: - fieldReferenceValue: spacedTitle - trimmedTitle: - functionValue: - args: - - fieldReferenceValue: spacedTitle - name: trim - name: select - - args: - - integerValue: '1' - name: limit - description: testLike pipeline: - Collection: books @@ -1656,11 +1356,6 @@ tests: - IsNaN: - Field: rating - Select: - - ExprWithAlias: - - Eq: - - Field: rating - - Constant: null - - "ratingIsNull" - ExprWithAlias: - Not: - IsNaN: @@ -1668,8 +1363,7 @@ tests: - "ratingIsNotNaN" - Limit: 1 assert_results: - - ratingIsNull: false - ratingIsNotNaN: true + - ratingIsNotNaN: true assert_proto: pipeline: stages: @@ -1696,12 +1390,6 @@ tests: - fieldReferenceValue: rating name: is_nan name: not - ratingIsNull: - functionValue: - args: - - fieldReferenceValue: rating - - nullValue: null - name: eq name: select - args: - integerValue: '1' @@ -1812,79 +1500,6 @@ tests: - booleanValue: true name: eq name: where - - description: testDistanceFunctions - pipeline: - - Collection: books - - Select: - - ExprWithAlias: - - CosineDistance: - - Constant: [[0.1, 0.1]] - - Constant: [[0.5, 0.8]] - - "cosineDistance" - - ExprWithAlias: - - DotProduct: - - Constant: [[0.1, 0.1]] - - Constant: [[0.5, 0.8]] - - "dotProductDistance" - - ExprWithAlias: - - EuclideanDistance: - - Constant: [[0.1, 0.1]] - - Constant: [[0.5, 0.8]] - - "euclideanDistance" - - Limit: 1 - assert_results: - - cosineDistance: 0.02560880430538015 - dotProductDistance: 0.13 - euclideanDistance: 0.806225774829855 - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - mapValue: - fields: - cosineDistance: - functionValue: - args: - - arrayValue: - values: - - doubleValue: 0.1 - - doubleValue: 0.1 - - arrayValue: - values: - - doubleValue: 0.5 - - doubleValue: 0.8 - name: cosine_distance - dotProductDistance: - functionValue: - args: - - arrayValue: - values: - - doubleValue: 0.1 - - doubleValue: 0.1 - - arrayValue: - values: - - doubleValue: 0.5 - - doubleValue: 0.8 - name: dot_product - euclideanDistance: - functionValue: - args: - - arrayValue: - values: - - doubleValue: 0.1 - - doubleValue: 0.1 - - arrayValue: - values: - - doubleValue: 0.5 - - doubleValue: 0.8 - name: euclidean_distance - name: select - - args: - - integerValue: '1' - name: limit - description: testNestedFields pipeline: - Collection: books @@ -1933,43 +1548,6 @@ tests: title: fieldReferenceValue: title name: select - - description: testReplace - pipeline: - - Collection: books - - Where: - - Eq: - - Field: title - - Constant: "The Hitchhiker's Guide to the Galaxy" - - Replace: awards - assert_results: - - title: The Hitchhiker's Guide to the Galaxy - author: Douglas Adams - genre: Science Fiction - published: 1979 - rating: 4.2 - tags: - - comedy - - space - - adventure - hugo: true - nebula: false - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - fieldReferenceValue: title - - stringValue: The Hitchhiker's Guide to the Galaxy - name: eq - name: where - - args: - - fieldReferenceValue: awards - - stringValue: full_replace - name: replace - description: testSampleLimit pipeline: - Collection: books From 41b91d4d4f4f44bf20753742fa5e6ccf5352e645 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 2 May 2025 14:13:19 -0700 Subject: [PATCH 098/131] fixed regex_match test --- tests/system/pipeline_e2e.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/system/pipeline_e2e.yaml b/tests/system/pipeline_e2e.yaml index dc0a9ab9c..dc262f4a9 100644 --- a/tests/system/pipeline_e2e.yaml +++ b/tests/system/pipeline_e2e.yaml @@ -1097,8 +1097,8 @@ tests: - functionValue: args: - fieldReferenceValue: title - - stringValue: "(?i)(the|of)" - name: regex_matches + - stringValue: ".*(?i)(the|of).*" + name: regex_match name: where - description: testArithmeticOperations pipeline: From 7f692290b7921298e39a084e27ad96cb219e61d3 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 2 May 2025 14:16:51 -0700 Subject: [PATCH 099/131] ran blacken --- google/cloud/firestore_v1/async_client.py | 2 +- google/cloud/firestore_v1/async_pipeline.py | 11 +- google/cloud/firestore_v1/base_pipeline.py | 18 +- google/cloud/firestore_v1/base_query.py | 21 +- google/cloud/firestore_v1/client.py | 2 +- google/cloud/firestore_v1/pipeline.py | 9 +- .../firestore_v1/pipeline_expressions.py | 183 +++++++++++++++--- google/cloud/firestore_v1/pipeline_stages.py | 114 +++++++++-- .../services/firestore/transports/rest.py | 3 +- tests/system/test_pipeline_acceptance.py | 2 + tests/unit/v1/test_base_query.py | 1 + 11 files changed, 294 insertions(+), 72 deletions(-) diff --git a/google/cloud/firestore_v1/async_client.py b/google/cloud/firestore_v1/async_client.py index 9fb6883bb..10aa02c69 100644 --- a/google/cloud/firestore_v1/async_client.py +++ b/google/cloud/firestore_v1/async_client.py @@ -415,4 +415,4 @@ def transaction(self, **kwargs) -> AsyncTransaction: return AsyncTransaction(self, **kwargs) def pipeline(self, *stages) -> AsyncPipeline: - return AsyncPipeline(self, *stages) \ No newline at end of file + return AsyncPipeline(self, *stages) diff --git a/google/cloud/firestore_v1/async_pipeline.py b/google/cloud/firestore_v1/async_pipeline.py index 07fe3ddbf..0e3453c94 100644 --- a/google/cloud/firestore_v1/async_pipeline.py +++ b/google/cloud/firestore_v1/async_pipeline.py @@ -46,7 +46,8 @@ class AsyncPipeline(_BasePipeline): Use `client.collection("...").pipeline()` to create instances of this class. """ - def __init__(self, client:AsyncClient, *stages: stages.Stage): + + def __init__(self, client: AsyncClient, *stages: stages.Stage): """ Initializes an asynchronous Pipeline. @@ -57,12 +58,16 @@ def __init__(self, client:AsyncClient, *stages: stages.Stage): super().__init__(client, *stages) async def execute(self) -> AsyncIterable["DocumentSnapshot"]: - database_name = f"projects/{self._client.project}/databases/{self._client._database}" + database_name = ( + f"projects/{self._client.project}/databases/{self._client._database}" + ) request = ExecutePipelineRequest( database=database_name, structured_pipeline=self._to_pb(), read_time=datetime.datetime.now(), ) - async for response in await self._client._firestore_api.execute_pipeline(request): + async for response in await self._client._firestore_api.execute_pipeline( + request + ): for snapshot in self._parse_response(response, self._client): yield snapshot diff --git a/google/cloud/firestore_v1/base_pipeline.py b/google/cloud/firestore_v1/base_pipeline.py index 56a8b2d59..75f7c3e8d 100644 --- a/google/cloud/firestore_v1/base_pipeline.py +++ b/google/cloud/firestore_v1/base_pipeline.py @@ -17,7 +17,9 @@ from typing_extensions import Self from google.cloud.firestore_v1 import pipeline_stages as stages from google.cloud.firestore_v1.base_client import BaseClient -from google.cloud.firestore_v1.types.pipeline import StructuredPipeline as StructuredPipeline_pb +from google.cloud.firestore_v1.types.pipeline import ( + StructuredPipeline as StructuredPipeline_pb, +) from google.cloud.firestore_v1.vector import Vector from google.cloud.firestore_v1.base_vector_query import DistanceMeasure from google.cloud.firestore_v1 import _helpers, document @@ -39,6 +41,7 @@ class _BasePipeline: This class is not intended to be instantiated directly. Use `client.collection.("...").pipeline()` to create pipeline instances. """ + def __init__(self, client: BaseClient, *stages: stages.Stage): """ Initializes a new pipeline with the given stages. @@ -62,8 +65,9 @@ def __repr__(self): return f"Pipeline(\n {stages_str}\n)" def _to_pb(self) -> StructuredPipeline_pb: - return StructuredPipeline_pb(pipeline={"stages":[s._to_pb() for s in self.stages]}) - + return StructuredPipeline_pb( + pipeline={"stages": [s._to_pb() for s in self.stages]} + ) def _append(self, new_stage): """ @@ -260,7 +264,9 @@ def find_nearest( Returns: A new Pipeline object with this stage appended to the stage list """ - return self._append(stages.FindNearest(field, vector, distance_measure, options)) + return self._append( + stages.FindNearest(field, vector, distance_measure, options) + ) def sort(self, *orders: stages.Ordering) -> Self: """ @@ -409,8 +415,8 @@ def generic_stage(self, name: str, *params: Expr) -> Self: """ Adds a generic, named stage to the pipeline with specified parameters. - This method provides a flexible way to extend the pipeline's functionality - by adding custom stages. Each generic stage is defined by a unique `name` + This method provides a flexible way to extend the pipeline's functionality + by adding custom stages. Each generic stage is defined by a unique `name` and a set of `params` that control its behavior. Example: diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index 100b137ef..d8fb96dee 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -1114,16 +1114,15 @@ def pipeline(self): # Filters for filter_ in self._field_filters: - ppl = ppl.where(pipeline_expressions.FilterCondition._from_query_filter_pb(filter_, self._client)) + ppl = ppl.where( + pipeline_expressions.FilterCondition._from_query_filter_pb( + filter_, self._client + ) + ) # Projections if self._projection and self._projection.fields: - ppl = ppl.select( - *[ - field.field_path - for field in self._projection.fields - ] - ) + ppl = ppl.select(*[field.field_path for field in self._projection.fields]) # Orders orders = self._normalize_orders() @@ -1134,7 +1133,9 @@ def pipeline(self): field = pipeline_expressions.Field.of(order.field.field_path) exists.append(field.exists()) direction = ( - "ascending" if order.direction == StructuredQuery.Direction.ASCENDING else "descending" + "ascending" + if order.direction == StructuredQuery.Direction.ASCENDING + else "descending" ) orderings.append(pipeline_expressions.Ordering(field, direction)) @@ -1149,7 +1150,9 @@ def pipeline(self): # Cursors, Limit and Offset if self._start_at or self._end_at or self._limit_to_last: - raise NotImplementedError("Query to Pipeline conversion: cursors and limitToLast is not supported yet.") + raise NotImplementedError( + "Query to Pipeline conversion: cursors and limitToLast is not supported yet." + ) else: # Limit & Offset without cursors if self._offset: ppl = ppl.offset(self._offset) diff --git a/google/cloud/firestore_v1/client.py b/google/cloud/firestore_v1/client.py index aa82c59b6..ed1a2543f 100644 --- a/google/cloud/firestore_v1/client.py +++ b/google/cloud/firestore_v1/client.py @@ -407,4 +407,4 @@ def transaction(self, **kwargs) -> Transaction: return Transaction(self, **kwargs) def pipeline(self, *stages) -> Pipeline: - return Pipeline(self, *stages) \ No newline at end of file + return Pipeline(self, *stages) diff --git a/google/cloud/firestore_v1/pipeline.py b/google/cloud/firestore_v1/pipeline.py index f17a68ed7..67fa11fba 100644 --- a/google/cloud/firestore_v1/pipeline.py +++ b/google/cloud/firestore_v1/pipeline.py @@ -43,7 +43,8 @@ class Pipeline(_BasePipeline): Use `client.collection("...").pipeline()` to create instances of this class. """ - def __init__(self, client:Client, *stages: stages.Stage): + + def __init__(self, client: Client, *stages: stages.Stage): """ Initializes a Pipeline. @@ -54,10 +55,12 @@ def __init__(self, client:Client, *stages: stages.Stage): super().__init__(client, *stages) def execute(self) -> Iterable["DocumentSnapshot"]: - database_name = f"projects/{self._client.project}/databases/{self._client._database}" + database_name = ( + f"projects/{self._client.project}/databases/{self._client._database}" + ) request = ExecutePipelineRequest( database=database_name, structured_pipeline=self._to_pb(), ) for response in self._client._firestore_api.execute_pipeline(request): - yield from self._parse_response(response, self) \ No newline at end of file + yield from self._parse_response(response, self) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index c26b2a4cb..1eda32713 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -13,7 +13,19 @@ # limitations under the License. from __future__ import annotations -from typing import Any, Iterable, List, Mapping, Union, Generic, TypeVar, List, Dict, Tuple, Sequence +from typing import ( + Any, + Iterable, + List, + Mapping, + Union, + Generic, + TypeVar, + List, + Dict, + Tuple, + Sequence, +) from abc import ABC from abc import abstractmethod from enum import Enum @@ -27,7 +39,20 @@ from google.cloud.firestore_v1._helpers import encode_value from google.cloud.firestore_v1._helpers import decode_value -CONSTANT_TYPE = TypeVar('CONSTANT_TYPE', str, int, float, bool, datetime.datetime, bytes, GeoPoint, Vector, list, Dict[str, Any], None) +CONSTANT_TYPE = TypeVar( + "CONSTANT_TYPE", + str, + int, + float, + bool, + datetime.datetime, + bytes, + GeoPoint, + Vector, + list, + Dict[str, Any], + None, +) class Ordering: @@ -37,7 +62,7 @@ class Direction(Enum): ASCENDING = "ascending" DESCENDING = "descending" - def __init__(self, expr, order_dir: Direction | str=Direction.ASCENDING): + def __init__(self, expr, order_dir: Direction | str = Direction.ASCENDING): """ Initializes an Ordering instance @@ -48,7 +73,11 @@ def __init__(self, expr, order_dir: Direction | str=Direction.ASCENDING): Defaults to ascending """ self.expr = expr if isinstance(expr, Expr) else Field.of(expr) - self.order_dir = Ordering.Direction[order_dir.upper()] if isinstance(order_dir, str) else order_dir + self.order_dir = ( + Ordering.Direction[order_dir.upper()] + if isinstance(order_dir, str) + else order_dir + ) def __repr__(self): if self.order_dir is Ordering.Direction.ASCENDING: @@ -59,21 +88,23 @@ def __repr__(self): def _to_pb(self) -> Value: return Value( - map_value={"fields": - { + map_value={ + "fields": { "direction": Value(string_value=self.order_dir.value), - "expression": self.expr._to_pb() + "expression": self.expr._to_pb(), } } ) + class SampleOptions: """Options for the 'sample' pipeline stage.""" + class Mode(Enum): DOCUMENTS = "documents" PERCENT = "percent" - def __init__(self, value: int | float, mode:Mode | str): + def __init__(self, value: int | float, mode: Mode | str): self.value = value self.mode = SampleOptions.Mode[mode.upper()] if isinstance(mode, str) else mode @@ -85,7 +116,7 @@ def __repr__(self): return f"SampleOptions.{mode_str}({self.value})" @staticmethod - def doc_limit(value:int): + def doc_limit(value: int): """ Sample a set number of documents @@ -95,7 +126,7 @@ def doc_limit(value:int): return SampleOptions(value, mode=SampleOptions.Mode.DOCUMENTS) @staticmethod - def percentage(value:float): + def percentage(value: float): """ Sample a percentage of documents @@ -104,6 +135,7 @@ def percentage(value:float): """ return SampleOptions(value, mode=SampleOptions.Mode.PERCENT) + class Expr(ABC): """Represents an expression that can be evaluated to a value within the execution of a pipeline. @@ -415,7 +447,9 @@ def array_contains(self, element: Expr | CONSTANT_TYPE) -> "ArrayContains": """ return ArrayContains(self, self._cast_to_expr_or_convert_to_constant(element)) - def array_contains_all(self, elements: List[Expr | CONSTANT_TYPE]) -> "ArrayContainsAll": + def array_contains_all( + self, elements: List[Expr | CONSTANT_TYPE] + ) -> "ArrayContainsAll": """Creates an expression that checks if an array contains all the specified elements. Example: @@ -430,9 +464,13 @@ def array_contains_all(self, elements: List[Expr | CONSTANT_TYPE]) -> "ArrayCont Returns: A new `Expr` representing the 'array_contains_all' comparison. """ - return ArrayContainsAll(self, [self._cast_to_expr_or_convert_to_constant(e) for e in elements]) + return ArrayContainsAll( + self, [self._cast_to_expr_or_convert_to_constant(e) for e in elements] + ) - def array_contains_any(self, elements: List[Expr | CONSTANT_TYPE]) -> "ArrayContainsAny": + def array_contains_any( + self, elements: List[Expr | CONSTANT_TYPE] + ) -> "ArrayContainsAny": """Creates an expression that checks if an array contains any of the specified elements. Example: @@ -448,7 +486,9 @@ def array_contains_any(self, elements: List[Expr | CONSTANT_TYPE]) -> "ArrayCont Returns: A new `Expr` representing the 'array_contains_any' comparison. """ - return ArrayContainsAny(self, [self._cast_to_expr_or_convert_to_constant(e) for e in elements]) + return ArrayContainsAny( + self, [self._cast_to_expr_or_convert_to_constant(e) for e in elements] + ) def array_length(self) -> "ArrayLength": """Creates an expression that calculates the length of an array. @@ -700,7 +740,9 @@ def str_concat(self, *elements: Expr | CONSTANT_TYPE) -> "StrConcat": Returns: A new `Expr` representing the concatenated string. """ - return StrConcat(*[self._cast_to_expr_or_convert_to_constant(el) for el in elements]) + return StrConcat( + *[self._cast_to_expr_or_convert_to_constant(el) for el in elements] + ) def map_get(self, key: str) -> "MapGet": """Accesses a value from a map (object) field using the provided key. @@ -831,7 +873,11 @@ def timestamp_add(self, unit: Expr | str, amount: Expr | float) -> "TimestampAdd Returns: A new `Expr` representing the resulting timestamp. """ - return TimestampAdd(self, self._cast_to_expr_or_convert_to_constant(unit), self._cast_to_expr_or_convert_to_constant(amount)) + return TimestampAdd( + self, + self._cast_to_expr_or_convert_to_constant(unit), + self._cast_to_expr_or_convert_to_constant(amount), + ) def timestamp_sub(self, unit: Expr | str, amount: Expr | float) -> "TimestampSub": """Creates an expression that subtracts a specified amount of time from this timestamp expression. @@ -850,7 +896,11 @@ def timestamp_sub(self, unit: Expr | str, amount: Expr | float) -> "TimestampSub Returns: A new `Expr` representing the resulting timestamp. """ - return TimestampSub(self, self._cast_to_expr_or_convert_to_constant(unit), self._cast_to_expr_or_convert_to_constant(amount)) + return TimestampSub( + self, + self._cast_to_expr_or_convert_to_constant(unit), + self._cast_to_expr_or_convert_to_constant(amount), + ) def ascending(self) -> Ordering: """Creates an `Ordering` that sorts documents in ascending order based on this expression. @@ -897,13 +947,15 @@ def as_(self, alias: str) -> "ExprWithAlias": """ return ExprWithAlias(self, alias) + class Constant(Expr, Generic[CONSTANT_TYPE]): """Represents a constant literal value in an expression.""" + def __init__(self, value: CONSTANT_TYPE): self.value: CONSTANT_TYPE = value @staticmethod - def of(value:CONSTANT_TYPE) -> Constant[CONSTANT_TYPE]: + def of(value: CONSTANT_TYPE) -> Constant[CONSTANT_TYPE]: """Creates a constant expression from a Python value.""" return Constant(value) @@ -913,8 +965,10 @@ def __repr__(self): def _to_pb(self) -> Value: return encode_value(self.value) + class ListOfExprs(Expr): """Represents a list of expressions, typically used as an argument to functions like 'in' or array functions.""" + def __init__(self, exprs: List[Expr]): self.exprs: list[Expr] = exprs @@ -938,168 +992,197 @@ def __repr__(self): def _to_pb(self): return Value( function_value={ - "name": self.name, "args": [p._to_pb() for p in self.params] + "name": self.name, + "args": [p._to_pb() for p in self.params], } ) + class Divide(Function): """Represents the division function.""" + def __init__(self, left: Expr, right: Expr): super().__init__("divide", [left, right]) class LogicalMax(Function): """Represents the logical maximum function based on Firestore type ordering.""" + def __init__(self, left: Expr, right: Expr): super().__init__("logical_maximum", [left, right]) class LogicalMin(Function): """Represents the logical minimum function based on Firestore type ordering.""" + def __init__(self, left: Expr, right: Expr): super().__init__("logical_minimum", [left, right]) class MapGet(Function): """Represents accessing a value within a map by key.""" + def __init__(self, map_: Expr, key: Constant[str]): super().__init__("map_get", [map_, key]) class Mod(Function): """Represents the modulo function.""" + def __init__(self, left: Expr, right: Expr): super().__init__("mod", [left, right]) class Multiply(Function): """Represents the multiplication function.""" + def __init__(self, left: Expr, right: Expr): super().__init__("multiply", [left, right]) class Parent(Function): """Represents getting the parent document reference.""" + def __init__(self, value: Expr): super().__init__("parent", [value]) class StrConcat(Function): """Represents concatenating multiple strings.""" + def __init__(self, *exprs: Expr): super().__init__("str_concat", exprs) class Subtract(Function): """Represents the subtraction function.""" + def __init__(self, left: Expr, right: Expr): super().__init__("subtract", [left, right]) class TimestampAdd(Function): """Represents adding a duration to a timestamp.""" + def __init__(self, timestamp: Expr, unit: Expr, amount: Expr): super().__init__("timestamp_add", [timestamp, unit, amount]) class TimestampSub(Function): """Represents subtracting a duration from a timestamp.""" + def __init__(self, timestamp: Expr, unit: Expr, amount: Expr): super().__init__("timestamp_sub", [timestamp, unit, amount]) class TimestampToUnixMicros(Function): """Represents converting a timestamp to microseconds since epoch.""" + def __init__(self, input: Expr): super().__init__("timestamp_to_unix_micros", [input]) class TimestampToUnixMillis(Function): """Represents converting a timestamp to milliseconds since epoch.""" + def __init__(self, input: Expr): super().__init__("timestamp_to_unix_millis", [input]) class TimestampToUnixSeconds(Function): """Represents converting a timestamp to seconds since epoch.""" + def __init__(self, input: Expr): super().__init__("timestamp_to_unix_seconds", [input]) class UnixMicrosToTimestamp(Function): """Represents converting microseconds since epoch to a timestamp.""" + def __init__(self, input: Expr): super().__init__("unix_micros_to_timestamp", [input]) class UnixMillisToTimestamp(Function): """Represents converting milliseconds since epoch to a timestamp.""" + def __init__(self, input: Expr): super().__init__("unix_millis_to_timestamp", [input]) class UnixSecondsToTimestamp(Function): """Represents converting seconds since epoch to a timestamp.""" + def __init__(self, input: Expr): super().__init__("unix_seconds_to_timestamp", [input]) class VectorLength(Function): """Represents getting the length (dimension) of a vector.""" + def __init__(self, array: Expr): super().__init__("vector_length", [array]) class Add(Function): """Represents the addition function.""" + def __init__(self, left: Expr, right: Expr): super().__init__("add", [left, right]) class ArrayElement(Function): """Represents accessing an element within an array""" + def __init__(self): super().__init__("array_element", []) class ArrayFilter(Function): """Represents filtering elements from an array based on a condition.""" + def __init__(self, array: Expr, filter: "FilterCondition"): super().__init__("array_filter", [array, filter]) class ArrayLength(Function): """Represents getting the length of an array.""" + def __init__(self, array: Expr): super().__init__("array_length", [array]) class ArrayReverse(Function): """Represents reversing the elements of an array.""" + def __init__(self, array: Expr): super().__init__("array_reverse", [array]) class ArrayTransform(Function): """Represents applying a transformation function to each element of an array.""" + def __init__(self, array: Expr, transform: Function): super().__init__("array_transform", [array, transform]) class ByteLength(Function): """Represents getting the byte length of a string (UTF-8).""" + def __init__(self, expr: Expr): super().__init__("byte_length", [expr]) class CharLength(Function): """Represents getting the character length of a string.""" + def __init__(self, expr: Expr): super().__init__("char_length", [expr]) class CollectionId(Function): """Represents getting the collection ID from a document reference.""" + def __init__(self, value: Expr): super().__init__("collection_id", [value]) @@ -1110,37 +1193,43 @@ class Accumulator(Function): class Max(Accumulator): """Represents the maximum aggregation function.""" - def __init__(self, value: Expr, distinct: bool=False): + + def __init__(self, value: Expr, distinct: bool = False): super().__init__("maximum", [value]) class Min(Accumulator): """Represents the minimum aggregation function.""" - def __init__(self, value: Expr, distinct: bool=False): + + def __init__(self, value: Expr, distinct: bool = False): super().__init__("minimum", [value]) class Sum(Accumulator): """Represents the sum aggregation function.""" - def __init__(self, value: Expr, distinct: bool=False): + + def __init__(self, value: Expr, distinct: bool = False): super().__init__("sum", [value]) class Avg(Accumulator): """Represents the average aggregation function.""" - def __init__(self, value: Expr, distinct: bool=False): + + def __init__(self, value: Expr, distinct: bool = False): super().__init__("avg", [value]) class Count(Accumulator): """Represents the count aggregation function.""" + def __init__(self, value: Expr | None = None): super().__init__("count", [value] if value else []) class CountIf(Function): """Represents counting inputs where a condition is true (likely used internally or planned).""" - def __init__(self, value: Expr, distinct: bool=False): + + def __init__(self, value: Expr, distinct: bool = False): super().__init__("countif", [value] if value else []) @@ -1152,9 +1241,12 @@ def _to_map(self): raise NotImplementedError -T = TypeVar('T', bound=Expr) +T = TypeVar("T", bound=Expr) + + class ExprWithAlias(Selectable, Generic[T]): """Wraps an expression with an alias.""" + def __init__(self, expr: T, alias: str): self.expr = expr self.alias = alias @@ -1166,13 +1258,12 @@ def __repr__(self): return f"{self.expr}.as_('{self.alias}')" def _to_pb(self): - return Value( - map_value={"fields": {self.alias: self.expr._to_pb()}} - ) + return Value(map_value={"fields": {self.alias: self.expr._to_pb()}}) class Field(Selectable): """Represents a reference to a field within a document.""" + DOCUMENT_ID = "__name__" def __init__(self, path: str): @@ -1213,13 +1304,18 @@ class FilterCondition(Function): @staticmethod def _from_query_filter_pb(filter_pb, client): if isinstance(filter_pb, Query_pb.CompositeFilter): - sub_filters = [FilterCondition._from_query_filter_pb(f, client) for f in filter_pb.filters] + sub_filters = [ + FilterCondition._from_query_filter_pb(f, client) + for f in filter_pb.filters + ] if filter_pb.op == Query_pb.CompositeFilter.Operator.OR: return Or(*sub_filters) elif filter_pb.op == Query_pb.CompositeFilter.Operator.AND: return And(*sub_filters) else: - raise TypeError(f"Unexpected CompositeFilter operator type: {filter_pb.op}") + raise TypeError( + f"Unexpected CompositeFilter operator type: {filter_pb.op}" + ) elif isinstance(filter_pb, Query_pb.UnaryFilter): field = Field.of(filter_pb.field.field_path) if filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NAN: @@ -1259,7 +1355,11 @@ def _from_query_filter_pb(filter_pb, client): raise TypeError(f"Unexpected FieldFilter operator type: {filter_pb.op}") elif isinstance(filter_pb, Query_pb.Filter): # unwrap oneof - f = filter_pb.composite_filter or filter_pb.field_filter or filter_pb.unary_filter + f = ( + filter_pb.composite_filter + or filter_pb.field_filter + or filter_pb.unary_filter + ) return FilterCondition._from_query_filter_pb(f, client) else: raise TypeError(f"Unexpected filter type: {type(filter_pb)}") @@ -1279,48 +1379,56 @@ def __init__(self, array: Expr, element: Expr): class ArrayContainsAll(FilterCondition): """Represents checking if an array contains all specified elements.""" + def __init__(self, array: Expr, elements: List[Expr]): super().__init__("array_contains_all", [array, ListOfExprs(elements)]) class ArrayContainsAny(FilterCondition): """Represents checking if an array contains any of the specified elements.""" + def __init__(self, array: Expr, elements: List[Expr]): super().__init__("array_contains_any", [array, ListOfExprs(elements)]) class EndsWith(FilterCondition): """Represents checking if a string ends with a specific postfix.""" + def __init__(self, expr: Expr, postfix: Expr): super().__init__("ends_with", [expr, postfix]) class Eq(FilterCondition): """Represents the equality comparison.""" + def __init__(self, left: Expr, right: Expr): super().__init__("eq", [left, right if right else Constant(None)]) class Exists(FilterCondition): """Represents checking if a field exists.""" + def __init__(self, expr: Expr): super().__init__("exists", [expr]) class Gt(FilterCondition): """Represents the greater than comparison.""" + def __init__(self, left: Expr, right: Expr): super().__init__("gt", [left, right if right else Constant(None)]) class Gte(FilterCondition): """Represents the greater than or equal to comparison.""" + def __init__(self, left: Expr, right: Expr): super().__init__("gte", [left, right if right else Constant(None)]) class If(FilterCondition): """Represents a conditional expression (if-then-else).""" + def __init__(self, condition: "FilterCondition", true_expr: Expr, false_expr: Expr): super().__init__( "if", [condition, true_expr, false_expr if false_expr else Constant(None)] @@ -1329,77 +1437,90 @@ def __init__(self, condition: "FilterCondition", true_expr: Expr, false_expr: Ex class In(FilterCondition): """Represents checking if an expression's value is within a list of values.""" + def __init__(self, left: Expr, others: List[Expr]): super().__init__("in", [left, ListOfExprs(others)]) class IsNaN(FilterCondition): """Represents checking if a numeric value is NaN.""" + def __init__(self, value: Expr): super().__init__("is_nan", [value]) class Like(FilterCondition): """Represents a case-sensitive wildcard string comparison.""" + def __init__(self, expr: Expr, pattern: Expr): super().__init__("like", [expr, pattern]) class Lt(FilterCondition): """Represents the less than comparison.""" + def __init__(self, left: Expr, right: Expr): super().__init__("lt", [left, right if right else Constant(None)]) class Lte(FilterCondition): """Represents the less than or equal to comparison.""" + def __init__(self, left: Expr, right: Expr): super().__init__("lte", [left, right if right else Constant(None)]) class Neq(FilterCondition): """Represents the inequality comparison.""" + def __init__(self, left: Expr, right: Expr): super().__init__("neq", [left, right if right else Constant(None)]) class Not(FilterCondition): """Represents the logical NOT of a filter condition.""" + def __init__(self, condition: Expr): super().__init__("not", [condition]) class Or(FilterCondition): """Represents the logical OR of multiple filter conditions.""" + def __init__(self, *conditions: "FilterCondition"): super().__init__("or", conditions) class RegexContains(FilterCondition): """Represents checking if a string contains a substring matching a regex.""" + def __init__(self, expr: Expr, regex: Expr): super().__init__("regex_contains", [expr, regex]) class RegexMatch(FilterCondition): """Represents checking if a string fully matches a regex.""" + def __init__(self, expr: Expr, regex: Expr): super().__init__("regex_match", [expr, regex]) class StartsWith(FilterCondition): """Represents checking if a string starts with a specific prefix.""" + def __init__(self, expr: Expr, prefix: Expr): super().__init__("starts_with", [expr, prefix]) class StrContains(FilterCondition): """Represents checking if a string contains a specific substring.""" + def __init__(self, expr: Expr, substring: Expr): super().__init__("str_contains", [expr, substring]) class Xor(FilterCondition): """Represents the logical XOR of multiple filter conditions.""" + def __init__(self, conditions: List["FilterCondition"]): super().__init__("xor", conditions) diff --git a/google/cloud/firestore_v1/pipeline_stages.py b/google/cloud/firestore_v1/pipeline_stages.py index 459aa13b6..547d22996 100644 --- a/google/cloud/firestore_v1/pipeline_stages.py +++ b/google/cloud/firestore_v1/pipeline_stages.py @@ -30,7 +30,7 @@ FilterCondition, Selectable, SampleOptions, - Ordering + Ordering, ) if TYPE_CHECKING: @@ -45,6 +45,7 @@ class FindNearestOptions: distance_field (Optional[Field]): An optional field to store the calculated distance in the output documents. """ + def __init__( self, limit: Optional[int] = None, @@ -61,6 +62,7 @@ class UnnestOptions: index_field (str): The name of the field to add to each output document, storing the original 0-based index of the element within the array. """ + def __init__(self, index_field: str): self.index_field = index_field @@ -72,11 +74,14 @@ class Stage: transforming) within a Firestore pipeline. Subclasses define the specific arguments and behavior for each operation. """ + def __init__(self, custom_name: Optional[str] = None): self.name = custom_name or type(self).__name__.lower() def _to_pb(self) -> Pipeline_pb.Stage: - return Pipeline_pb.Stage(name=self.name, args=self._pb_args(), options=self._pb_options()) + return Pipeline_pb.Stage( + name=self.name, args=self._pb_args(), options=self._pb_options() + ) def _pb_args(self) -> list[Value]: """Return Ordered list of arguments the given stage expects""" @@ -93,15 +98,24 @@ def __repr__(self): class AddFields(Stage): """Adds new fields to outputs from previous stages.""" + def __init__(self, *fields: Selectable): super().__init__("add_fields") self.fields = list(fields) def _pb_args(self): - return [Value(map_value={"fields": {m[0]: m[1] for m in [f._to_map() for f in self.fields]}})] + return [ + Value( + map_value={ + "fields": {m[0]: m[1] for m in [f._to_map() for f in self.fields]} + } + ) + ] + class Aggregate(Stage): """Performs aggregation operations, optionally grouped.""" + def __init__( self, *extra_accumulators: ExprWithAlias[Accumulator], @@ -109,17 +123,32 @@ def __init__( groups: Sequence[str | Selectable] = (), ): super().__init__() - self.groups: list[Selectable] = [Field(f) if isinstance(f, str) else f for f in groups] - self.accumulators: list[ExprWithAlias[Accumulator]] = [*accumulators, *extra_accumulators] + self.groups: list[Selectable] = [ + Field(f) if isinstance(f, str) else f for f in groups + ] + self.accumulators: list[ExprWithAlias[Accumulator]] = [ + *accumulators, + *extra_accumulators, + ] def _pb_args(self): return [ - Value(map_value={"fields": {m[0]: m[1] for m in [f._to_map() for f in self.accumulators]}}), - Value(map_value={"fields": {m[0]: m[1] for m in [f._to_map() for f in self.groups]}}) + Value( + map_value={ + "fields": { + m[0]: m[1] for m in [f._to_map() for f in self.accumulators] + } + } + ), + Value( + map_value={ + "fields": {m[0]: m[1] for m in [f._to_map() for f in self.groups]} + } + ), ] def __repr__(self): - accumulator_str = ', '.join(repr(v) for v in self.accumulators) + accumulator_str = ", ".join(repr(v) for v in self.accumulators) group_str = "" if self.groups: if self.accumulators: @@ -130,6 +159,7 @@ def __repr__(self): class Collection(Stage): """Specifies a collection as the initial data source.""" + def __init__(self, path: str): super().__init__() if not path.startswith("/"): @@ -139,8 +169,10 @@ def __init__(self, path: str): def _pb_args(self): return [Value(reference_value=self.path)] + class CollectionGroup(Stage): """Specifies a collection group as the initial data source.""" + def __init__(self, collection_id: str): super().__init__("collection_group") self.collection_id = collection_id @@ -151,21 +183,33 @@ def _pb_args(self): class Database(Stage): """Specifies the default database as the initial data source.""" + def __init__(self): super().__init__() + class Distinct(Stage): """Returns documents with distinct combinations of specified field values.""" + def __init__(self, *fields: str | Selectable): super().__init__() - self.fields: list[Selectable] = [Field(f) if isinstance(f, str) else f for f in fields] + self.fields: list[Selectable] = [ + Field(f) if isinstance(f, str) else f for f in fields + ] def _pb_args(self) -> list[Value]: - return [Value(map_value={"fields": {m[0]: m[1] for m in [f._to_map() for f in self.fields]}})] + return [ + Value( + map_value={ + "fields": {m[0]: m[1] for m in [f._to_map() for f in self.fields]} + } + ) + ] class Documents(Stage): """Specifies specific documents as the initial data source.""" + def __init__(self, *paths: str): super().__init__() self.paths = paths @@ -176,11 +220,16 @@ def of(*documents: "DocumentReference") -> "Documents": return Documents(*doc_paths) def _pb_args(self): - return [Value(list_value={"values": [Value(string_value=path) for path in self.paths]})] + return [ + Value( + list_value={"values": [Value(string_value=path) for path in self.paths]} + ) + ] class FindNearest(Stage): """Performs vector distance (similarity) search.""" + def __init__( self, field: str | Expr, @@ -209,11 +258,15 @@ def _pb_options(self) -> dict[str, Value]: options["distance_field"] = self.options.distance_field._to_pb() return options + class GenericStage(Stage): """Represents a generic, named stage with parameters.""" + def __init__(self, name: str, *params: Expr | Value): super().__init__(name) - self.params: list[Value] = [p._to_pb() if isinstance(p, Expr) else p for p in params] + self.params: list[Value] = [ + p._to_pb() if isinstance(p, Expr) else p for p in params + ] def _pb_args(self): return self.params @@ -221,6 +274,7 @@ def _pb_args(self): class Limit(Stage): """Limits the maximum number of documents returned.""" + def __init__(self, limit: int): super().__init__() self.limit = limit @@ -231,6 +285,7 @@ def _pb_args(self): class Offset(Stage): """Skips a specified number of documents.""" + def __init__(self, offset: int): super().__init__() self.offset = offset @@ -241,6 +296,7 @@ def _pb_args(self): class RemoveFields(Stage): """Removes specified fields from outputs.""" + def __init__(self, *fields: str | Field): super().__init__("remove_fields") self.fields = [Field(f) if isinstance(f, str) else f for f in fields] @@ -251,6 +307,7 @@ def _pb_args(self) -> list[Value]: class Sample(Stage): """Performs pseudo-random sampling of documents.""" + def __init__(self, limit_or_options: int | SampleOptions): super().__init__() if isinstance(limit_or_options, int): @@ -261,23 +318,39 @@ def __init__(self, limit_or_options: int | SampleOptions): def _pb_args(self): if self.options.mode == SampleOptions.Mode.DOCUMENTS: - return [Value(integer_value=self.options.value), Value(string_value="documents")] + return [ + Value(integer_value=self.options.value), + Value(string_value="documents"), + ] else: - return [Value(double_value=self.options.value), Value(string_value="percent")] + return [ + Value(double_value=self.options.value), + Value(string_value="percent"), + ] class Select(Stage): """Selects or creates a set of fields.""" + def __init__(self, *selections: str | Selectable): super().__init__() self.projections = [Field(s) if isinstance(s, str) else s for s in selections] def _pb_args(self) -> list[Value]: - return [Value(map_value={"fields": {m[0]: m[1] for m in [f._to_map() for f in self.projections]}})] + return [ + Value( + map_value={ + "fields": { + m[0]: m[1] for m in [f._to_map() for f in self.projections] + } + } + ) + ] class Sort(Stage): """Sorts documents based on specified criteria.""" + def __init__(self, *orders: "Ordering"): super().__init__() self.orders = list(orders) @@ -288,6 +361,7 @@ def _pb_args(self): class Union(Stage): """Performs a union of documents from two pipelines.""" + def __init__(self, other: Pipeline): super().__init__() self.other = other @@ -298,7 +372,13 @@ def _pb_args(self): class Unnest(Stage): """Produces a document for each element in an array field.""" - def __init__(self, field: Selectable | str, alias: Field | str | None=None, options: UnnestOptions|None=None): + + def __init__( + self, + field: Selectable | str, + alias: Field | str | None = None, + options: UnnestOptions | None = None, + ): super().__init__() self.field: Selectable = Field(field) if isinstance(field, str) else field if alias is None: @@ -321,10 +401,10 @@ def _pb_options(self): class Where(Stage): """Filters documents based on a specified condition.""" + def __init__(self, condition: FilterCondition): super().__init__() self.condition = condition def _pb_args(self): return [self.condition._to_pb()] - diff --git a/google/cloud/firestore_v1/services/firestore/transports/rest.py b/google/cloud/firestore_v1/services/firestore/transports/rest.py index 6364c4e3f..6a4ae16b4 100644 --- a/google/cloud/firestore_v1/services/firestore/transports/rest.py +++ b/google/cloud/firestore_v1/services/firestore/transports/rest.py @@ -1145,9 +1145,9 @@ def __call__( query_params.update(self._get_unset_required_fields(query_params)) query_params["$alt"] = "json;enum-encoding=int" + class _ExecutePipeline(FirestoreRestStub): def __hash__(self): - # Send the request headers = dict(metadata) headers["Content-Type"] = "application/json" @@ -1230,6 +1230,7 @@ def __call__( query_params.update(self._get_unset_required_fields(query_params)) query_params["$alt"] = "json;enum-encoding=int" + class _ExecutePipeline(FirestoreRestStub): def __hash__(self): return hash("ExecutePipeline") diff --git a/tests/system/test_pipeline_acceptance.py b/tests/system/test_pipeline_acceptance.py index 644cb6dc6..7266764ce 100644 --- a/tests/system/test_pipeline_acceptance.py +++ b/tests/system/test_pipeline_acceptance.py @@ -44,9 +44,11 @@ def yaml_loader(field="tests"): test_cases = yaml.safe_load(f) return test_cases[field] + @pytest.fixture(scope="session") def event_loop(): import asyncio + try: loop = asyncio.get_running_loop() except RuntimeError: diff --git a/tests/unit/v1/test_base_query.py b/tests/unit/v1/test_base_query.py index 517d176f7..d6762d846 100644 --- a/tests/unit/v1/test_base_query.py +++ b/tests/unit/v1/test_base_query.py @@ -1951,6 +1951,7 @@ def test__collection_group_query_response_to_snapshot_response(): assert snapshot.create_time == response_pb._pb.document.create_time assert snapshot.update_time == response_pb._pb.document.update_time + def test__query_pipeline_decendants(): from google.cloud.firestore_v1 import pipeline_stages From 0b4c2944df9b50f51a61e70acbd9e8d0eab7bf7f Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 2 May 2025 14:53:45 -0700 Subject: [PATCH 100/131] turn Stage into an ABC --- google/cloud/firestore_v1/pipeline_stages.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/google/cloud/firestore_v1/pipeline_stages.py b/google/cloud/firestore_v1/pipeline_stages.py index 547d22996..686aaf2a0 100644 --- a/google/cloud/firestore_v1/pipeline_stages.py +++ b/google/cloud/firestore_v1/pipeline_stages.py @@ -14,6 +14,8 @@ from __future__ import annotations from typing import Any, Dict, Iterable, List, Optional, Sequence, TYPE_CHECKING +from abc import ABC +from abc import abstractmethod from enum import Enum from enum import auto @@ -67,7 +69,7 @@ def __init__(self, index_field: str): self.index_field = index_field -class Stage: +class Stage(ABC): """Base class for all pipeline stages. Each stage represents a specific operation (e.g., filtering, sorting, @@ -83,9 +85,10 @@ def _to_pb(self) -> Pipeline_pb.Stage: name=self.name, args=self._pb_args(), options=self._pb_options() ) + @abstractmethod def _pb_args(self) -> list[Value]: """Return Ordered list of arguments the given stage expects""" - return [] + raise NotImplementedError def _pb_options(self) -> dict[str, Value]: """Return optional named arguments that certain functions may support.""" From 255b6980e4386b8a7f94a4de5d5e25d65d4b58a5 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 2 May 2025 16:45:16 -0700 Subject: [PATCH 101/131] removed extra stages and expressions --- google/cloud/firestore_v1/base_pipeline.py | 359 +---- .../firestore_v1/pipeline_expressions.py | 992 +------------ google/cloud/firestore_v1/pipeline_stages.py | 261 +--- tests/system/pipeline_e2e.yaml | 1238 ++--------------- tests/system/test_pipeline_acceptance.py | 38 - 5 files changed, 99 insertions(+), 2789 deletions(-) diff --git a/google/cloud/firestore_v1/base_pipeline.py b/google/cloud/firestore_v1/base_pipeline.py index 75f7c3e8d..117dceb95 100644 --- a/google/cloud/firestore_v1/base_pipeline.py +++ b/google/cloud/firestore_v1/base_pipeline.py @@ -13,24 +13,16 @@ # limitations under the License. from __future__ import annotations -from typing import Optional, Sequence from typing_extensions import Self from google.cloud.firestore_v1 import pipeline_stages as stages from google.cloud.firestore_v1.base_client import BaseClient from google.cloud.firestore_v1.types.pipeline import ( StructuredPipeline as StructuredPipeline_pb, ) -from google.cloud.firestore_v1.vector import Vector -from google.cloud.firestore_v1.base_vector_query import DistanceMeasure from google.cloud.firestore_v1 import _helpers, document from google.cloud.firestore_v1.pipeline_expressions import ( - Accumulator, - Expr, - ExprWithAlias, - Field, FilterCondition, Selectable, - SampleOptions, ) @@ -88,71 +80,16 @@ def _parse_response(response_pb, client): update_time=doc.update_time, ) - def add_fields(self, *fields: Selectable) -> Self: - """ - Adds new fields to outputs from previous stages. - - This stage allows you to compute values on-the-fly based on existing data - from previous stages or constants. You can use this to create new fields - or overwrite existing ones (if there is name overlap). - - The added fields are defined using `Selectable` expressions, which can be: - - `Field`: References an existing document field. - - `Function`: Performs a calculation using functions like `add`, - `multiply` with assigned aliases using `Expr.as_()`. - - Example: - >>> from google.cloud.firestore_v1.pipeline_expressions import Field, add - >>> pipeline = client.collection("books").pipeline() - >>> pipeline = pipeline.add_fields( - ... Field.of("rating").as_("bookRating"), # Rename 'rating' to 'bookRating' - ... add(5, Field.of("quantity")).as_("totalCost") # Calculate 'totalCost' - ... ) - - Args: - *fields: The fields to add to the documents, specified as `Selectable` - expressions. - - Returns: - A new Pipeline object with this stage appended to the stage list - """ - return self._append(stages.AddFields(*fields)) - - def remove_fields(self, *fields: Field | str) -> Self: - """ - Removes fields from outputs of previous stages. - - Example: - >>> from google.cloud.firestore_v1.pipeline_expressions import Field - >>> pipeline = client.collection("books").pipeline() - >>> # Remove by name - >>> pipeline = pipeline.remove_fields("rating", "cost") - >>> # Remove by Field object - >>> pipeline = pipeline.remove_fields(Field.of("rating"), Field.of("cost")) - - - Args: - *fields: The fields to remove, specified as field names (str) or - `Field` objects. - - Returns: - A new Pipeline object with this stage appended to the stage list - """ - return self._append(stages.RemoveFields(*fields)) - def select(self, *selections: str | Selectable) -> Self: """ Selects or creates a set of fields from the outputs of previous stages. - The selected fields are defined using `Selectable` expressions or field names: - `Field`: References an existing document field. - `Function`: Represents the result of a function with an assigned alias name using `Expr.as_()`. - `str`: The name of an existing field. - If no selections are provided, the output of this stage is empty. Use `add_fields()` instead if only additions are desired. - Example: >>> from google.cloud.firestore_v1.pipeline_expressions import Field, to_upper >>> pipeline = client.collection("books").pipeline() @@ -163,11 +100,9 @@ def select(self, *selections: str | Selectable) -> Self: ... Field.of("name"), ... Field.of("address").to_upper().as_("upperAddress"), ... ) - Args: *selections: The fields to include in the output documents, specified as field names (str) or `Selectable` expressions. - Returns: A new Pipeline object with this stage appended to the stage list """ @@ -177,14 +112,12 @@ def where(self, condition: FilterCondition) -> Self: """ Filters the documents from previous stages to only include those matching the specified `FilterCondition`. - This stage allows you to apply conditions to the data, similar to a "WHERE" clause in SQL. You can filter documents based on their field values, using implementations of `FilterCondition`, typically including but not limited to: - field comparators: `eq`, `lt` (less than), `gt` (greater than), etc. - logical operators: `And`, `Or`, `Not`, etc. - advanced functions: `regex_matches`, `array_contains`, etc. - Example: >>> from google.cloud.firestore_v1.pipeline_expressions import Field, And, >>> pipeline = client.collection("books").pipeline() @@ -202,83 +135,22 @@ def where(self, condition: FilterCondition) -> Self: ... Field.of("genre").eq("Science Fiction") ... ) ... ) - - Args: condition: The `FilterCondition` to apply. - Returns: A new Pipeline object with this stage appended to the stage list """ return self._append(stages.Where(condition)) - def find_nearest( - self, - field: str | Expr, - vector: Sequence[float] | "Vector", - distance_measure: "DistanceMeasure", - options: Optional[stages.FindNearestOptions] = None, - ) -> Self: - """ - Performs vector distance (similarity) search with given parameters on the - stage inputs. - - This stage adds a "nearest neighbor search" capability to your pipelines. - Given a field or expression that evaluates to a vector and a target vector, - this stage will identify and return the inputs whose vector is closest to - the target vector, using the specified distance measure and options. - - Example: - >>> from google.cloud.firestore_v1.base_vector_query import DistanceMeasure - >>> from google.cloud.firestore_v1.pipeline_stages import FindNearestOptions - >>> from google.cloud.firestore_v1.pipeline_expressions import Field - >>> - >>> target_vector = [0.1, 0.2, 0.3] - >>> pipeline = client.collection("books").pipeline() - >>> # Find using field name - >>> pipeline = pipeline.find_nearest( - ... "topicVectors", - ... target_vector, - ... DistanceMeasure.COSINE, - ... options=FindNearestOptions(limit=10, distance_field="distance") - ... ) - >>> # Find using Field expression - >>> pipeline = pipeline.find_nearest( - ... Field.of("topicVectors"), - ... target_vector, - ... DistanceMeasure.COSINE, - ... options=FindNearestOptions(limit=10, distance_field="distance") - ... ) - - Args: - field: The name of the field (str) or an expression (`Expr`) that - evaluates to the vector data. This field should store vector values. - vector: The target vector (sequence of floats or `Vector` object) to - compare against. - distance_measure: The distance measure (`DistanceMeasure`) to use - (e.g., `DistanceMeasure.COSINE`, `DistanceMeasure.EUCLIDEAN`). - limit: The maximum number of nearest neighbors to return. - options: Configuration options (`FindNearestOptions`) for the search, - such as limit and output distance field name. - - Returns: - A new Pipeline object with this stage appended to the stage list - """ - return self._append( - stages.FindNearest(field, vector, distance_measure, options) - ) - def sort(self, *orders: stages.Ordering) -> Self: """ Sorts the documents from previous stages based on one or more `Ordering` criteria. - This stage allows you to order the results of your pipeline. You can specify multiple `Ordering` instances to sort by multiple fields or expressions in ascending or descending order. If documents have the same value for a sorting criterion, the next specified ordering will be used. If all orderings result in equal comparison, the documents are considered equal and the relative order is unspecified. - Example: >>> from google.cloud.firestore_v1.pipeline_expressions import Field >>> pipeline = client.collection("books").pipeline() @@ -287,161 +159,19 @@ def sort(self, *orders: stages.Ordering) -> Self: ... Field.of("rating").descending(), ... Field.of("title").ascending() ... ) - Args: *orders: One or more `Ordering` instances specifying the sorting criteria. - Returns: A new Pipeline object with this stage appended to the stage list """ return self._append(stages.Sort(*orders)) - def sample(self, limit_or_options: int | SampleOptions) -> Self: - """ - Performs a pseudo-random sampling of the documents from the previous stage. - - This stage filters documents pseudo-randomly. - - If an `int` limit is provided, it specifies the maximum number of documents - to emit. If fewer documents are available, all are passed through. - - If `SampleOptions` are provided, they specify how sampling is performed - (e.g., by document count or percentage). - - Example: - >>> from google.cloud.firestore_v1.pipeline_expressions import SampleOptions - >>> pipeline = client.collection("books").pipeline() - >>> # Sample 10 books, if available. - >>> pipeline = pipeline.sample(10) - >>> pipeline = pipeline.sample(SampleOptions.doc_limit(10)) - >>> # Sample 50% of books. - >>> pipeline = pipeline.sample(SampleOptions.percentage(0.5)) - - - Args: - limit_or_options: Either an integer specifying the maximum number of - documents to sample, or a `SampleOptions` object. - - Returns: - A new Pipeline object with this stage appended to the stage list - """ - return self._append(stages.Sample(limit_or_options)) - - def union(self, other: Self) -> Self: - """ - Performs a union of all documents from this pipeline and another pipeline, - including duplicates. - - This stage passes through documents from the previous stage of this pipeline, - and also passes through documents from the previous stage of the `other` - pipeline provided. The order of documents emitted from this stage is undefined. - - Example: - >>> books_pipeline = client.collection("books").pipeline() - >>> magazines_pipeline = client.collection("magazines").pipeline() - >>> # Emit documents from both collections - >>> combined_pipeline = books_pipeline.union(magazines_pipeline) - - Args: - other: The other `Pipeline` whose results will be unioned with this one. - - Returns: - A new Pipeline object with this stage appended to the stage list - """ - return self._append(stages.Union(other)) - - def unnest( - self, - field: str | Selectable, - alias: str | Field | None = None, - options: Optional[stages.UnnestOptions] = None, - ) -> Self: - """ - Produces a document for each element in an array field from the previous stage document. - - For each previous stage document, this stage will emit zero or more augmented documents. The - input array found in the previous stage document field specified by the `fieldName` parameter, - will emit an augmented document for each input array element. The input array element will - augment the previous stage document by setting the `alias` field with the array element value. - If `alias` is unset, the data in `field` will be overwritten. - - Example: - Input document: - ```json - { "title": "The Hitchhiker's Guide", "tags": [ "comedy", "sci-fi" ], ... } - ``` - - >>> from google.cloud.firestore_v1.pipeline_stages import UnnestOptions - >>> pipeline = client.collection("books").pipeline() - >>> # Emit a document for each tag - >>> pipeline = pipeline.unnest("tags", alias="tag") - - Output documents (without options): - ```json - { "title": "The Hitchhiker's Guide", "tag": "comedy", ... } - { "title": "The Hitchhiker's Guide", "tag": "sci-fi", ... } - ``` - - Optionally, `UnnestOptions` can specify a field to store the original index - of the element within the array - - Example: - Input document: - ```json - { "title": "The Hitchhiker's Guide", "tags": [ "comedy", "sci-fi" ], ... } - ``` - - >>> from google.cloud.firestore_v1.pipeline_stages import UnnestOptions - >>> pipeline = client.collection("books").pipeline() - >>> # Emit a document for each tag, including the index - >>> pipeline = pipeline.unnest("tags", options=UnnestOptions(index_field="tagIndex")) - - Output documents (with index_field="tagIndex"): - ```json - { "title": "The Hitchhiker's Guide", "tags": "comedy", "tagIndex": 0, ... } - { "title": "The Hitchhiker's Guide", "tags": "sci-fi", "tagIndex": 1, ... } - ``` - - Args: - field: The name of the field containing the array to unnest. - alias The alias field is used as the field name for each element within the output array. - If unset, or if `alias` matches the `field`, the output data will overwrite the original field. - options: Optional `UnnestOptions` to configure additional behavior, like adding an index field. - - Returns: - A new Pipeline object with this stage appended to the stage list - """ - return self._append(stages.Unnest(field, alias, options)) - - def generic_stage(self, name: str, *params: Expr) -> Self: - """ - Adds a generic, named stage to the pipeline with specified parameters. - - This method provides a flexible way to extend the pipeline's functionality - by adding custom stages. Each generic stage is defined by a unique `name` - and a set of `params` that control its behavior. - - Example: - >>> # Assume we don't have a built-in "where" stage - >>> pipeline = client.collection("books").pipeline() - >>> pipeline = pipeline.generic_stage("where", [Field.of("published").lt(900)]) - >>> pipeline = pipeline.select("title", "author") - - Args: - name: The name of the generic stage. - *params: A sequence of `Expr` objects representing the parameters for the stage. - - Returns: - A new Pipeline object with this stage appended to the stage list - """ - return self._append(stages.GenericStage(name, *params)) - def offset(self, offset: int) -> Self: """ Skips the first `offset` number of documents from the results of previous stages. - This stage is useful for implementing pagination, allowing you to retrieve results in chunks. It is typically used in conjunction with `limit()` to control the size of each page. - Example: >>> from google.cloud.firestore_v1.pipeline_expressions import Field >>> pipeline = client.collection("books").pipeline() @@ -449,10 +179,8 @@ def offset(self, offset: int) -> Self: >>> pipeline = pipeline.sort(Field.of("published").descending()) >>> pipeline = pipeline.offset(20) # Skip the first 20 results >>> pipeline = pipeline.limit(20) # Take the next 20 results - Args: offset: The non-negative number of documents to skip. - Returns: A new Pipeline object with this stage appended to the stage list """ @@ -461,104 +189,19 @@ def offset(self, offset: int) -> Self: def limit(self, limit: int) -> Self: """ Limits the maximum number of documents returned by previous stages to `limit`. - This stage is useful for controlling the size of the result set, often used for: - **Pagination:** In combination with `offset()` to retrieve specific pages. - **Top-N queries:** To get a limited number of results after sorting. - **Performance:** To prevent excessive data transfer. - Example: >>> from google.cloud.firestore_v1.pipeline_expressions import Field >>> pipeline = client.collection("books").pipeline() >>> # Limit the results to the top 10 highest-rated books >>> pipeline = pipeline.sort(Field.of("rating").descending()) >>> pipeline = pipeline.limit(10) - Args: limit: The non-negative maximum number of documents to return. - - Returns: - A new Pipeline object with this stage appended to the stage list - """ - return self._append(stages.Limit(limit)) - - def aggregate( - self, - *accumulators: ExprWithAlias[Accumulator], - groups: Sequence[str | Selectable] = (), - ) -> Self: - """ - Performs aggregation operations on the documents from previous stages, - optionally grouped by specified fields or expressions. - - This stage allows you to calculate aggregate values (like sum, average, count, - min, max) over a set of documents. - - - **Accumulators:** Define the aggregation calculations using `Accumulator` - expressions (e.g., `sum()`, `avg()`, `count()`, `min()`, `max()`) combined - with `as_()` to name the result field. - - **Groups:** Optionally specify fields (by name or `Selectable`) to group - the documents by. Aggregations are then performed within each distinct group. - If no groups are provided, the aggregation is performed over the entire input. - - Example: - >>> from google.cloud.firestore_v1.pipeline_expressions import Field, avg, count_all - >>> pipeline = client.collection("books").pipeline() - >>> # Calculate the average rating and total count for all books - >>> pipeline = pipeline.aggregate( - ... avg(Field.of("rating")).as_("averageRating"), - ... count_all().as_("totalBooks") - ... ) - >>> # Calculate the average rating for each genre - >>> pipeline = pipeline.aggregate( - ... avg(Field.of("rating")).as_("avg_rating"), - ... groups=["genre"] # Group by the 'genre' field - ... ) - >>> # Calculate the count for each author, grouping by Field object - >>> pipeline = pipeline.aggregate( - ... count_all().as_("bookCount"), - ... groups=[Field.of("author")] - ... ) - - - Args: - *accumulators: One or more `ExprWithAlias[Accumulator]` expressions defining - the aggregations to perform and their output names. - groups: An optional sequence of field names (str) or `Selectable` - expressions to group by before aggregating. - - Returns: - A new Pipeline object with this stage appended to the stage list - """ - return self._append(stages.Aggregate(*accumulators, groups=groups)) - - def distinct(self, *fields: str | Selectable) -> Self: - """ - Returns documents with distinct combinations of values for the specified - fields or expressions. - - This stage filters the results from previous stages to include only one - document for each unique combination of values in the specified `fields`. - The output documents contain only the fields specified in the `distinct` call. - - Example: - >>> from google.cloud.firestore_v1.pipeline_expressions import Field, to_upper - >>> pipeline = client.collection("books").pipeline() - >>> # Get a list of unique genres (output has only 'genre' field) - >>> pipeline = pipeline.distinct("genre") - >>> # Get unique combinations of author (uppercase) and genre - >>> pipeline = pipeline.distinct( - ... Field.of("author").to_upper().as_("authorUpper"), - ... Field.of("genre") - ... ) - - - Args: - *fields: Field names (str) or `Selectable` expressions to consider when - determining distinct value combinations. The output will only - contain these fields/expressions. - Returns: A new Pipeline object with this stage appended to the stage list """ - return self._append(stages.Distinct(*fields)) + return self._append(stages.Limit(limit)) \ No newline at end of file diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index 1eda32713..fd5cb618e 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -15,23 +15,16 @@ from __future__ import annotations from typing import ( Any, - Iterable, - List, - Mapping, - Union, Generic, TypeVar, List, Dict, - Tuple, Sequence, ) from abc import ABC from abc import abstractmethod from enum import Enum -from enum import auto import datetime -from dataclasses import dataclass from google.cloud.firestore_v1.types.document import Value from google.cloud.firestore_v1.types.query import StructuredQuery as Query_pb from google.cloud.firestore_v1.vector import Vector @@ -65,7 +58,6 @@ class Direction(Enum): def __init__(self, expr, order_dir: Direction | str = Direction.ASCENDING): """ Initializes an Ordering instance - Args: expr (Expr | str): The expression or field path string to sort by. If a string is provided, it's treated as a field path. @@ -97,45 +89,6 @@ def _to_pb(self) -> Value: ) -class SampleOptions: - """Options for the 'sample' pipeline stage.""" - - class Mode(Enum): - DOCUMENTS = "documents" - PERCENT = "percent" - - def __init__(self, value: int | float, mode: Mode | str): - self.value = value - self.mode = SampleOptions.Mode[mode.upper()] if isinstance(mode, str) else mode - - def __repr__(self): - if self.mode == SampleOptions.Mode.DOCUMENTS: - mode_str = "doc_limit" - else: - mode_str = "percentage" - return f"SampleOptions.{mode_str}({self.value})" - - @staticmethod - def doc_limit(value: int): - """ - Sample a set number of documents - - Args: - value: number of documents to sample - """ - return SampleOptions(value, mode=SampleOptions.Mode.DOCUMENTS) - - @staticmethod - def percentage(value: float): - """ - Sample a percentage of documents - - Args: - value: percentage of documents to return - """ - return SampleOptions(value, mode=SampleOptions.Mode.PERCENT) - - class Expr(ABC): """Represents an expression that can be evaluated to a value within the execution of a pipeline. @@ -163,146 +116,17 @@ def _to_pb(self) -> Value: def _cast_to_expr_or_convert_to_constant(o: Any) -> "Expr": return o if isinstance(o, Expr) else Constant(o) - def add(self, other: Expr | float) -> "Add": - """Creates an expression that adds this expression to another expression or constant. - - Example: - >>> # Add the value of the 'quantity' field and the 'reserve' field. - >>> Field.of("quantity").add(Field.of("reserve")) - >>> # Add 5 to the value of the 'age' field - >>> Field.of("age").add(5) - - Args: - other: The expression or constant value to add to this expression. - - Returns: - A new `Expr` representing the addition operation. - """ - return Add(self, self._cast_to_expr_or_convert_to_constant(other)) - - def subtract(self, other: Expr | float) -> "Subtract": - """Creates an expression that subtracts another expression or constant from this expression. - - Example: - >>> # Subtract the 'discount' field from the 'price' field - >>> Field.of("price").subtract(Field.of("discount")) - >>> # Subtract 20 from the value of the 'total' field - >>> Field.of("total").subtract(20) - - Args: - other: The expression or constant value to subtract from this expression. - - Returns: - A new `Expr` representing the subtraction operation. - """ - return Subtract(self, self._cast_to_expr_or_convert_to_constant(other)) - - def multiply(self, other: Expr | float) -> "Multiply": - """Creates an expression that multiplies this expression by another expression or constant. - - Example: - >>> # Multiply the 'quantity' field by the 'price' field - >>> Field.of("quantity").multiply(Field.of("price")) - >>> # Multiply the 'value' field by 2 - >>> Field.of("value").multiply(2) - - Args: - other: The expression or constant value to multiply by. - - Returns: - A new `Expr` representing the multiplication operation. - """ - return Multiply(self, self._cast_to_expr_or_convert_to_constant(other)) - - def divide(self, other: Expr | float) -> "Divide": - """Creates an expression that divides this expression by another expression or constant. - - Example: - >>> # Divide the 'total' field by the 'count' field - >>> Field.of("total").divide(Field.of("count")) - >>> # Divide the 'value' field by 10 - >>> Field.of("value").divide(10) - - Args: - other: The expression or constant value to divide by. - - Returns: - A new `Expr` representing the division operation. - """ - return Divide(self, self._cast_to_expr_or_convert_to_constant(other)) - - def mod(self, other: Expr | float) -> "Mod": - """Creates an expression that calculates the modulo (remainder) to another expression or constant. - - Example: - >>> # Calculate the remainder of dividing the 'value' field by field 'divisor'. - >>> Field.of("value").mod(Field.of("divisor")) - >>> # Calculate the remainder of dividing the 'value' field by 5. - >>> Field.of("value").mod(5) - - Args: - other: The divisor expression or constant. - - Returns: - A new `Expr` representing the modulo operation. - """ - return Mod(self, self._cast_to_expr_or_convert_to_constant(other)) - - def logical_max(self, other: Expr | CONSTANT_TYPE) -> "LogicalMax": - """Creates an expression that returns the larger value between this expression - and another expression or constant, based on Firestore's value type ordering. - - Firestore's value type ordering is described here: - https://cloud.google.com/firestore/docs/concepts/data-types#value_type_ordering - - Example: - >>> # Returns the larger value between the 'discount' field and the 'cap' field. - >>> Field.of("discount").logical_max(Field.of("cap")) - >>> # Returns the larger value between the 'value' field and 10. - >>> Field.of("value").logical_max(10) - - Args: - other: The other expression or constant value to compare with. - - Returns: - A new `Expr` representing the logical max operation. - """ - return LogicalMax(self, self._cast_to_expr_or_convert_to_constant(other)) - - def logical_min(self, other: Expr | CONSTANT_TYPE) -> "LogicalMin": - """Creates an expression that returns the smaller value between this expression - and another expression or constant, based on Firestore's value type ordering. - - Firestore's value type ordering is described here: - https://cloud.google.com/firestore/docs/concepts/data-types#value_type_ordering - - Example: - >>> # Returns the smaller value between the 'discount' field and the 'floor' field. - >>> Field.of("discount").logical_min(Field.of("floor")) - >>> # Returns the smaller value between the 'value' field and 10. - >>> Field.of("value").logical_min(10) - - Args: - other: The other expression or constant value to compare with. - - Returns: - A new `Expr` representing the logical min operation. - """ - return LogicalMin(self, self._cast_to_expr_or_convert_to_constant(other)) def eq(self, other: Expr | CONSTANT_TYPE) -> "Eq": """Creates an expression that checks if this expression is equal to another expression or constant value. - Example: >>> # Check if the 'age' field is equal to 21 >>> Field.of("age").eq(21) >>> # Check if the 'city' field is equal to "London" >>> Field.of("city").eq("London") - Args: other: The expression or constant value to compare for equality. - Returns: A new `Expr` representing the equality comparison. """ @@ -311,16 +135,13 @@ def eq(self, other: Expr | CONSTANT_TYPE) -> "Eq": def neq(self, other: Expr | CONSTANT_TYPE) -> "Neq": """Creates an expression that checks if this expression is not equal to another expression or constant value. - Example: >>> # Check if the 'status' field is not equal to "completed" >>> Field.of("status").neq("completed") >>> # Check if the 'country' field is not equal to "USA" >>> Field.of("country").neq("USA") - Args: other: The expression or constant value to compare for inequality. - Returns: A new `Expr` representing the inequality comparison. """ @@ -329,16 +150,13 @@ def neq(self, other: Expr | CONSTANT_TYPE) -> "Neq": def gt(self, other: Expr | CONSTANT_TYPE) -> "Gt": """Creates an expression that checks if this expression is greater than another expression or constant value. - Example: >>> # Check if the 'age' field is greater than the 'limit' field >>> Field.of("age").gt(Field.of("limit")) >>> # Check if the 'price' field is greater than 100 >>> Field.of("price").gt(100) - Args: other: The expression or constant value to compare for greater than. - Returns: A new `Expr` representing the greater than comparison. """ @@ -347,16 +165,13 @@ def gt(self, other: Expr | CONSTANT_TYPE) -> "Gt": def gte(self, other: Expr | CONSTANT_TYPE) -> "Gte": """Creates an expression that checks if this expression is greater than or equal to another expression or constant value. - Example: >>> # Check if the 'quantity' field is greater than or equal to field 'requirement' plus 1 >>> Field.of("quantity").gte(Field.of('requirement').add(1)) >>> # Check if the 'score' field is greater than or equal to 80 >>> Field.of("score").gte(80) - Args: other: The expression or constant value to compare for greater than or equal to. - Returns: A new `Expr` representing the greater than or equal to comparison. """ @@ -365,16 +180,13 @@ def gte(self, other: Expr | CONSTANT_TYPE) -> "Gte": def lt(self, other: Expr | CONSTANT_TYPE) -> "Lt": """Creates an expression that checks if this expression is less than another expression or constant value. - Example: >>> # Check if the 'age' field is less than 'limit' >>> Field.of("age").lt(Field.of('limit')) >>> # Check if the 'price' field is less than 50 >>> Field.of("price").lt(50) - Args: other: The expression or constant value to compare for less than. - Returns: A new `Expr` representing the less than comparison. """ @@ -383,16 +195,13 @@ def lt(self, other: Expr | CONSTANT_TYPE) -> "Lt": def lte(self, other: Expr | CONSTANT_TYPE) -> "Lte": """Creates an expression that checks if this expression is less than or equal to another expression or constant value. - Example: >>> # Check if the 'quantity' field is less than or equal to 20 >>> Field.of("quantity").lte(Constant.of(20)) >>> # Check if the 'score' field is less than or equal to 70 >>> Field.of("score").lte(70) - Args: other: The expression or constant value to compare for less than or equal to. - Returns: A new `Expr` representing the less than or equal to comparison. """ @@ -401,14 +210,11 @@ def lte(self, other: Expr | CONSTANT_TYPE) -> "Lte": def in_any(self, array: List[Expr | CONSTANT_TYPE]) -> "In": """Creates an expression that checks if this expression is equal to any of the provided values or expressions. - Example: >>> # Check if the 'category' field is either "Electronics" or value of field 'primaryType' >>> Field.of("category").in_any(["Electronics", Field.of("primaryType")]) - Args: array: The values or expressions to check against. - Returns: A new `Expr` representing the 'IN' comparison. """ @@ -417,14 +223,11 @@ def in_any(self, array: List[Expr | CONSTANT_TYPE]) -> "In": def not_in_any(self, array: List[Expr | CONSTANT_TYPE]) -> "Not": """Creates an expression that checks if this expression is not equal to any of the provided values or expressions. - Example: >>> # Check if the 'status' field is neither "pending" nor "cancelled" >>> Field.of("status").not_in_any(["pending", "cancelled"]) - Args: *others: The values or expressions to check against. - Returns: A new `Expr` representing the 'NOT IN' comparison. """ @@ -432,57 +235,30 @@ def not_in_any(self, array: List[Expr | CONSTANT_TYPE]) -> "Not": def array_contains(self, element: Expr | CONSTANT_TYPE) -> "ArrayContains": """Creates an expression that checks if an array contains a specific element or value. - Example: >>> # Check if the 'sizes' array contains the value from the 'selectedSize' field >>> Field.of("sizes").array_contains(Field.of("selectedSize")) >>> # Check if the 'colors' array contains "red" >>> Field.of("colors").array_contains("red") - Args: element: The element (expression or constant) to search for in the array. - Returns: A new `Expr` representing the 'array_contains' comparison. """ return ArrayContains(self, self._cast_to_expr_or_convert_to_constant(element)) - def array_contains_all( - self, elements: List[Expr | CONSTANT_TYPE] - ) -> "ArrayContainsAll": - """Creates an expression that checks if an array contains all the specified elements. - - Example: - >>> # Check if the 'tags' array contains both "news" and "sports" - >>> Field.of("tags").array_contains_all(["news", "sports"]) - >>> # Check if the 'tags' array contains both of the values from field 'tag1' and "tag2" - >>> Field.of("tags").array_contains_all([Field.of("tag1"), "tag2"]) - - Args: - elements: The list of elements (expressions or constants) to check for in the array. - - Returns: - A new `Expr` representing the 'array_contains_all' comparison. - """ - return ArrayContainsAll( - self, [self._cast_to_expr_or_convert_to_constant(e) for e in elements] - ) - def array_contains_any( self, elements: List[Expr | CONSTANT_TYPE] ) -> "ArrayContainsAny": """Creates an expression that checks if an array contains any of the specified elements. - Example: >>> # Check if the 'categories' array contains either values from field "cate1" or "cate2" >>> Field.of("categories").array_contains_any([Field.of("cate1"), Field.of("cate2")]) >>> # Check if the 'groups' array contains either the value from the 'userGroup' field >>> # or the value "guest" >>> Field.of("groups").array_contains_any([Field.of("userGroup"), "guest"]) - Args: elements: The list of elements (expressions or constants) to check for in the array. - Returns: A new `Expr` representing the 'array_contains_any' comparison. """ @@ -490,37 +266,11 @@ def array_contains_any( self, [self._cast_to_expr_or_convert_to_constant(e) for e in elements] ) - def array_length(self) -> "ArrayLength": - """Creates an expression that calculates the length of an array. - - Example: - >>> # Get the number of items in the 'cart' array - >>> Field.of("cart").array_length() - - Returns: - A new `Expr` representing the length of the array. - """ - return ArrayLength(self) - - def array_reverse(self) -> "ArrayReverse": - """Creates an expression that returns the reversed content of an array. - - Example: - >>> # Get the 'preferences' array in reversed order. - >>> Field.of("preferences").array_reverse() - - Returns: - A new `Expr` representing the reversed array. - """ - return ArrayReverse(self) - def is_nan(self) -> "IsNaN": """Creates an expression that checks if this expression evaluates to 'NaN' (Not a Number). - Example: >>> # Check if the result of a calculation is NaN >>> Field.of("value").divide(0).is_nan() - Returns: A new `Expr` representing the 'isNaN' check. """ @@ -528,426 +278,14 @@ def is_nan(self) -> "IsNaN": def exists(self) -> "Exists": """Creates an expression that checks if a field exists in the document. - Example: >>> # Check if the document has a field named "phoneNumber" >>> Field.of("phoneNumber").exists() - Returns: A new `Expr` representing the 'exists' check. """ return Exists(self) - def sum(self) -> "Sum": - """Creates an aggregation that calculates the sum of a numeric field across multiple stage inputs. - - Example: - >>> # Calculate the total revenue from a set of orders - >>> Field.of("orderAmount").sum().as_("totalRevenue") - - Returns: - A new `Accumulator` representing the 'sum' aggregation. - """ - return Sum(self, False) - - def avg(self) -> "Avg": - """Creates an aggregation that calculates the average (mean) of a numeric field across multiple - stage inputs. - - Example: - >>> # Calculate the average age of users - >>> Field.of("age").avg().as_("averageAge") - - Returns: - A new `Accumulator` representing the 'avg' aggregation. - """ - return Avg(self, False) - - def count(self) -> "Count": - """Creates an aggregation that counts the number of stage inputs with valid evaluations of the - expression or field. - - Example: - >>> # Count the total number of products - >>> Field.of("productId").count().as_("totalProducts") - - Returns: - A new `Accumulator` representing the 'count' aggregation. - """ - return Count(self) - - def min(self) -> "Min": - """Creates an aggregation that finds the minimum value of a field across multiple stage inputs. - - Example: - >>> # Find the lowest price of all products - >>> Field.of("price").min().as_("lowestPrice") - - Returns: - A new `Accumulator` representing the 'min' aggregation. - """ - return Min(self, False) - - def max(self) -> "Max": - """Creates an aggregation that finds the maximum value of a field across multiple stage inputs. - - Example: - >>> # Find the highest score in a leaderboard - >>> Field.of("score").max().as_("highestScore") - - Returns: - A new `Accumulator` representing the 'max' aggregation. - """ - return Max(self, False) - - def char_length(self) -> "CharLength": - """Creates an expression that calculates the character length of a string. - - Example: - >>> # Get the character length of the 'name' field - >>> Field.of("name").char_length() - - Returns: - A new `Expr` representing the length of the string. - """ - return CharLength(self) - - def byte_length(self) -> "ByteLength": - """Creates an expression that calculates the byte length of a string in its UTF-8 form. - - Example: - >>> # Get the byte length of the 'name' field - >>> Field.of("name").byte_length() - - Returns: - A new `Expr` representing the byte length of the string. - """ - return ByteLength(self) - - def like(self, pattern: Expr | str) -> "Like": - """Creates an expression that performs a case-sensitive string comparison. - - Example: - >>> # Check if the 'title' field contains the word "guide" (case-sensitive) - >>> Field.of("title").like("%guide%") - >>> # Check if the 'title' field matches the pattern specified in field 'pattern'. - >>> Field.of("title").like(Field.of("pattern")) - - Args: - pattern: The pattern (string or expression) to search for. You can use "%" as a wildcard character. - - Returns: - A new `Expr` representing the 'like' comparison. - """ - return Like(self, self._cast_to_expr_or_convert_to_constant(pattern)) - - def regex_contains(self, regex: Expr | str) -> "RegexContains": - """Creates an expression that checks if a string contains a specified regular expression as a - substring. - - Example: - >>> # Check if the 'description' field contains "example" (case-insensitive) - >>> Field.of("description").regex_contains("(?i)example") - >>> # Check if the 'description' field contains the regular expression stored in field 'regex' - >>> Field.of("description").regex_contains(Field.of("regex")) - - Args: - regex: The regular expression (string or expression) to use for the search. - - Returns: - A new `Expr` representing the 'contains' comparison. - """ - return RegexContains(self, self._cast_to_expr_or_convert_to_constant(regex)) - - def regex_matches(self, regex: Expr | str) -> "RegexMatch": - """Creates an expression that checks if a string matches a specified regular expression. - - Example: - >>> # Check if the 'email' field matches a valid email pattern - >>> Field.of("email").regex_matches("[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\\.[A-Za-z]{2,}") - >>> # Check if the 'email' field matches a regular expression stored in field 'regex' - >>> Field.of("email").regex_matches(Field.of("regex")) - - Args: - regex: The regular expression (string or expression) to use for the match. - - Returns: - A new `Expr` representing the regular expression match. - """ - return RegexMatch(self, self._cast_to_expr_or_convert_to_constant(regex)) - - def str_contains(self, substring: Expr | str) -> "StrContains": - """Creates an expression that checks if this string expression contains a specified substring. - - Example: - >>> # Check if the 'description' field contains "example". - >>> Field.of("description").str_contains("example") - >>> # Check if the 'description' field contains the value of the 'keyword' field. - >>> Field.of("description").str_contains(Field.of("keyword")) - - Args: - substring: The substring (string or expression) to use for the search. - - Returns: - A new `Expr` representing the 'contains' comparison. - """ - return StrContains(self, self._cast_to_expr_or_convert_to_constant(substring)) - - def starts_with(self, prefix: Expr | str) -> "StartsWith": - """Creates an expression that checks if a string starts with a given prefix. - - Example: - >>> # Check if the 'name' field starts with "Mr." - >>> Field.of("name").starts_with("Mr.") - >>> # Check if the 'fullName' field starts with the value of the 'firstName' field - >>> Field.of("fullName").starts_with(Field.of("firstName")) - - Args: - prefix: The prefix (string or expression) to check for. - - Returns: - A new `Expr` representing the 'starts with' comparison. - """ - return StartsWith(self, self._cast_to_expr_or_convert_to_constant(prefix)) - - def ends_with(self, postfix: Expr | str) -> "EndsWith": - """Creates an expression that checks if a string ends with a given postfix. - - Example: - >>> # Check if the 'filename' field ends with ".txt" - >>> Field.of("filename").ends_with(".txt") - >>> # Check if the 'url' field ends with the value of the 'extension' field - >>> Field.of("url").ends_with(Field.of("extension")) - - Args: - postfix: The postfix (string or expression) to check for. - - Returns: - A new `Expr` representing the 'ends with' comparison. - """ - return EndsWith(self, self._cast_to_expr_or_convert_to_constant(postfix)) - - def str_concat(self, *elements: Expr | CONSTANT_TYPE) -> "StrConcat": - """Creates an expression that concatenates string expressions, fields or constants together. - - Example: - >>> # Combine the 'firstName', " ", and 'lastName' fields into a single string - >>> Field.of("firstName").str_concat(" ", Field.of("lastName")) - - Args: - *elements: The expressions or constants (typically strings) to concatenate. - - Returns: - A new `Expr` representing the concatenated string. - """ - return StrConcat( - *[self._cast_to_expr_or_convert_to_constant(el) for el in elements] - ) - - def map_get(self, key: str) -> "MapGet": - """Accesses a value from a map (object) field using the provided key. - - Example: - >>> # Get the 'city' value from - >>> # the 'address' map field - >>> Field.of("address").map_get("city") - - Args: - key: The key to access in the map. - - Returns: - A new `Expr` representing the value associated with the given key in the map. - """ - return MapGet(self, Constant.of(key)) - - def vector_length(self) -> "VectorLength": - """Creates an expression that calculates the length (dimension) of a Firestore Vector. - - Example: - >>> # Get the vector length (dimension) of the field 'embedding'. - >>> Field.of("embedding").vector_length() - - Returns: - A new `Expr` representing the length of the vector. - """ - return VectorLength(self) - - def timestamp_to_unix_micros(self) -> "TimestampToUnixMicros": - """Creates an expression that converts a timestamp to the number of microseconds since the epoch - (1970-01-01 00:00:00 UTC). - - Truncates higher levels of precision by rounding down to the beginning of the microsecond. - - Example: - >>> # Convert the 'timestamp' field to microseconds since the epoch. - >>> Field.of("timestamp").timestamp_to_unix_micros() - - Returns: - A new `Expr` representing the number of microseconds since the epoch. - """ - return TimestampToUnixMicros(self) - - def unix_micros_to_timestamp(self) -> "UnixMicrosToTimestamp": - """Creates an expression that converts a number of microseconds since the epoch (1970-01-01 - 00:00:00 UTC) to a timestamp. - - Example: - >>> # Convert the 'microseconds' field to a timestamp. - >>> Field.of("microseconds").unix_micros_to_timestamp() - - Returns: - A new `Expr` representing the timestamp. - """ - return UnixMicrosToTimestamp(self) - - def timestamp_to_unix_millis(self) -> "TimestampToUnixMillis": - """Creates an expression that converts a timestamp to the number of milliseconds since the epoch - (1970-01-01 00:00:00 UTC). - - Truncates higher levels of precision by rounding down to the beginning of the millisecond. - - Example: - >>> # Convert the 'timestamp' field to milliseconds since the epoch. - >>> Field.of("timestamp").timestamp_to_unix_millis() - - Returns: - A new `Expr` representing the number of milliseconds since the epoch. - """ - return TimestampToUnixMillis(self) - - def unix_millis_to_timestamp(self) -> "UnixMillisToTimestamp": - """Creates an expression that converts a number of milliseconds since the epoch (1970-01-01 - 00:00:00 UTC) to a timestamp. - - Example: - >>> # Convert the 'milliseconds' field to a timestamp. - >>> Field.of("milliseconds").unix_millis_to_timestamp() - - Returns: - A new `Expr` representing the timestamp. - """ - return UnixMillisToTimestamp(self) - - def timestamp_to_unix_seconds(self) -> "TimestampToUnixSeconds": - """Creates an expression that converts a timestamp to the number of seconds since the epoch - (1970-01-01 00:00:00 UTC). - - Truncates higher levels of precision by rounding down to the beginning of the second. - - Example: - >>> # Convert the 'timestamp' field to seconds since the epoch. - >>> Field.of("timestamp").timestamp_to_unix_seconds() - - Returns: - A new `Expr` representing the number of seconds since the epoch. - """ - return TimestampToUnixSeconds(self) - - def unix_seconds_to_timestamp(self) -> "UnixSecondsToTimestamp": - """Creates an expression that converts a number of seconds since the epoch (1970-01-01 00:00:00 - UTC) to a timestamp. - - Example: - >>> # Convert the 'seconds' field to a timestamp. - >>> Field.of("seconds").unix_seconds_to_timestamp() - - Returns: - A new `Expr` representing the timestamp. - """ - return UnixSecondsToTimestamp(self) - - def timestamp_add(self, unit: Expr | str, amount: Expr | float) -> "TimestampAdd": - """Creates an expression that adds a specified amount of time to this timestamp expression. - - Example: - >>> # Add a duration specified by the 'unit' and 'amount' fields to the 'timestamp' field. - >>> Field.of("timestamp").timestamp_add(Field.of("unit"), Field.of("amount")) - >>> # Add 1.5 days to the 'timestamp' field. - >>> Field.of("timestamp").timestamp_add("day", 1.5) - - Args: - unit: The expression or string evaluating to the unit of time to add, must be one of - 'microsecond', 'millisecond', 'second', 'minute', 'hour', 'day'. - amount: The expression or float representing the amount of time to add. - - Returns: - A new `Expr` representing the resulting timestamp. - """ - return TimestampAdd( - self, - self._cast_to_expr_or_convert_to_constant(unit), - self._cast_to_expr_or_convert_to_constant(amount), - ) - - def timestamp_sub(self, unit: Expr | str, amount: Expr | float) -> "TimestampSub": - """Creates an expression that subtracts a specified amount of time from this timestamp expression. - - Example: - >>> # Subtract a duration specified by the 'unit' and 'amount' fields from the 'timestamp' field. - >>> Field.of("timestamp").timestamp_sub(Field.of("unit"), Field.of("amount")) - >>> # Subtract 2.5 hours from the 'timestamp' field. - >>> Field.of("timestamp").timestamp_sub("hour", 2.5) - - Args: - unit: The expression or string evaluating to the unit of time to subtract, must be one of - 'microsecond', 'millisecond', 'second', 'minute', 'hour', 'day'. - amount: The expression or float representing the amount of time to subtract. - - Returns: - A new `Expr` representing the resulting timestamp. - """ - return TimestampSub( - self, - self._cast_to_expr_or_convert_to_constant(unit), - self._cast_to_expr_or_convert_to_constant(amount), - ) - - def ascending(self) -> Ordering: - """Creates an `Ordering` that sorts documents in ascending order based on this expression. - - Example: - >>> # Sort documents by the 'name' field in ascending order - >>> firestore.pipeline().collection("users").sort(Field.of("name").ascending()) - - Returns: - A new `Ordering` for ascending sorting. - """ - return Ordering(self, Ordering.Direction.ASCENDING) - - def descending(self) -> Ordering: - """Creates an `Ordering` that sorts documents in descending order based on this expression. - - Example: - >>> # Sort documents by the 'createdAt' field in descending order - >>> firestore.pipeline().collection("users").sort(Field.of("createdAt").descending()) - - Returns: - A new `Ordering` for descending sorting. - """ - return Ordering(self, Ordering.Direction.DESCENDING) - - def as_(self, alias: str) -> "ExprWithAlias": - """Assigns an alias to this expression. - - Aliases are useful for renaming fields in the output of a stage or for giving meaningful - names to calculated values. - - Example: - >>> # Calculate the total price and assign it the alias "totalPrice" and add it to the output. - >>> firestore.pipeline().collection("items").add_fields( - ... Field.of("price").multiply(Field.of("quantity")).as_("totalPrice") - ... ) - - Args: - alias: The alias to assign to this expression. - - Returns: - A new `Selectable` (typically an `ExprWithAlias`) that wraps this - expression and associates it with the provided alias. - """ - return ExprWithAlias(self, alias) - - class Constant(Expr, Generic[CONSTANT_TYPE]): """Represents a constant literal value in an expression.""" @@ -998,241 +336,6 @@ def _to_pb(self): ) -class Divide(Function): - """Represents the division function.""" - - def __init__(self, left: Expr, right: Expr): - super().__init__("divide", [left, right]) - - -class LogicalMax(Function): - """Represents the logical maximum function based on Firestore type ordering.""" - - def __init__(self, left: Expr, right: Expr): - super().__init__("logical_maximum", [left, right]) - - -class LogicalMin(Function): - """Represents the logical minimum function based on Firestore type ordering.""" - - def __init__(self, left: Expr, right: Expr): - super().__init__("logical_minimum", [left, right]) - - -class MapGet(Function): - """Represents accessing a value within a map by key.""" - - def __init__(self, map_: Expr, key: Constant[str]): - super().__init__("map_get", [map_, key]) - - -class Mod(Function): - """Represents the modulo function.""" - - def __init__(self, left: Expr, right: Expr): - super().__init__("mod", [left, right]) - - -class Multiply(Function): - """Represents the multiplication function.""" - - def __init__(self, left: Expr, right: Expr): - super().__init__("multiply", [left, right]) - - -class Parent(Function): - """Represents getting the parent document reference.""" - - def __init__(self, value: Expr): - super().__init__("parent", [value]) - - -class StrConcat(Function): - """Represents concatenating multiple strings.""" - - def __init__(self, *exprs: Expr): - super().__init__("str_concat", exprs) - - -class Subtract(Function): - """Represents the subtraction function.""" - - def __init__(self, left: Expr, right: Expr): - super().__init__("subtract", [left, right]) - - -class TimestampAdd(Function): - """Represents adding a duration to a timestamp.""" - - def __init__(self, timestamp: Expr, unit: Expr, amount: Expr): - super().__init__("timestamp_add", [timestamp, unit, amount]) - - -class TimestampSub(Function): - """Represents subtracting a duration from a timestamp.""" - - def __init__(self, timestamp: Expr, unit: Expr, amount: Expr): - super().__init__("timestamp_sub", [timestamp, unit, amount]) - - -class TimestampToUnixMicros(Function): - """Represents converting a timestamp to microseconds since epoch.""" - - def __init__(self, input: Expr): - super().__init__("timestamp_to_unix_micros", [input]) - - -class TimestampToUnixMillis(Function): - """Represents converting a timestamp to milliseconds since epoch.""" - - def __init__(self, input: Expr): - super().__init__("timestamp_to_unix_millis", [input]) - - -class TimestampToUnixSeconds(Function): - """Represents converting a timestamp to seconds since epoch.""" - - def __init__(self, input: Expr): - super().__init__("timestamp_to_unix_seconds", [input]) - - -class UnixMicrosToTimestamp(Function): - """Represents converting microseconds since epoch to a timestamp.""" - - def __init__(self, input: Expr): - super().__init__("unix_micros_to_timestamp", [input]) - - -class UnixMillisToTimestamp(Function): - """Represents converting milliseconds since epoch to a timestamp.""" - - def __init__(self, input: Expr): - super().__init__("unix_millis_to_timestamp", [input]) - - -class UnixSecondsToTimestamp(Function): - """Represents converting seconds since epoch to a timestamp.""" - - def __init__(self, input: Expr): - super().__init__("unix_seconds_to_timestamp", [input]) - - -class VectorLength(Function): - """Represents getting the length (dimension) of a vector.""" - - def __init__(self, array: Expr): - super().__init__("vector_length", [array]) - - -class Add(Function): - """Represents the addition function.""" - - def __init__(self, left: Expr, right: Expr): - super().__init__("add", [left, right]) - - -class ArrayElement(Function): - """Represents accessing an element within an array""" - - def __init__(self): - super().__init__("array_element", []) - - -class ArrayFilter(Function): - """Represents filtering elements from an array based on a condition.""" - - def __init__(self, array: Expr, filter: "FilterCondition"): - super().__init__("array_filter", [array, filter]) - - -class ArrayLength(Function): - """Represents getting the length of an array.""" - - def __init__(self, array: Expr): - super().__init__("array_length", [array]) - - -class ArrayReverse(Function): - """Represents reversing the elements of an array.""" - - def __init__(self, array: Expr): - super().__init__("array_reverse", [array]) - - -class ArrayTransform(Function): - """Represents applying a transformation function to each element of an array.""" - - def __init__(self, array: Expr, transform: Function): - super().__init__("array_transform", [array, transform]) - - -class ByteLength(Function): - """Represents getting the byte length of a string (UTF-8).""" - - def __init__(self, expr: Expr): - super().__init__("byte_length", [expr]) - - -class CharLength(Function): - """Represents getting the character length of a string.""" - - def __init__(self, expr: Expr): - super().__init__("char_length", [expr]) - - -class CollectionId(Function): - """Represents getting the collection ID from a document reference.""" - - def __init__(self, value: Expr): - super().__init__("collection_id", [value]) - - -class Accumulator(Function): - """A base class for aggregation functions that operate across multiple inputs.""" - - -class Max(Accumulator): - """Represents the maximum aggregation function.""" - - def __init__(self, value: Expr, distinct: bool = False): - super().__init__("maximum", [value]) - - -class Min(Accumulator): - """Represents the minimum aggregation function.""" - - def __init__(self, value: Expr, distinct: bool = False): - super().__init__("minimum", [value]) - - -class Sum(Accumulator): - """Represents the sum aggregation function.""" - - def __init__(self, value: Expr, distinct: bool = False): - super().__init__("sum", [value]) - - -class Avg(Accumulator): - """Represents the average aggregation function.""" - - def __init__(self, value: Expr, distinct: bool = False): - super().__init__("avg", [value]) - - -class Count(Accumulator): - """Represents the count aggregation function.""" - - def __init__(self, value: Expr | None = None): - super().__init__("count", [value] if value else []) - - -class CountIf(Function): - """Represents counting inputs where a condition is true (likely used internally or planned).""" - - def __init__(self, value: Expr, distinct: bool = False): - super().__init__("countif", [value] if value else []) - - class Selectable(Expr): """Base class for expressions that can be selected or aliased in projection stages.""" @@ -1241,26 +344,6 @@ def _to_map(self): raise NotImplementedError -T = TypeVar("T", bound=Expr) - - -class ExprWithAlias(Selectable, Generic[T]): - """Wraps an expression with an alias.""" - - def __init__(self, expr: T, alias: str): - self.expr = expr - self.alias = alias - - def _to_map(self): - return self.alias, self.expr._to_pb() - - def __repr__(self): - return f"{self.expr}.as_('{self.alias}')" - - def _to_pb(self): - return Value(map_value={"fields": {self.alias: self.expr._to_pb()}}) - - class Field(Selectable): """Represents a reference to a field within a document.""" @@ -1268,7 +351,6 @@ class Field(Selectable): def __init__(self, path: str): """Initializes a Field reference. - Args: path: The dot-separated path to the field (e.g., "address.city"). Use Field.DOCUMENT_ID for the document ID. @@ -1278,11 +360,9 @@ def __init__(self, path: str): @staticmethod def of(path: str): """Creates a Field reference. - Args: path: The dot-separated path to the field (e.g., "address.city"). Use Field.DOCUMENT_ID for the document ID. - Returns: A new Field instance. """ @@ -1369,35 +449,18 @@ class And(FilterCondition): def __init__(self, *conditions: "FilterCondition"): super().__init__("and", conditions) - class ArrayContains(FilterCondition): def __init__(self, array: Expr, element: Expr): super().__init__( "array_contains", [array, element if element else Constant(None)] ) - -class ArrayContainsAll(FilterCondition): - """Represents checking if an array contains all specified elements.""" - - def __init__(self, array: Expr, elements: List[Expr]): - super().__init__("array_contains_all", [array, ListOfExprs(elements)]) - - class ArrayContainsAny(FilterCondition): """Represents checking if an array contains any of the specified elements.""" def __init__(self, array: Expr, elements: List[Expr]): super().__init__("array_contains_any", [array, ListOfExprs(elements)]) - -class EndsWith(FilterCondition): - """Represents checking if a string ends with a specific postfix.""" - - def __init__(self, expr: Expr, postfix: Expr): - super().__init__("ends_with", [expr, postfix]) - - class Eq(FilterCondition): """Represents the equality comparison.""" @@ -1426,15 +489,6 @@ def __init__(self, left: Expr, right: Expr): super().__init__("gte", [left, right if right else Constant(None)]) -class If(FilterCondition): - """Represents a conditional expression (if-then-else).""" - - def __init__(self, condition: "FilterCondition", true_expr: Expr, false_expr: Expr): - super().__init__( - "if", [condition, true_expr, false_expr if false_expr else Constant(None)] - ) - - class In(FilterCondition): """Represents checking if an expression's value is within a list of values.""" @@ -1448,14 +502,6 @@ class IsNaN(FilterCondition): def __init__(self, value: Expr): super().__init__("is_nan", [value]) - -class Like(FilterCondition): - """Represents a case-sensitive wildcard string comparison.""" - - def __init__(self, expr: Expr, pattern: Expr): - super().__init__("like", [expr, pattern]) - - class Lt(FilterCondition): """Represents the less than comparison.""" @@ -1476,7 +522,6 @@ class Neq(FilterCondition): def __init__(self, left: Expr, right: Expr): super().__init__("neq", [left, right if right else Constant(None)]) - class Not(FilterCondition): """Represents the logical NOT of a filter condition.""" @@ -1488,39 +533,4 @@ class Or(FilterCondition): """Represents the logical OR of multiple filter conditions.""" def __init__(self, *conditions: "FilterCondition"): - super().__init__("or", conditions) - - -class RegexContains(FilterCondition): - """Represents checking if a string contains a substring matching a regex.""" - - def __init__(self, expr: Expr, regex: Expr): - super().__init__("regex_contains", [expr, regex]) - - -class RegexMatch(FilterCondition): - """Represents checking if a string fully matches a regex.""" - - def __init__(self, expr: Expr, regex: Expr): - super().__init__("regex_match", [expr, regex]) - - -class StartsWith(FilterCondition): - """Represents checking if a string starts with a specific prefix.""" - - def __init__(self, expr: Expr, prefix: Expr): - super().__init__("starts_with", [expr, prefix]) - - -class StrContains(FilterCondition): - """Represents checking if a string contains a specific substring.""" - - def __init__(self, expr: Expr, substring: Expr): - super().__init__("str_contains", [expr, substring]) - - -class Xor(FilterCondition): - """Represents the logical XOR of multiple filter conditions.""" - - def __init__(self, conditions: List["FilterCondition"]): - super().__init__("xor", conditions) + super().__init__("or", conditions) \ No newline at end of file diff --git a/google/cloud/firestore_v1/pipeline_stages.py b/google/cloud/firestore_v1/pipeline_stages.py index 686aaf2a0..877b1bf1e 100644 --- a/google/cloud/firestore_v1/pipeline_stages.py +++ b/google/cloud/firestore_v1/pipeline_stages.py @@ -13,62 +13,20 @@ # limitations under the License. from __future__ import annotations -from typing import Any, Dict, Iterable, List, Optional, Sequence, TYPE_CHECKING +from typing import Optional from abc import ABC from abc import abstractmethod -from enum import Enum -from enum import auto from google.cloud.firestore_v1.types.document import Pipeline as Pipeline_pb from google.cloud.firestore_v1.types.document import Value -from google.cloud.firestore_v1.document import DocumentReference -from google.cloud.firestore_v1.vector import Vector -from google.cloud.firestore_v1.base_vector_query import DistanceMeasure from google.cloud.firestore_v1.pipeline_expressions import ( - Accumulator, Expr, - ExprWithAlias, Field, FilterCondition, Selectable, - SampleOptions, Ordering, ) -if TYPE_CHECKING: - from google.cloud.firestore_v1.pipeline import Pipeline - - -class FindNearestOptions: - """Options for configuring the `FindNearest` pipeline stage. - - Attributes: - limit (Optional[int]): The maximum number of nearest neighbors to return. - distance_field (Optional[Field]): An optional field to store the calculated - distance in the output documents. - """ - - def __init__( - self, - limit: Optional[int] = None, - distance_field: Optional[Field] = None, - ): - self.limit = limit - self.distance_field = distance_field - - -class UnnestOptions: - """Options for configuring the `Unnest` pipeline stage. - - Attributes: - index_field (str): The name of the field to add to each output document, - storing the original 0-based index of the element within the array. - """ - - def __init__(self, index_field: str): - self.index_field = index_field - - class Stage(ABC): """Base class for all pipeline stages. @@ -99,67 +57,6 @@ def __repr__(self): return f"{self.__class__.__name__}({', '.join(items)})" -class AddFields(Stage): - """Adds new fields to outputs from previous stages.""" - - def __init__(self, *fields: Selectable): - super().__init__("add_fields") - self.fields = list(fields) - - def _pb_args(self): - return [ - Value( - map_value={ - "fields": {m[0]: m[1] for m in [f._to_map() for f in self.fields]} - } - ) - ] - - -class Aggregate(Stage): - """Performs aggregation operations, optionally grouped.""" - - def __init__( - self, - *extra_accumulators: ExprWithAlias[Accumulator], - accumulators: Sequence[ExprWithAlias[Accumulator]] = (), - groups: Sequence[str | Selectable] = (), - ): - super().__init__() - self.groups: list[Selectable] = [ - Field(f) if isinstance(f, str) else f for f in groups - ] - self.accumulators: list[ExprWithAlias[Accumulator]] = [ - *accumulators, - *extra_accumulators, - ] - - def _pb_args(self): - return [ - Value( - map_value={ - "fields": { - m[0]: m[1] for m in [f._to_map() for f in self.accumulators] - } - } - ), - Value( - map_value={ - "fields": {m[0]: m[1] for m in [f._to_map() for f in self.groups]} - } - ), - ] - - def __repr__(self): - accumulator_str = ", ".join(repr(v) for v in self.accumulators) - group_str = "" - if self.groups: - if self.accumulators: - group_str = ", " - group_str += f"groups={self.groups}" - return f"{self.__class__.__name__}({accumulator_str}{group_str})" - - class Collection(Stage): """Specifies a collection as the initial data source.""" @@ -184,84 +81,6 @@ def _pb_args(self): return [Value(string_value=self.collection_id)] -class Database(Stage): - """Specifies the default database as the initial data source.""" - - def __init__(self): - super().__init__() - - -class Distinct(Stage): - """Returns documents with distinct combinations of specified field values.""" - - def __init__(self, *fields: str | Selectable): - super().__init__() - self.fields: list[Selectable] = [ - Field(f) if isinstance(f, str) else f for f in fields - ] - - def _pb_args(self) -> list[Value]: - return [ - Value( - map_value={ - "fields": {m[0]: m[1] for m in [f._to_map() for f in self.fields]} - } - ) - ] - - -class Documents(Stage): - """Specifies specific documents as the initial data source.""" - - def __init__(self, *paths: str): - super().__init__() - self.paths = paths - - @staticmethod - def of(*documents: "DocumentReference") -> "Documents": - doc_paths = ["/" + doc.path for doc in documents] - return Documents(*doc_paths) - - def _pb_args(self): - return [ - Value( - list_value={"values": [Value(string_value=path) for path in self.paths]} - ) - ] - - -class FindNearest(Stage): - """Performs vector distance (similarity) search.""" - - def __init__( - self, - field: str | Expr, - vector: Sequence[float] | Vector, - distance_measure: "DistanceMeasure", - options: Optional["FindNearestOptions"] = None, - ): - super().__init__("find_nearest") - self.field: Expr = Field(field) if isinstance(field, str) else field - self.vector: Vector = vector if isinstance(vector, Vector) else Vector(vector) - self.distance_measure = distance_measure - self.options = options or FindNearestOptions() - - def _pb_args(self): - return [ - self.field._to_pb(), - Value(array_value={"values": self.vector}), - Value(string_value=self.distance_measure.value), - ] - - def _pb_options(self) -> dict[str, Value]: - options = {} - if self.options and self.options.limit is not None: - options["limit"] = Value(integer_value=self.options.limit) - if self.options and self.options.distance_field is not None: - options["distance_field"] = self.options.distance_field._to_pb() - return options - - class GenericStage(Stage): """Represents a generic, named stage with parameters.""" @@ -297,41 +116,6 @@ def _pb_args(self): return [Value(integer_value=self.offset)] -class RemoveFields(Stage): - """Removes specified fields from outputs.""" - - def __init__(self, *fields: str | Field): - super().__init__("remove_fields") - self.fields = [Field(f) if isinstance(f, str) else f for f in fields] - - def _pb_args(self) -> list[Value]: - return [f._to_pb() for f in self.fields] - - -class Sample(Stage): - """Performs pseudo-random sampling of documents.""" - - def __init__(self, limit_or_options: int | SampleOptions): - super().__init__() - if isinstance(limit_or_options, int): - options = SampleOptions.doc_limit(limit_or_options) - else: - options = limit_or_options - self.options: SampleOptions = options - - def _pb_args(self): - if self.options.mode == SampleOptions.Mode.DOCUMENTS: - return [ - Value(integer_value=self.options.value), - Value(string_value="documents"), - ] - else: - return [ - Value(double_value=self.options.value), - Value(string_value="percent"), - ] - - class Select(Stage): """Selects or creates a set of fields.""" @@ -361,47 +145,6 @@ def __init__(self, *orders: "Ordering"): def _pb_args(self): return [o._to_pb() for o in self.orders] - -class Union(Stage): - """Performs a union of documents from two pipelines.""" - - def __init__(self, other: Pipeline): - super().__init__() - self.other = other - - def _pb_args(self): - return [Value(pipeline_value=self.other._to_pb().pipeline)] - - -class Unnest(Stage): - """Produces a document for each element in an array field.""" - - def __init__( - self, - field: Selectable | str, - alias: Field | str | None = None, - options: UnnestOptions | None = None, - ): - super().__init__() - self.field: Selectable = Field(field) if isinstance(field, str) else field - if alias is None: - self.alias = self.field - elif isinstance(alias, str): - self.alias = Field(alias) - else: - self.alias = alias - self.options = options - - def _pb_args(self): - return [self.field._to_pb(), self.alias._to_pb()] - - def _pb_options(self): - options = {} - if self.options is not None: - options["index_field"] = Value(string_value=self.options.index_field) - return options - - class Where(Stage): """Filters documents based on a specified condition.""" @@ -410,4 +153,4 @@ def __init__(self, condition: FilterCondition): self.condition = condition def _pb_args(self): - return [self.condition._to_pb()] + return [self.condition._to_pb()] \ No newline at end of file diff --git a/tests/system/pipeline_e2e.yaml b/tests/system/pipeline_e2e.yaml index dc262f4a9..d92bbb316 100644 --- a/tests/system/pipeline_e2e.yaml +++ b/tests/system/pipeline_e2e.yaml @@ -126,207 +126,6 @@ data: hugo: true nebula: true tests: - - description: "testAggregates - count" - pipeline: - - Collection: books - - Aggregate: - - ExprWithAlias: - - Count - - "count" - assert_results: - - count: 10 - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - mapValue: - fields: - count: - functionValue: - name: count - - mapValue: {} - name: aggregate - - description: "testAggregates - avg, count, max" - pipeline: - - Collection: books - - Where: - - Eq: - - Field: genre - - Constant: Science Fiction - - Aggregate: - - ExprWithAlias: - - Count - - "count" - - ExprWithAlias: - - Avg: - - Field: rating - - "avg_rating" - - ExprWithAlias: - - Max: - - Field: rating - - "max_rating" - assert_results: - - count: 2 - avg_rating: 4.4 - max_rating: 4.6 - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - fieldReferenceValue: genre - - stringValue: Science Fiction - name: eq - name: where - - args: - - mapValue: - fields: - avg_rating: - functionValue: - args: - - fieldReferenceValue: rating - name: avg - count: - functionValue: - name: count - max_rating: - functionValue: - args: - - fieldReferenceValue: rating - name: maximum - - mapValue: {} - name: aggregate - - description: testGroupBysWithoutAccumulators - pipeline: - - Collection: books - - Where: - - Lt: - - Field: published - - Constant: 1900 - - Aggregate: - accumulators: [] - groups: [genre] - assert_error: ".* requires at least one accumulator" - - description: testGroupBysAndAggregate - pipeline: - - Collection: books - - Where: - - Lt: - - Field: published - - Constant: 1984 - - Aggregate: - accumulators: - - ExprWithAlias: - - Avg: - - Field: rating - - "avg_rating" - groups: [genre] - - Where: - - Gt: - - Field: avg_rating - - Constant: 4.3 - - Sort: - - Ordering: - - Field: avg_rating - - ASCENDING - assert_results: - - avg_rating: 4.4 - genre: Science Fiction - - avg_rating: 4.5 - genre: Romance - - avg_rating: 4.7 - genre: Fantasy - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - fieldReferenceValue: published - - integerValue: '1984' - name: lt - name: where - - args: - - mapValue: - fields: - avg_rating: - functionValue: - args: - - fieldReferenceValue: rating - name: avg - - mapValue: - fields: - genre: - fieldReferenceValue: genre - name: aggregate - - args: - - functionValue: - args: - - fieldReferenceValue: avg_rating - - doubleValue: 4.3 - name: gt - name: where - - args: - - mapValue: - fields: - direction: - stringValue: ascending - expression: - fieldReferenceValue: avg_rating - name: sort - - description: testMinMax - pipeline: - - Collection: books - - Aggregate: - - ExprWithAlias: - - Count - - "count" - - ExprWithAlias: - - Max: - - Field: rating - - "max_rating" - - ExprWithAlias: - - Min: - - Field: published - - "min_published" - assert_results: - - count: 10 - max_rating: 4.7 - min_published: 1813 - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - mapValue: - fields: - count: - functionValue: - name: count - max_rating: - functionValue: - args: - - fieldReferenceValue: rating - name: maximum - min_published: - functionValue: - args: - - fieldReferenceValue: published - name: minimum - - mapValue: {} - name: aggregate - description: selectSpecificFields pipeline: - Collection: books @@ -380,98 +179,6 @@ tests: expression: fieldReferenceValue: author name: sort - - description: addAndRemoveFields - pipeline: - - Collection: books - - AddFields: - - ExprWithAlias: - - StrConcat: - - Field: author - - Constant: _ - - Field: title - - "author_title" - - ExprWithAlias: - - StrConcat: - - Field: title - - Constant: _ - - Field: author - - "title_author" - - RemoveFields: - - title_author - - tags - - awards - - rating - - title - - Field: published - - Field: genre - - Field: nestedField # Field does not exist, should be ignored - - Sort: - - Ordering: - - Field: author_title - - ASCENDING - assert_results: - - author: Douglas Adams - author_title: Douglas Adams_The Hitchhiker's Guide to the Galaxy - - author: F. Scott Fitzgerald - author_title: F. Scott Fitzgerald_The Great Gatsby - - author: Frank Herbert - author_title: Frank Herbert_Dune - - author: Fyodor Dostoevsky - author_title: Fyodor Dostoevsky_Crime and Punishment - - author: Gabriel García Márquez - author_title: Gabriel García Márquez_One Hundred Years of Solitude - - author: George Orwell - author_title: George Orwell_1984 - - author: Harper Lee - author_title: Harper Lee_To Kill a Mockingbird - - author: J.R.R. Tolkien - author_title: J.R.R. Tolkien_The Lord of the Rings - - author: Jane Austen - author_title: Jane Austen_Pride and Prejudice - - author: Margaret Atwood - author_title: Margaret Atwood_The Handmaid's Tale - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - mapValue: - fields: - author_title: - functionValue: - args: - - fieldReferenceValue: author - - stringValue: _ - - fieldReferenceValue: title - name: str_concat - title_author: - functionValue: - args: - - fieldReferenceValue: title - - stringValue: _ - - fieldReferenceValue: author - name: str_concat - name: add_fields - - args: - - fieldReferenceValue: title_author - - fieldReferenceValue: tags - - fieldReferenceValue: awards - - fieldReferenceValue: rating - - fieldReferenceValue: title - - fieldReferenceValue: published - - fieldReferenceValue: genre - - fieldReferenceValue: nestedField - name: remove_fields - - args: - - mapValue: - fields: - direction: - stringValue: ascending - expression: - fieldReferenceValue: author_title - name: sort - description: whereByMultipleConditions pipeline: - Collection: books @@ -697,18 +404,34 @@ tests: expression: fieldReferenceValue: title name: sort - - description: testArrayContainsAll + - description: testComparisonOperators pipeline: - Collection: books - Where: - - ArrayContainsAll: - - Field: tags - - - Constant: adventure - - Constant: magic + - And: + - Gt: + - Field: rating + - Constant: 4.2 + - Lte: + - Field: rating + - Constant: 4.5 + - Neq: + - Field: genre + - Constant: Science Fiction - Select: + - rating - title + - Sort: + - Ordering: + - title + - ASCENDING assert_results: - - title: The Lord of the Rings + - rating: 4.3 + title: Crime and Punishment + - rating: 4.3 + title: One Hundred Years of Solitude + - rating: 4.5 + title: Pride and Prejudice assert_proto: pipeline: stages: @@ -718,127 +441,64 @@ tests: - args: - functionValue: args: - - fieldReferenceValue: tags - - arrayValue: - values: - - stringValue: adventure - - stringValue: magic - name: array_contains_all + - functionValue: + args: + - fieldReferenceValue: rating + - doubleValue: 4.2 + name: gt + - functionValue: + args: + - fieldReferenceValue: rating + - doubleValue: 4.5 + name: lte + - functionValue: + args: + - fieldReferenceValue: genre + - stringValue: Science Fiction + name: neq + name: and name: where - args: - mapValue: fields: + rating: + fieldReferenceValue: rating title: fieldReferenceValue: title name: select - - description: testArrayLength - pipeline: - - Collection: books - - Select: - - ExprWithAlias: - - ArrayLength: - - Field: tags - - "tagsCount" - - Where: - - Eq: - - Field: tagsCount - - Constant: 3 - assert_results: # All documents have 3 tags - - tagsCount: 3 - - tagsCount: 3 - - tagsCount: 3 - - tagsCount: 3 - - tagsCount: 3 - - tagsCount: 3 - - tagsCount: 3 - - tagsCount: 3 - - tagsCount: 3 - - tagsCount: 3 - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - mapValue: - fields: - tagsCount: - functionValue: - args: - - fieldReferenceValue: tags - name: array_length - name: select - - args: - - functionValue: - args: - - fieldReferenceValue: tagsCount - - integerValue: '3' - name: eq - name: where - - description: testStrConcat - pipeline: - - Collection: books - - Sort: - - Ordering: - - Field: author - - ASCENDING - - Select: - - ExprWithAlias: - - StrConcat: - - Field: author - - Constant: " - " - - Field: title - - "bookInfo" - - Limit: 1 - assert_results: - - bookInfo: Douglas Adams - The Hitchhiker's Guide to the Galaxy - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - args: - mapValue: fields: direction: stringValue: ascending expression: - fieldReferenceValue: author + fieldReferenceValue: title name: sort - - args: - - mapValue: - fields: - bookInfo: - functionValue: - args: - - fieldReferenceValue: author - - stringValue: ' - ' - - fieldReferenceValue: title - name: str_concat - name: select - - args: - - integerValue: '1' - name: limit - - description: testStartsWith + - description: testLogicalOperators pipeline: - Collection: books - Where: - - StartsWith: - - Field: title - - Constant: The + - Or: + - And: + - Gt: + - Field: rating + - Constant: 4.5 + - Eq: + - Field: genre + - Constant: Science Fiction + - Lt: + - Field: published + - Constant: 1900 - Select: - - title + - title - Sort: - Ordering: - - Field: title - - ASCENDING + - Field: title + - ASCENDING assert_results: - - title: The Great Gatsby - - title: The Handmaid's Tale - - title: The Hitchhiker's Guide to the Galaxy - - title: The Lord of the Rings + - title: Crime and Punishment + - title: Dune + - title: Pride and Prejudice assert_proto: pipeline: stages: @@ -848,9 +508,25 @@ tests: - args: - functionValue: args: - - fieldReferenceValue: title - - stringValue: The - name: starts_with + - functionValue: + args: + - functionValue: + args: + - fieldReferenceValue: rating + - doubleValue: 4.5 + name: gt + - functionValue: + args: + - fieldReferenceValue: genre + - stringValue: Science Fiction + name: eq + name: and + - functionValue: + args: + - fieldReferenceValue: published + - integerValue: '1900' + name: lt + name: or name: where - args: - mapValue: @@ -866,22 +542,25 @@ tests: expression: fieldReferenceValue: title name: sort - - description: testEndsWith + - description: testNestedFields pipeline: - Collection: books - Where: - - EndsWith: - - Field: title - - Constant: y - - Select: - - title + - Eq: + - Field: awards.hugo + - Constant: true - Sort: - - Ordering: - - Field: title - - DESCENDING + - Ordering: + - Field: title + - DESCENDING + - Select: + - title + - Field: awards.hugo assert_results: - title: The Hitchhiker's Guide to the Galaxy - - title: The Great Gatsby + awards.hugo: true + - title: Dune + awards.hugo: true assert_proto: pipeline: stages: @@ -891,646 +570,9 @@ tests: - args: - functionValue: args: - - fieldReferenceValue: title - - stringValue: y - name: ends_with - name: where - - args: - - mapValue: - fields: - title: - fieldReferenceValue: title - name: select - - args: - - mapValue: - fields: - direction: - stringValue: descending - expression: - fieldReferenceValue: title - name: sort - - description: testLength - pipeline: - - Collection: books - - Select: - - ExprWithAlias: - - CharLength: - - Field: title - - "titleLength" - - title - - Where: - - Gt: - - Field: titleLength - - Constant: 20 - - Sort: - - Ordering: - - Field: title - - ASCENDING - assert_results: - - titleLength: 29 - title: One Hundred Years of Solitude - - titleLength: 36 - title: The Hitchhiker's Guide to the Galaxy - - titleLength: 21 - title: The Lord of the Rings - - titleLength: 21 - title: To Kill a Mockingbird - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - mapValue: - fields: - title: - fieldReferenceValue: title - titleLength: - functionValue: - args: - - fieldReferenceValue: title - name: char_length - name: select - - args: - - functionValue: - args: - - fieldReferenceValue: titleLength - - integerValue: '20' - name: gt - name: where - - args: - - mapValue: - fields: - direction: - stringValue: ascending - expression: - fieldReferenceValue: title - name: sort - - description: testStringFunctions - CharLength - pipeline: - - Collection: books - - Where: - - Eq: - - Field: author - - Constant: "Douglas Adams" - - Select: - - ExprWithAlias: - - CharLength: - - Field: title - - "title_length" - assert_results: - - title_length: 36 - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - fieldReferenceValue: author - - stringValue: Douglas Adams - name: eq - name: where - - args: - - mapValue: - fields: - title_length: - functionValue: - args: - - fieldReferenceValue: title - name: char_length - name: select - - description: testStringFunctions - ByteLength - pipeline: - - Collection: books - - Where: - - Eq: - - Field: author - - Constant: Douglas Adams - - Select: - - ExprWithAlias: - - ByteLength: - - StrConcat: - - Field: title - - Constant: _银河系漫游指南 - - "title_byte_length" - assert_results: - - title_byte_length: 58 - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - fieldReferenceValue: author - - stringValue: Douglas Adams - name: eq - name: where - - args: - - mapValue: - fields: - title_byte_length: - functionValue: - args: - - functionValue: - args: - - fieldReferenceValue: title - - stringValue: "_\u94F6\u6CB3\u7CFB\u6F2B\u6E38\u6307\u5357" - name: str_concat - name: byte_length - name: select - - description: testLike - pipeline: - - Collection: books - - Where: - - Like: - - Field: title - - Constant: "%Guide%" - - Select: - - title - assert_results: - - title: The Hitchhiker's Guide to the Galaxy - - description: testRegexContains - # Find titles that contain either "the" or "of" (case-insensitive) - pipeline: - - Collection: books - - Where: - - RegexContains: - - Field: title - - Constant: "(?i)(the|of)" - assert_count: 5 - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - fieldReferenceValue: title - - stringValue: "(?i)(the|of)" - name: regex_contains - name: where - - description: testRegexMatches - # Find titles that contain either "the" or "of" (case-insensitive) - pipeline: - - Collection: books - - Where: - - RegexMatch: - - Field: title - - Constant: ".*(?i)(the|of).*" - assert_count: 5 - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - fieldReferenceValue: title - - stringValue: ".*(?i)(the|of).*" - name: regex_match - name: where - - description: testArithmeticOperations - pipeline: - - Collection: books - - Where: - - Eq: - - Field: title - - Constant: To Kill a Mockingbird - - Select: - - ExprWithAlias: - - Add: - - Field: rating - - Constant: 1 - - "ratingPlusOne" - - ExprWithAlias: - - Subtract: - - Field: published - - Constant: 1900 - - "yearsSince1900" - - ExprWithAlias: - - Multiply: - - Field: rating - - Constant: 10 - - "ratingTimesTen" - - ExprWithAlias: - - Divide: - - Field: rating - - Constant: 2 - - "ratingDividedByTwo" - - ExprWithAlias: - - Multiply: - - Field: rating - - Constant: 20 - - "ratingTimes20" - - ExprWithAlias: - - Add: - - Field: rating - - Constant: 3 - - "ratingPlus3" - - ExprWithAlias: - - Mod: - - Field: rating - - Constant: 2 - - "ratingMod2" - assert_results: - - ratingPlusOne: 5.2 - yearsSince1900: 60 - ratingTimesTen: 42.0 - ratingDividedByTwo: 2.1 - ratingTimes20: 84 - ratingPlus3: 7.2 - ratingMod2: 0.20000000000000018 - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - fieldReferenceValue: title - - stringValue: To Kill a Mockingbird - name: eq - name: where - - args: - - mapValue: - fields: - ratingDividedByTwo: - functionValue: - args: - - fieldReferenceValue: rating - - integerValue: '2' - name: divide - ratingPlusOne: - functionValue: - args: - - fieldReferenceValue: rating - - integerValue: '1' - name: add - ratingTimesTen: - functionValue: - args: - - fieldReferenceValue: rating - - integerValue: '10' - name: multiply - yearsSince1900: - functionValue: - args: - - fieldReferenceValue: published - - integerValue: '1900' - name: subtract - ratingTimes20: - functionValue: - args: - - fieldReferenceValue: rating - - integerValue: '20' - name: multiply - ratingPlus3: - functionValue: - args: - - fieldReferenceValue: rating - - integerValue: '3' - name: add - ratingMod2: - functionValue: - args: - - fieldReferenceValue: rating - - integerValue: '2' - name: mod - name: select - - description: testComparisonOperators - pipeline: - - Collection: books - - Where: - - And: - - Gt: - - Field: rating - - Constant: 4.2 - - Lte: - - Field: rating - - Constant: 4.5 - - Neq: - - Field: genre - - Constant: Science Fiction - - Select: - - rating - - title - - Sort: - - Ordering: - - title - - ASCENDING - assert_results: - - rating: 4.3 - title: Crime and Punishment - - rating: 4.3 - title: One Hundred Years of Solitude - - rating: 4.5 - title: Pride and Prejudice - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - functionValue: - args: - - fieldReferenceValue: rating - - doubleValue: 4.2 - name: gt - - functionValue: - args: - - fieldReferenceValue: rating - - doubleValue: 4.5 - name: lte - - functionValue: - args: - - fieldReferenceValue: genre - - stringValue: Science Fiction - name: neq - name: and - name: where - - args: - - mapValue: - fields: - rating: - fieldReferenceValue: rating - title: - fieldReferenceValue: title - name: select - - args: - - mapValue: - fields: - direction: - stringValue: ascending - expression: - fieldReferenceValue: title - name: sort - - description: testLogicalOperators - pipeline: - - Collection: books - - Where: - - Or: - - And: - - Gt: - - Field: rating - - Constant: 4.5 - - Eq: - - Field: genre - - Constant: Science Fiction - - Lt: - - Field: published - - Constant: 1900 - - Select: - - title - - Sort: - - Ordering: - - Field: title - - ASCENDING - assert_results: - - title: Crime and Punishment - - title: Dune - - title: Pride and Prejudice - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - functionValue: - args: - - functionValue: - args: - - fieldReferenceValue: rating - - doubleValue: 4.5 - name: gt - - functionValue: - args: - - fieldReferenceValue: genre - - stringValue: Science Fiction - name: eq - name: and - - functionValue: - args: - - fieldReferenceValue: published - - integerValue: '1900' - name: lt - name: or - name: where - - args: - - mapValue: - fields: - title: - fieldReferenceValue: title - name: select - - args: - - mapValue: - fields: - direction: - stringValue: ascending - expression: - fieldReferenceValue: title - name: sort - - description: testChecks - pipeline: - - Collection: books - - Where: - - Not: - - IsNaN: - - Field: rating - - Select: - - ExprWithAlias: - - Not: - - IsNaN: - - Field: rating - - "ratingIsNotNaN" - - Limit: 1 - assert_results: - - ratingIsNotNaN: true - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - functionValue: - args: - - fieldReferenceValue: rating - name: is_nan - name: not - name: where - - args: - - mapValue: - fields: - ratingIsNotNaN: - functionValue: - args: - - functionValue: - args: - - fieldReferenceValue: rating - name: is_nan - name: not - name: select - - args: - - integerValue: '1' - name: limit - - description: testLogicalMinMax - pipeline: - - Collection: books - - Where: - - Eq: - - Field: author - - Constant: Douglas Adams - - Select: - - ExprWithAlias: - - LogicalMax: - - Field: rating - - Constant: 4.5 - - "max_rating" - - ExprWithAlias: - - LogicalMax: - - Field: published - - Constant: 1900 - - "max_published" - assert_results: - - max_rating: 4.5 - max_published: 1979 - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - fieldReferenceValue: author - - stringValue: Douglas Adams - name: eq - name: where - - args: - - mapValue: - fields: - max_published: - functionValue: - args: - - fieldReferenceValue: published - - integerValue: '1900' - name: logical_maximum - max_rating: - functionValue: - args: - - fieldReferenceValue: rating - - doubleValue: 4.5 - name: logical_maximum - name: select - - description: testMapGet - pipeline: - - Collection: books - - Sort: - - Ordering: - - Field: published - - DESCENDING - - Select: - - ExprWithAlias: - - MapGet: - - Field: awards - - Constant: hugo - - "hugoAward" - - Field: title - - Where: - - Eq: - - Field: hugoAward - - Constant: true - assert_results: - - hugoAward: true - title: The Hitchhiker's Guide to the Galaxy - - hugoAward: true - title: Dune - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - mapValue: - fields: - direction: - stringValue: descending - expression: - fieldReferenceValue: published - name: sort - - args: - - mapValue: - fields: - hugoAward: - functionValue: - args: - - fieldReferenceValue: awards - - stringValue: hugo - name: map_get - title: - fieldReferenceValue: title - name: select - - args: - - functionValue: - args: - - fieldReferenceValue: hugoAward - - booleanValue: true - name: eq - name: where - - description: testNestedFields - pipeline: - - Collection: books - - Where: - - Eq: - - Field: awards.hugo - - Constant: true - - Sort: - - Ordering: - - Field: title - - DESCENDING - - Select: - - title - - Field: awards.hugo - assert_results: - - title: The Hitchhiker's Guide to the Galaxy - awards.hugo: true - - title: Dune - awards.hugo: true - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - fieldReferenceValue: awards.hugo - - booleanValue: true - name: eq + - fieldReferenceValue: awards.hugo + - booleanValue: true + name: eq name: where - args: - mapValue: @@ -1547,94 +589,4 @@ tests: fieldReferenceValue: awards.hugo title: fieldReferenceValue: title - name: select - - description: testSampleLimit - pipeline: - - Collection: books - - Sample: 3 - assert_count: 3 # Results will vary due to randomness - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - integerValue: '3' - - stringValue: documents - name: sample - - description: testSamplePercentage - pipeline: - - Collection: books - - Sample: - - SampleOptions: - - 0.6 - - percent - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - doubleValue: 0.6 - - stringValue: percent - name: sample - - description: testUnion - pipeline: - - Collection: books - - Union: - - Pipeline: - - Collection: books - assert_count: 20 # Results will be duplicated - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - pipelineValue: - stages: - - args: - - referenceValue: /books - name: collection - name: union - - description: testUnnest - pipeline: - - Collection: books - - Where: - - Eq: - - Field: title - - Constant: The Hitchhiker's Guide to the Galaxy - - Unnest: - - tags - - tags_alias - - Select: tags_alias - assert_results: - - tags_alias: comedy - - tags_alias: space - - tags_alias: adventure - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - fieldReferenceValue: title - - stringValue: The Hitchhiker's Guide to the Galaxy - name: eq - name: where - - args: - - fieldReferenceValue: tags - - fieldReferenceValue: tags_alias - name: unnest - - args: - - mapValue: - fields: - tags_alias: - fieldReferenceValue: tags_alias - name: select + name: select \ No newline at end of file diff --git a/tests/system/test_pipeline_acceptance.py b/tests/system/test_pipeline_acceptance.py index 7266764ce..3c58e489c 100644 --- a/tests/system/test_pipeline_acceptance.py +++ b/tests/system/test_pipeline_acceptance.py @@ -155,25 +155,6 @@ def test_pipeline_parse_proto(test_dict, client): got_proto = MessageToDict(pipeline._to_pb()._pb) assert yaml.dump(expected_proto) == yaml.dump(got_proto) -@pytest.mark.parametrize( - "test_dict", - [t for t in yaml_loader() if "assert_error" in t], - ids=lambda x: f"{x.get('description', '')}" -) -def test_pipeline_expected_errors(test_dict, client): - """ - Finds assert_error statements in yaml, and ensures the pipeline raises the expected error - """ - error_regex = test_dict["assert_error"] - pipeline = parse_pipeline(client, test_dict["pipeline"]) - # check if server responds as expected - with pytest.raises(GoogleAPIError) as err: - [_ for _ in pipeline.execute()] - found_error = str(err.value) - match = re.search(error_regex, found_error) - assert match, f"error '{found_error}' does not match '{error_regex}'" - - @pytest.mark.parametrize( "test_dict", [t for t in yaml_loader() if "assert_results" in t or "assert_count" in t], @@ -193,25 +174,6 @@ def test_pipeline_results(test_dict, client): if expected_count is not None: assert len(got_results) == expected_count -@pytest.mark.parametrize( - "test_dict", - [t for t in yaml_loader() if "assert_error" in t], - ids=lambda x: f"{x.get('description', '')}" -) -@pytest.mark.asyncio -async def test_pipeline_expected_errors_async(test_dict, async_client): - """ - Finds assert_error statements in yaml, and ensures the pipeline raises the expected error - """ - error_regex = test_dict["assert_error"] - pipeline = parse_pipeline(async_client, test_dict["pipeline"]) - # check if server responds as expected - with pytest.raises(GoogleAPIError) as err: - [_ async for _ in pipeline.execute()] - found_error = str(err.value) - match = re.search(error_regex, found_error) - assert match, f"error '{found_error}' does not match '{error_regex}'" - @pytest.mark.parametrize( "test_dict", [t for t in yaml_loader() if "assert_results" in t or "assert_count" in t], From bd9c2c47c38112a0ff0a879fb19a7c7cda9340a7 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 2 May 2025 16:54:40 -0700 Subject: [PATCH 102/131] stripped down to stubs --- google/cloud/firestore_v1/base_pipeline.py | 133 +--- google/cloud/firestore_v1/base_query.py | 54 +- .../firestore_v1/pipeline_expressions.py | 449 +------------ google/cloud/firestore_v1/pipeline_stages.py | 92 +-- tests/system/pipeline_e2e.yaml | 592 ------------------ tests/system/test_pipeline_acceptance.py | 196 ------ tests/system/test_system.py | 164 ++--- tests/unit/v1/test_base_query.py | 168 ----- tests/unit/v1/test_pipeline_expressions.py | 270 -------- 9 files changed, 50 insertions(+), 2068 deletions(-) delete mode 100644 tests/system/pipeline_e2e.yaml delete mode 100644 tests/system/test_pipeline_acceptance.py delete mode 100644 tests/unit/v1/test_pipeline_expressions.py diff --git a/google/cloud/firestore_v1/base_pipeline.py b/google/cloud/firestore_v1/base_pipeline.py index 117dceb95..0793e58f9 100644 --- a/google/cloud/firestore_v1/base_pipeline.py +++ b/google/cloud/firestore_v1/base_pipeline.py @@ -13,17 +13,12 @@ # limitations under the License. from __future__ import annotations -from typing_extensions import Self from google.cloud.firestore_v1 import pipeline_stages as stages from google.cloud.firestore_v1.base_client import BaseClient from google.cloud.firestore_v1.types.pipeline import ( StructuredPipeline as StructuredPipeline_pb, ) from google.cloud.firestore_v1 import _helpers, document -from google.cloud.firestore_v1.pipeline_expressions import ( - FilterCondition, - Selectable, -) class _BasePipeline: @@ -78,130 +73,4 @@ def _parse_response(response_pb, client): read_time=response_pb._pb.execution_time, create_time=doc.create_time, update_time=doc.update_time, - ) - - def select(self, *selections: str | Selectable) -> Self: - """ - Selects or creates a set of fields from the outputs of previous stages. - The selected fields are defined using `Selectable` expressions or field names: - - `Field`: References an existing document field. - - `Function`: Represents the result of a function with an assigned alias - name using `Expr.as_()`. - - `str`: The name of an existing field. - If no selections are provided, the output of this stage is empty. Use - `add_fields()` instead if only additions are desired. - Example: - >>> from google.cloud.firestore_v1.pipeline_expressions import Field, to_upper - >>> pipeline = client.collection("books").pipeline() - >>> # Select by name - >>> pipeline = pipeline.select("name", "address") - >>> # Select using Field and Function expressions - >>> pipeline = pipeline.select( - ... Field.of("name"), - ... Field.of("address").to_upper().as_("upperAddress"), - ... ) - Args: - *selections: The fields to include in the output documents, specified as - field names (str) or `Selectable` expressions. - Returns: - A new Pipeline object with this stage appended to the stage list - """ - return self._append(stages.Select(*selections)) - - def where(self, condition: FilterCondition) -> Self: - """ - Filters the documents from previous stages to only include those matching - the specified `FilterCondition`. - This stage allows you to apply conditions to the data, similar to a "WHERE" - clause in SQL. You can filter documents based on their field values, using - implementations of `FilterCondition`, typically including but not limited to: - - field comparators: `eq`, `lt` (less than), `gt` (greater than), etc. - - logical operators: `And`, `Or`, `Not`, etc. - - advanced functions: `regex_matches`, `array_contains`, etc. - Example: - >>> from google.cloud.firestore_v1.pipeline_expressions import Field, And, - >>> pipeline = client.collection("books").pipeline() - >>> # Using static functions - >>> pipeline = pipeline.where( - ... And( - ... Field.of("rating").gt(4.0), # Filter for ratings > 4.0 - ... Field.of("genre").eq("Science Fiction") # Filter for genre - ... ) - ... ) - >>> # Using methods on expressions - >>> pipeline = pipeline.where( - ... And( - ... Field.of("rating").gt(4.0), - ... Field.of("genre").eq("Science Fiction") - ... ) - ... ) - Args: - condition: The `FilterCondition` to apply. - Returns: - A new Pipeline object with this stage appended to the stage list - """ - return self._append(stages.Where(condition)) - - def sort(self, *orders: stages.Ordering) -> Self: - """ - Sorts the documents from previous stages based on one or more `Ordering` criteria. - This stage allows you to order the results of your pipeline. You can specify - multiple `Ordering` instances to sort by multiple fields or expressions in - ascending or descending order. If documents have the same value for a sorting - criterion, the next specified ordering will be used. If all orderings result - in equal comparison, the documents are considered equal and the relative order - is unspecified. - Example: - >>> from google.cloud.firestore_v1.pipeline_expressions import Field - >>> pipeline = client.collection("books").pipeline() - >>> # Sort books by rating descending, then title ascending - >>> pipeline = pipeline.sort( - ... Field.of("rating").descending(), - ... Field.of("title").ascending() - ... ) - Args: - *orders: One or more `Ordering` instances specifying the sorting criteria. - Returns: - A new Pipeline object with this stage appended to the stage list - """ - return self._append(stages.Sort(*orders)) - - def offset(self, offset: int) -> Self: - """ - Skips the first `offset` number of documents from the results of previous stages. - This stage is useful for implementing pagination, allowing you to retrieve - results in chunks. It is typically used in conjunction with `limit()` to - control the size of each page. - Example: - >>> from google.cloud.firestore_v1.pipeline_expressions import Field - >>> pipeline = client.collection("books").pipeline() - >>> # Retrieve the second page of 20 results (assuming sorted) - >>> pipeline = pipeline.sort(Field.of("published").descending()) - >>> pipeline = pipeline.offset(20) # Skip the first 20 results - >>> pipeline = pipeline.limit(20) # Take the next 20 results - Args: - offset: The non-negative number of documents to skip. - Returns: - A new Pipeline object with this stage appended to the stage list - """ - return self._append(stages.Offset(offset)) - - def limit(self, limit: int) -> Self: - """ - Limits the maximum number of documents returned by previous stages to `limit`. - This stage is useful for controlling the size of the result set, often used for: - - **Pagination:** In combination with `offset()` to retrieve specific pages. - - **Top-N queries:** To get a limited number of results after sorting. - - **Performance:** To prevent excessive data transfer. - Example: - >>> from google.cloud.firestore_v1.pipeline_expressions import Field - >>> pipeline = client.collection("books").pipeline() - >>> # Limit the results to the top 10 highest-rated books - >>> pipeline = pipeline.sort(Field.of("rating").descending()) - >>> pipeline = pipeline.limit(10) - Args: - limit: The non-negative maximum number of documents to return. - Returns: - A new Pipeline object with this stage appended to the stage list - """ - return self._append(stages.Limit(limit)) \ No newline at end of file + ) \ No newline at end of file diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index d8fb96dee..cec7ce02b 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -1106,60 +1106,8 @@ def recursive(self: QueryType) -> QueryType: return copied def pipeline(self): - if self._all_descendants: - base_stage = pipeline_stages.CollectionGroup(self._parent.id) - else: - base_stage = pipeline_stages.Collection("/".join(self._parent._path)) - ppl = self._client.pipeline(base_stage) - - # Filters - for filter_ in self._field_filters: - ppl = ppl.where( - pipeline_expressions.FilterCondition._from_query_filter_pb( - filter_, self._client - ) - ) - - # Projections - if self._projection and self._projection.fields: - ppl = ppl.select(*[field.field_path for field in self._projection.fields]) - - # Orders - orders = self._normalize_orders() - if orders: - exists = [] - orderings = [] - for order in orders: - field = pipeline_expressions.Field.of(order.field.field_path) - exists.append(field.exists()) - direction = ( - "ascending" - if order.direction == StructuredQuery.Direction.ASCENDING - else "descending" - ) - orderings.append(pipeline_expressions.Ordering(field, direction)) - - # Add exists filters to match Query's implicit orderby semantics. - if len(exists) > 1: - ppl = ppl.where(pipeline_expressions.And(*exists)) - elif len(exists) == 1: - ppl = ppl.where(exists[0]) - - # Add sort orderings - ppl = ppl.sort(*orderings) - - # Cursors, Limit and Offset - if self._start_at or self._end_at or self._limit_to_last: - raise NotImplementedError( - "Query to Pipeline conversion: cursors and limitToLast is not supported yet." - ) - else: # Limit & Offset without cursors - if self._offset: - ppl = ppl.offset(self._offset) - if self._limit: - ppl = ppl.limit(self._limit) + raise NotImplementedError - return ppl def _comparator(self, doc1, doc2) -> int: _orders = self._orders diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index fd5cb618e..4a560748f 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -17,20 +17,15 @@ Any, Generic, TypeVar, - List, Dict, - Sequence, ) from abc import ABC from abc import abstractmethod -from enum import Enum import datetime from google.cloud.firestore_v1.types.document import Value -from google.cloud.firestore_v1.types.query import StructuredQuery as Query_pb from google.cloud.firestore_v1.vector import Vector from google.cloud.firestore_v1._helpers import GeoPoint from google.cloud.firestore_v1._helpers import encode_value -from google.cloud.firestore_v1._helpers import decode_value CONSTANT_TYPE = TypeVar( "CONSTANT_TYPE", @@ -48,47 +43,6 @@ ) -class Ordering: - """Represents the direction for sorting results in a pipeline.""" - - class Direction(Enum): - ASCENDING = "ascending" - DESCENDING = "descending" - - def __init__(self, expr, order_dir: Direction | str = Direction.ASCENDING): - """ - Initializes an Ordering instance - Args: - expr (Expr | str): The expression or field path string to sort by. - If a string is provided, it's treated as a field path. - order_dir (Direction | str): The direction to sort in. - Defaults to ascending - """ - self.expr = expr if isinstance(expr, Expr) else Field.of(expr) - self.order_dir = ( - Ordering.Direction[order_dir.upper()] - if isinstance(order_dir, str) - else order_dir - ) - - def __repr__(self): - if self.order_dir is Ordering.Direction.ASCENDING: - order_str = ".ascending()" - else: - order_str = ".descending()" - return f"{self.expr!r}{order_str}" - - def _to_pb(self) -> Value: - return Value( - map_value={ - "fields": { - "direction": Value(string_value=self.order_dir.value), - "expression": self.expr._to_pb(), - } - } - ) - - class Expr(ABC): """Represents an expression that can be evaluated to a value within the execution of a pipeline. @@ -117,175 +71,6 @@ def _cast_to_expr_or_convert_to_constant(o: Any) -> "Expr": return o if isinstance(o, Expr) else Constant(o) - def eq(self, other: Expr | CONSTANT_TYPE) -> "Eq": - """Creates an expression that checks if this expression is equal to another - expression or constant value. - Example: - >>> # Check if the 'age' field is equal to 21 - >>> Field.of("age").eq(21) - >>> # Check if the 'city' field is equal to "London" - >>> Field.of("city").eq("London") - Args: - other: The expression or constant value to compare for equality. - Returns: - A new `Expr` representing the equality comparison. - """ - return Eq(self, self._cast_to_expr_or_convert_to_constant(other)) - - def neq(self, other: Expr | CONSTANT_TYPE) -> "Neq": - """Creates an expression that checks if this expression is not equal to another - expression or constant value. - Example: - >>> # Check if the 'status' field is not equal to "completed" - >>> Field.of("status").neq("completed") - >>> # Check if the 'country' field is not equal to "USA" - >>> Field.of("country").neq("USA") - Args: - other: The expression or constant value to compare for inequality. - Returns: - A new `Expr` representing the inequality comparison. - """ - return Neq(self, self._cast_to_expr_or_convert_to_constant(other)) - - def gt(self, other: Expr | CONSTANT_TYPE) -> "Gt": - """Creates an expression that checks if this expression is greater than another - expression or constant value. - Example: - >>> # Check if the 'age' field is greater than the 'limit' field - >>> Field.of("age").gt(Field.of("limit")) - >>> # Check if the 'price' field is greater than 100 - >>> Field.of("price").gt(100) - Args: - other: The expression or constant value to compare for greater than. - Returns: - A new `Expr` representing the greater than comparison. - """ - return Gt(self, self._cast_to_expr_or_convert_to_constant(other)) - - def gte(self, other: Expr | CONSTANT_TYPE) -> "Gte": - """Creates an expression that checks if this expression is greater than or equal - to another expression or constant value. - Example: - >>> # Check if the 'quantity' field is greater than or equal to field 'requirement' plus 1 - >>> Field.of("quantity").gte(Field.of('requirement').add(1)) - >>> # Check if the 'score' field is greater than or equal to 80 - >>> Field.of("score").gte(80) - Args: - other: The expression or constant value to compare for greater than or equal to. - Returns: - A new `Expr` representing the greater than or equal to comparison. - """ - return Gte(self, self._cast_to_expr_or_convert_to_constant(other)) - - def lt(self, other: Expr | CONSTANT_TYPE) -> "Lt": - """Creates an expression that checks if this expression is less than another - expression or constant value. - Example: - >>> # Check if the 'age' field is less than 'limit' - >>> Field.of("age").lt(Field.of('limit')) - >>> # Check if the 'price' field is less than 50 - >>> Field.of("price").lt(50) - Args: - other: The expression or constant value to compare for less than. - Returns: - A new `Expr` representing the less than comparison. - """ - return Lt(self, self._cast_to_expr_or_convert_to_constant(other)) - - def lte(self, other: Expr | CONSTANT_TYPE) -> "Lte": - """Creates an expression that checks if this expression is less than or equal to - another expression or constant value. - Example: - >>> # Check if the 'quantity' field is less than or equal to 20 - >>> Field.of("quantity").lte(Constant.of(20)) - >>> # Check if the 'score' field is less than or equal to 70 - >>> Field.of("score").lte(70) - Args: - other: The expression or constant value to compare for less than or equal to. - Returns: - A new `Expr` representing the less than or equal to comparison. - """ - return Lte(self, self._cast_to_expr_or_convert_to_constant(other)) - - def in_any(self, array: List[Expr | CONSTANT_TYPE]) -> "In": - """Creates an expression that checks if this expression is equal to any of the - provided values or expressions. - Example: - >>> # Check if the 'category' field is either "Electronics" or value of field 'primaryType' - >>> Field.of("category").in_any(["Electronics", Field.of("primaryType")]) - Args: - array: The values or expressions to check against. - Returns: - A new `Expr` representing the 'IN' comparison. - """ - return In(self, [self._cast_to_expr_or_convert_to_constant(v) for v in array]) - - def not_in_any(self, array: List[Expr | CONSTANT_TYPE]) -> "Not": - """Creates an expression that checks if this expression is not equal to any of the - provided values or expressions. - Example: - >>> # Check if the 'status' field is neither "pending" nor "cancelled" - >>> Field.of("status").not_in_any(["pending", "cancelled"]) - Args: - *others: The values or expressions to check against. - Returns: - A new `Expr` representing the 'NOT IN' comparison. - """ - return Not(self.in_any(array)) - - def array_contains(self, element: Expr | CONSTANT_TYPE) -> "ArrayContains": - """Creates an expression that checks if an array contains a specific element or value. - Example: - >>> # Check if the 'sizes' array contains the value from the 'selectedSize' field - >>> Field.of("sizes").array_contains(Field.of("selectedSize")) - >>> # Check if the 'colors' array contains "red" - >>> Field.of("colors").array_contains("red") - Args: - element: The element (expression or constant) to search for in the array. - Returns: - A new `Expr` representing the 'array_contains' comparison. - """ - return ArrayContains(self, self._cast_to_expr_or_convert_to_constant(element)) - - def array_contains_any( - self, elements: List[Expr | CONSTANT_TYPE] - ) -> "ArrayContainsAny": - """Creates an expression that checks if an array contains any of the specified elements. - Example: - >>> # Check if the 'categories' array contains either values from field "cate1" or "cate2" - >>> Field.of("categories").array_contains_any([Field.of("cate1"), Field.of("cate2")]) - >>> # Check if the 'groups' array contains either the value from the 'userGroup' field - >>> # or the value "guest" - >>> Field.of("groups").array_contains_any([Field.of("userGroup"), "guest"]) - Args: - elements: The list of elements (expressions or constants) to check for in the array. - Returns: - A new `Expr` representing the 'array_contains_any' comparison. - """ - return ArrayContainsAny( - self, [self._cast_to_expr_or_convert_to_constant(e) for e in elements] - ) - - def is_nan(self) -> "IsNaN": - """Creates an expression that checks if this expression evaluates to 'NaN' (Not a Number). - Example: - >>> # Check if the result of a calculation is NaN - >>> Field.of("value").divide(0).is_nan() - Returns: - A new `Expr` representing the 'isNaN' check. - """ - return IsNaN(self) - - def exists(self) -> "Exists": - """Creates an expression that checks if a field exists in the document. - Example: - >>> # Check if the document has a field named "phoneNumber" - >>> Field.of("phoneNumber").exists() - Returns: - A new `Expr` representing the 'exists' check. - """ - return Exists(self) - class Constant(Expr, Generic[CONSTANT_TYPE]): """Represents a constant literal value in an expression.""" @@ -301,236 +86,4 @@ def __repr__(self): return f"Constant.of({self.value!r})" def _to_pb(self) -> Value: - return encode_value(self.value) - - -class ListOfExprs(Expr): - """Represents a list of expressions, typically used as an argument to functions like 'in' or array functions.""" - - def __init__(self, exprs: List[Expr]): - self.exprs: list[Expr] = exprs - - def __repr__(self): - return f"{self.__class__.__name__}({self.exprs})" - - def _to_pb(self): - return Value(array_value={"values": [e._to_pb() for e in self.exprs]}) - - -class Function(Expr): - """A base class for expressions that represent function calls.""" - - def __init__(self, name: str, params: Sequence[Expr]): - self.name = name - self.params = list(params) - - def __repr__(self): - return f"{self.__class__.__name__}({', '.join([repr(p) for p in self.params])})" - - def _to_pb(self): - return Value( - function_value={ - "name": self.name, - "args": [p._to_pb() for p in self.params], - } - ) - - -class Selectable(Expr): - """Base class for expressions that can be selected or aliased in projection stages.""" - - @abstractmethod - def _to_map(self): - raise NotImplementedError - - -class Field(Selectable): - """Represents a reference to a field within a document.""" - - DOCUMENT_ID = "__name__" - - def __init__(self, path: str): - """Initializes a Field reference. - Args: - path: The dot-separated path to the field (e.g., "address.city"). - Use Field.DOCUMENT_ID for the document ID. - """ - self.path = path - - @staticmethod - def of(path: str): - """Creates a Field reference. - Args: - path: The dot-separated path to the field (e.g., "address.city"). - Use Field.DOCUMENT_ID for the document ID. - Returns: - A new Field instance. - """ - return Field(path) - - def _to_map(self): - return self.path, self._to_pb() - - def __repr__(self): - return f"Field.of({self.path!r})" - - def _to_pb(self): - return Value(field_reference_value=self.path) - - -class FilterCondition(Function): - """Filters the given data in some way.""" - - @staticmethod - def _from_query_filter_pb(filter_pb, client): - if isinstance(filter_pb, Query_pb.CompositeFilter): - sub_filters = [ - FilterCondition._from_query_filter_pb(f, client) - for f in filter_pb.filters - ] - if filter_pb.op == Query_pb.CompositeFilter.Operator.OR: - return Or(*sub_filters) - elif filter_pb.op == Query_pb.CompositeFilter.Operator.AND: - return And(*sub_filters) - else: - raise TypeError( - f"Unexpected CompositeFilter operator type: {filter_pb.op}" - ) - elif isinstance(filter_pb, Query_pb.UnaryFilter): - field = Field.of(filter_pb.field.field_path) - if filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NAN: - return And(field.exists(), field.is_nan()) - elif filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NOT_NAN: - return And(field.exists(), Not(field.is_nan())) - elif filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NULL: - return And(field.exists(), field.eq(None)) - elif filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NOT_NULL: - return And(field.exists(), Not(field.eq(None))) - else: - raise TypeError(f"Unexpected UnaryFilter operator type: {filter_pb.op}") - elif isinstance(filter_pb, Query_pb.FieldFilter): - field = Field.of(filter_pb.field.field_path) - value = decode_value(filter_pb.value, client) - if filter_pb.op == Query_pb.FieldFilter.Operator.LESS_THAN: - return And(field.exists(), field.lt(value)) - elif filter_pb.op == Query_pb.FieldFilter.Operator.LESS_THAN_OR_EQUAL: - return And(field.exists(), field.lte(value)) - elif filter_pb.op == Query_pb.FieldFilter.Operator.GREATER_THAN: - return And(field.exists(), field.gt(value)) - elif filter_pb.op == Query_pb.FieldFilter.Operator.GREATER_THAN_OR_EQUAL: - return And(field.exists(), field.gte(value)) - elif filter_pb.op == Query_pb.FieldFilter.Operator.EQUAL: - return And(field.exists(), field.eq(value)) - elif filter_pb.op == Query_pb.FieldFilter.Operator.NOT_EQUAL: - return And(field.exists(), field.neq(value)) - if filter_pb.op == Query_pb.FieldFilter.Operator.ARRAY_CONTAINS: - return And(field.exists(), field.array_contains(value)) - elif filter_pb.op == Query_pb.FieldFilter.Operator.ARRAY_CONTAINS_ANY: - return And(field.exists(), field.array_contains_any(value)) - elif filter_pb.op == Query_pb.FieldFilter.Operator.IN: - return And(field.exists(), field.in_any(value)) - elif filter_pb.op == Query_pb.FieldFilter.Operator.NOT_IN: - return And(field.exists(), field.not_in_any(value)) - else: - raise TypeError(f"Unexpected FieldFilter operator type: {filter_pb.op}") - elif isinstance(filter_pb, Query_pb.Filter): - # unwrap oneof - f = ( - filter_pb.composite_filter - or filter_pb.field_filter - or filter_pb.unary_filter - ) - return FilterCondition._from_query_filter_pb(f, client) - else: - raise TypeError(f"Unexpected filter type: {type(filter_pb)}") - - -class And(FilterCondition): - def __init__(self, *conditions: "FilterCondition"): - super().__init__("and", conditions) - -class ArrayContains(FilterCondition): - def __init__(self, array: Expr, element: Expr): - super().__init__( - "array_contains", [array, element if element else Constant(None)] - ) - -class ArrayContainsAny(FilterCondition): - """Represents checking if an array contains any of the specified elements.""" - - def __init__(self, array: Expr, elements: List[Expr]): - super().__init__("array_contains_any", [array, ListOfExprs(elements)]) - -class Eq(FilterCondition): - """Represents the equality comparison.""" - - def __init__(self, left: Expr, right: Expr): - super().__init__("eq", [left, right if right else Constant(None)]) - - -class Exists(FilterCondition): - """Represents checking if a field exists.""" - - def __init__(self, expr: Expr): - super().__init__("exists", [expr]) - - -class Gt(FilterCondition): - """Represents the greater than comparison.""" - - def __init__(self, left: Expr, right: Expr): - super().__init__("gt", [left, right if right else Constant(None)]) - - -class Gte(FilterCondition): - """Represents the greater than or equal to comparison.""" - - def __init__(self, left: Expr, right: Expr): - super().__init__("gte", [left, right if right else Constant(None)]) - - -class In(FilterCondition): - """Represents checking if an expression's value is within a list of values.""" - - def __init__(self, left: Expr, others: List[Expr]): - super().__init__("in", [left, ListOfExprs(others)]) - - -class IsNaN(FilterCondition): - """Represents checking if a numeric value is NaN.""" - - def __init__(self, value: Expr): - super().__init__("is_nan", [value]) - -class Lt(FilterCondition): - """Represents the less than comparison.""" - - def __init__(self, left: Expr, right: Expr): - super().__init__("lt", [left, right if right else Constant(None)]) - - -class Lte(FilterCondition): - """Represents the less than or equal to comparison.""" - - def __init__(self, left: Expr, right: Expr): - super().__init__("lte", [left, right if right else Constant(None)]) - - -class Neq(FilterCondition): - """Represents the inequality comparison.""" - - def __init__(self, left: Expr, right: Expr): - super().__init__("neq", [left, right if right else Constant(None)]) - -class Not(FilterCondition): - """Represents the logical NOT of a filter condition.""" - - def __init__(self, condition: Expr): - super().__init__("not", [condition]) - - -class Or(FilterCondition): - """Represents the logical OR of multiple filter conditions.""" - - def __init__(self, *conditions: "FilterCondition"): - super().__init__("or", conditions) \ No newline at end of file + return encode_value(self.value) \ No newline at end of file diff --git a/google/cloud/firestore_v1/pipeline_stages.py b/google/cloud/firestore_v1/pipeline_stages.py index 877b1bf1e..8e51fb1c1 100644 --- a/google/cloud/firestore_v1/pipeline_stages.py +++ b/google/cloud/firestore_v1/pipeline_stages.py @@ -21,10 +21,6 @@ from google.cloud.firestore_v1.types.document import Value from google.cloud.firestore_v1.pipeline_expressions import ( Expr, - Field, - FilterCondition, - Selectable, - Ordering, ) class Stage(ABC): @@ -57,30 +53,6 @@ def __repr__(self): return f"{self.__class__.__name__}({', '.join(items)})" -class Collection(Stage): - """Specifies a collection as the initial data source.""" - - def __init__(self, path: str): - super().__init__() - if not path.startswith("/"): - path = f"/{path}" - self.path = path - - def _pb_args(self): - return [Value(reference_value=self.path)] - - -class CollectionGroup(Stage): - """Specifies a collection group as the initial data source.""" - - def __init__(self, collection_id: str): - super().__init__("collection_group") - self.collection_id = collection_id - - def _pb_args(self): - return [Value(string_value=self.collection_id)] - - class GenericStage(Stage): """Represents a generic, named stage with parameters.""" @@ -91,66 +63,4 @@ def __init__(self, name: str, *params: Expr | Value): ] def _pb_args(self): - return self.params - - -class Limit(Stage): - """Limits the maximum number of documents returned.""" - - def __init__(self, limit: int): - super().__init__() - self.limit = limit - - def _pb_args(self): - return [Value(integer_value=self.limit)] - - -class Offset(Stage): - """Skips a specified number of documents.""" - - def __init__(self, offset: int): - super().__init__() - self.offset = offset - - def _pb_args(self): - return [Value(integer_value=self.offset)] - - -class Select(Stage): - """Selects or creates a set of fields.""" - - def __init__(self, *selections: str | Selectable): - super().__init__() - self.projections = [Field(s) if isinstance(s, str) else s for s in selections] - - def _pb_args(self) -> list[Value]: - return [ - Value( - map_value={ - "fields": { - m[0]: m[1] for m in [f._to_map() for f in self.projections] - } - } - ) - ] - - -class Sort(Stage): - """Sorts documents based on specified criteria.""" - - def __init__(self, *orders: "Ordering"): - super().__init__() - self.orders = list(orders) - - def _pb_args(self): - return [o._to_pb() for o in self.orders] - -class Where(Stage): - """Filters documents based on a specified condition.""" - - def __init__(self, condition: FilterCondition): - super().__init__() - self.condition = condition - - def _pb_args(self): - return [self.condition._to_pb()] \ No newline at end of file + return self.params \ No newline at end of file diff --git a/tests/system/pipeline_e2e.yaml b/tests/system/pipeline_e2e.yaml deleted file mode 100644 index d92bbb316..000000000 --- a/tests/system/pipeline_e2e.yaml +++ /dev/null @@ -1,592 +0,0 @@ -data: - books: - book1: - title: "The Hitchhiker's Guide to the Galaxy" - author: "Douglas Adams" - genre: "Science Fiction" - published: 1979 - rating: 4.2 - tags: - - comedy - - space - - adventure - awards: - hugo: true - nebula: false - book2: - title: "Pride and Prejudice" - author: "Jane Austen" - genre: "Romance" - published: 1813 - rating: 4.5 - tags: - - classic - - social commentary - - love - awards: - none: true - book3: - title: "One Hundred Years of Solitude" - author: "Gabriel García Márquez" - genre: "Magical Realism" - published: 1967 - rating: 4.3 - tags: - - family - - history - - fantasy - awards: - nobel: true - nebula: false - book4: - title: "The Lord of the Rings" - author: "J.R.R. Tolkien" - genre: "Fantasy" - published: 1954 - rating: 4.7 - tags: - - adventure - - magic - - epic - awards: - hugo: false - nebula: false - book5: - title: "The Handmaid's Tale" - author: "Margaret Atwood" - genre: "Dystopian" - published: 1985 - rating: 4.1 - tags: - - feminism - - totalitarianism - - resistance - awards: - arthur c. clarke: true - booker prize: false - book6: - title: "Crime and Punishment" - author: "Fyodor Dostoevsky" - genre: "Psychological Thriller" - published: 1866 - rating: 4.3 - tags: - - philosophy - - crime - - redemption - awards: - none: true - book7: - title: "To Kill a Mockingbird" - author: "Harper Lee" - genre: "Southern Gothic" - published: 1960 - rating: 4.2 - tags: - - racism - - injustice - - coming-of-age - awards: - pulitzer: true - book8: - title: "1984" - author: "George Orwell" - genre: "Dystopian" - published: 1949 - rating: 4.2 - tags: - - surveillance - - totalitarianism - - propaganda - awards: - prometheus: true - book9: - title: "The Great Gatsby" - author: "F. Scott Fitzgerald" - genre: "Modernist" - published: 1925 - rating: 4.0 - tags: - - wealth - - american dream - - love - awards: - none: true - book10: - title: "Dune" - author: "Frank Herbert" - genre: "Science Fiction" - published: 1965 - rating: 4.6 - tags: - - politics - - desert - - ecology - awards: - hugo: true - nebula: true -tests: - - description: selectSpecificFields - pipeline: - - Collection: books - - Select: - - title - - author - - Sort: - - Ordering: - - Field: author - - ASCENDING - assert_results: - - title: "The Hitchhiker's Guide to the Galaxy" - author: "Douglas Adams" - - title: "The Great Gatsby" - author: "F. Scott Fitzgerald" - - title: "Dune" - author: "Frank Herbert" - - title: "Crime and Punishment" - author: "Fyodor Dostoevsky" - - title: "One Hundred Years of Solitude" - author: "Gabriel García Márquez" - - title: "1984" - author: "George Orwell" - - title: "To Kill a Mockingbird" - author: "Harper Lee" - - title: "The Lord of the Rings" - author: "J.R.R. Tolkien" - - title: "Pride and Prejudice" - author: "Jane Austen" - - title: "The Handmaid's Tale" - author: "Margaret Atwood" - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - mapValue: - fields: - author: - fieldReferenceValue: author - title: - fieldReferenceValue: title - name: select - - args: - - mapValue: - fields: - direction: - stringValue: ascending - expression: - fieldReferenceValue: author - name: sort - - description: whereByMultipleConditions - pipeline: - - Collection: books - - Where: - - And: - - Gt: - - Field: rating - - Constant: 4.5 - - Eq: - - Field: genre - - Constant: Science Fiction - assert_results: - - title: Dune - author: Frank Herbert - genre: Science Fiction - published: 1965 - rating: 4.6 - tags: - - politics - - desert - - ecology - awards: - hugo: true - nebula: true - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - functionValue: - args: - - fieldReferenceValue: rating - - doubleValue: 4.5 - name: gt - - functionValue: - args: - - fieldReferenceValue: genre - - stringValue: Science Fiction - name: eq - name: and - name: where - - description: whereByOrCondition - pipeline: - - Collection: books - - Where: - - Or: - - Eq: - - Field: genre - - Constant: Romance - - Eq: - - Field: genre - - Constant: Dystopian - - Select: - - title - - Sort: - - Ordering: - - Field: title - - ASCENDING - assert_results: - - title: "1984" - - title: Pride and Prejudice - - title: The Handmaid's Tale - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - functionValue: - args: - - fieldReferenceValue: genre - - stringValue: Romance - name: eq - - functionValue: - args: - - fieldReferenceValue: genre - - stringValue: Dystopian - name: eq - name: or - name: where - - args: - - mapValue: - fields: - title: - fieldReferenceValue: title - name: select - - args: - - mapValue: - fields: - direction: - stringValue: ascending - expression: - fieldReferenceValue: title - name: sort - - description: testPipelineWithOffsetAndLimit - pipeline: - - Collection: books - - Sort: - - Ordering: - - Field: author - - ASCENDING - - Offset: 5 - - Limit: 3 - - Select: - - title - - author - assert_results: - - title: "1984" - author: George Orwell - - title: To Kill a Mockingbird - author: Harper Lee - - title: The Lord of the Rings - author: J.R.R. Tolkien - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - mapValue: - fields: - direction: - stringValue: ascending - expression: - fieldReferenceValue: author - name: sort - - args: - - integerValue: '5' - name: offset - - args: - - integerValue: '3' - name: limit - - args: - - mapValue: - fields: - author: - fieldReferenceValue: author - title: - fieldReferenceValue: title - name: select - - description: testArrayContains - pipeline: - - Collection: books - - Where: - - ArrayContains: - - Field: tags - - Constant: comedy - assert_results: - - title: The Hitchhiker's Guide to the Galaxy - author: Douglas Adams - awards: - hugo: true - nebula: false - genre: Science Fiction - published: 1979 - rating: 4.2 - tags: ["comedy", "space", "adventure"] - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - fieldReferenceValue: tags - - stringValue: comedy - name: array_contains - name: where - - description: testArrayContainsAny - pipeline: - - Collection: books - - Where: - - ArrayContainsAny: - - Field: tags - - - Constant: comedy - - Constant: classic - - Select: - - title - - Sort: - - Ordering: - - Field: title - - ASCENDING - assert_results: - - title: Pride and Prejudice - - title: The Hitchhiker's Guide to the Galaxy - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - fieldReferenceValue: tags - - arrayValue: - values: - - stringValue: comedy - - stringValue: classic - name: array_contains_any - name: where - - args: - - mapValue: - fields: - title: - fieldReferenceValue: title - name: select - - args: - - mapValue: - fields: - direction: - stringValue: ascending - expression: - fieldReferenceValue: title - name: sort - - description: testComparisonOperators - pipeline: - - Collection: books - - Where: - - And: - - Gt: - - Field: rating - - Constant: 4.2 - - Lte: - - Field: rating - - Constant: 4.5 - - Neq: - - Field: genre - - Constant: Science Fiction - - Select: - - rating - - title - - Sort: - - Ordering: - - title - - ASCENDING - assert_results: - - rating: 4.3 - title: Crime and Punishment - - rating: 4.3 - title: One Hundred Years of Solitude - - rating: 4.5 - title: Pride and Prejudice - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - functionValue: - args: - - fieldReferenceValue: rating - - doubleValue: 4.2 - name: gt - - functionValue: - args: - - fieldReferenceValue: rating - - doubleValue: 4.5 - name: lte - - functionValue: - args: - - fieldReferenceValue: genre - - stringValue: Science Fiction - name: neq - name: and - name: where - - args: - - mapValue: - fields: - rating: - fieldReferenceValue: rating - title: - fieldReferenceValue: title - name: select - - args: - - mapValue: - fields: - direction: - stringValue: ascending - expression: - fieldReferenceValue: title - name: sort - - description: testLogicalOperators - pipeline: - - Collection: books - - Where: - - Or: - - And: - - Gt: - - Field: rating - - Constant: 4.5 - - Eq: - - Field: genre - - Constant: Science Fiction - - Lt: - - Field: published - - Constant: 1900 - - Select: - - title - - Sort: - - Ordering: - - Field: title - - ASCENDING - assert_results: - - title: Crime and Punishment - - title: Dune - - title: Pride and Prejudice - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - functionValue: - args: - - functionValue: - args: - - fieldReferenceValue: rating - - doubleValue: 4.5 - name: gt - - functionValue: - args: - - fieldReferenceValue: genre - - stringValue: Science Fiction - name: eq - name: and - - functionValue: - args: - - fieldReferenceValue: published - - integerValue: '1900' - name: lt - name: or - name: where - - args: - - mapValue: - fields: - title: - fieldReferenceValue: title - name: select - - args: - - mapValue: - fields: - direction: - stringValue: ascending - expression: - fieldReferenceValue: title - name: sort - - description: testNestedFields - pipeline: - - Collection: books - - Where: - - Eq: - - Field: awards.hugo - - Constant: true - - Sort: - - Ordering: - - Field: title - - DESCENDING - - Select: - - title - - Field: awards.hugo - assert_results: - - title: The Hitchhiker's Guide to the Galaxy - awards.hugo: true - - title: Dune - awards.hugo: true - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - fieldReferenceValue: awards.hugo - - booleanValue: true - name: eq - name: where - - args: - - mapValue: - fields: - direction: - stringValue: descending - expression: - fieldReferenceValue: title - name: sort - - args: - - mapValue: - fields: - awards.hugo: - fieldReferenceValue: awards.hugo - title: - fieldReferenceValue: title - name: select \ No newline at end of file diff --git a/tests/system/test_pipeline_acceptance.py b/tests/system/test_pipeline_acceptance.py deleted file mode 100644 index 3c58e489c..000000000 --- a/tests/system/test_pipeline_acceptance.py +++ /dev/null @@ -1,196 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations -import sys -import os -import pytest -import yaml -import re -from typing import Any -from contextlib import nullcontext - -from google.protobuf.json_format import MessageToDict - -# from google.cloud.firestore_v1.pipeline_stages import * -from google.cloud.firestore_v1 import pipeline_stages -from google.cloud.firestore_v1 import pipeline_expressions -from google.cloud.firestore_v1.pipeline import Pipeline -from google.api_core.exceptions import GoogleAPIError - -from google.cloud.firestore import Client, AsyncClient - -FIRESTORE_TEST_DB = os.environ.get("SYSTEM_TESTS_DATABASE", "system-tests-named-db") -FIRESTORE_PROJECT = os.environ.get("GCLOUD_PROJECT") - -test_dir_name = os.path.dirname(__file__) - -def yaml_loader(field="tests"): - """ - loads test cases or data from yaml file - """ - with open(f"{test_dir_name}/pipeline_e2e.yaml") as f: - test_cases = yaml.safe_load(f) - return test_cases[field] - - -@pytest.fixture(scope="session") -def event_loop(): - import asyncio - - try: - loop = asyncio.get_running_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - yield loop - loop.close() - -@pytest.fixture(scope="module") -def client(): - client = Client(project=FIRESTORE_PROJECT, database=FIRESTORE_TEST_DB) - data = yaml_loader("data") - try: - # setup data - batch = client.batch() - for collection_name, documents in data.items(): - collection_ref = client.collection(collection_name) - for document_id, document_data in documents.items(): - document_ref = collection_ref.document(document_id) - batch.set(document_ref, document_data) - batch.commit() - yield client - finally: - # clear data - for collection_name, documents in data.items(): - collection_ref = client.collection(collection_name) - for document_id in documents: - document_ref = collection_ref.document(document_id) - document_ref.delete() - -@pytest.fixture(scope="module") -def async_client(client): - yield AsyncClient(project=client.project, database=client._database) - -def _apply_yaml_args(cls, client, yaml_args): - if isinstance(yaml_args, dict): - return cls(**parse_expressions(client, yaml_args)) - elif isinstance(yaml_args, list): - # yaml has an array of arguments. Treat as args - return cls(*parse_expressions(client, yaml_args)) - else: - # yaml has a single argument - return cls(parse_expressions(client, yaml_args)) - -def parse_pipeline(client, pipeline: list[dict[str, Any], str]): - """ - parse a yaml list of pipeline stages into firestore.pipeline_stages.Stage classes - """ - result_list = [] - for stage in pipeline: - # stage will be either a map of the stage_name and its args, or just the stage_name itself - stage_name: str = stage if isinstance(stage, str) else list(stage.keys())[0] - stage_cls: type[pipeline_stages.Stage] = getattr(pipeline_stages, stage_name) - # breakpoint() - # find arguments if given - if isinstance(stage, dict): - stage_yaml_args = stage[stage_name] - stage_obj = _apply_yaml_args(stage_cls, client, stage_yaml_args) - else: - # yaml has no arguments - stage_obj = stage_cls() - result_list.append(stage_obj) - return client.pipeline(*result_list) - -def _is_expr_string(yaml_str): - return isinstance(yaml_str, str) and \ - yaml_str[0].isupper() and \ - hasattr(pipeline_expressions, yaml_str) - -def parse_expressions(client, yaml_element: Any): - if isinstance(yaml_element, list): - return [parse_expressions(client, v) for v in yaml_element] - elif isinstance(yaml_element, dict): - if len(yaml_element) == 1 and _is_expr_string(list(yaml_element)[0]): - # build pipeline expressions if possible - cls_str = list(yaml_element)[0] - cls = getattr(pipeline_expressions, cls_str) - yaml_args = yaml_element[cls_str] - return _apply_yaml_args(cls, client, yaml_args) - elif len(yaml_element) == 1 and list(yaml_element)[0] == "Pipeline": - # find Pipeline objects for Union expressions - other_ppl = yaml_element["Pipeline"] - return parse_pipeline(client, other_ppl) - else: - # otherwise, return dict - return {parse_expressions(client, k): parse_expressions(client, v) for k,v in yaml_element.items()} - elif _is_expr_string(yaml_element): - return getattr(pipeline_expressions, yaml_element)() - else: - return yaml_element - -@pytest.mark.parametrize( - "test_dict", - [t for t in yaml_loader() if "assert_proto" in t], - ids=lambda x: f"{x.get('description', '')}" -) -def test_pipeline_parse_proto(test_dict, client): - """ - Finds assert_proto statements in yaml, and compares generated proto against expected value - """ - expected_proto = test_dict.get("assert_proto", None) - pipeline = parse_pipeline(client, test_dict["pipeline"]) - # check if proto matches as expected - if expected_proto: - got_proto = MessageToDict(pipeline._to_pb()._pb) - assert yaml.dump(expected_proto) == yaml.dump(got_proto) - -@pytest.mark.parametrize( - "test_dict", - [t for t in yaml_loader() if "assert_results" in t or "assert_count" in t], - ids=lambda x: f"{x.get('description', '')}" -) -def test_pipeline_results(test_dict, client): - """ - Ensure pipeline returns expected results - """ - expected_results = test_dict.get("assert_results", None) - expected_count = test_dict.get("assert_count", None) - pipeline = parse_pipeline(client, test_dict["pipeline"]) - # check if server responds as expected - got_results = [snapshot.to_dict() for snapshot in pipeline.execute()] - if expected_results: - assert got_results == expected_results - if expected_count is not None: - assert len(got_results) == expected_count - -@pytest.mark.parametrize( - "test_dict", - [t for t in yaml_loader() if "assert_results" in t or "assert_count" in t], - ids=lambda x: f"{x.get('description', '')}" -) -@pytest.mark.asyncio -async def test_pipeline_results_async(test_dict, async_client): - """ - Ensure pipeline returns expected results - """ - expected_results = test_dict.get("assert_results", None) - expected_count = test_dict.get("assert_count", None) - pipeline = parse_pipeline(async_client, test_dict["pipeline"]) - # check if server responds as expected - got_results = [snapshot.to_dict() async for snapshot in pipeline.execute()] - if expected_results: - assert got_results == expected_results - if expected_count is not None: - assert len(got_results) == expected_count - diff --git a/tests/system/test_system.py b/tests/system/test_system.py index 22a3663fd..ed525db57 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -80,31 +80,6 @@ def cleanup(): operation() -@pytest.fixture -def verify_pipeline(query): - """ - This fixture ensures a pipeline produces the same - results as the query it is derived from - - It can be attached to existing query tests both - modalities at the same time - """ - query_exception = None - query_results = None - try: - query_results = [s.to_dict() for s in query.get()] - except Exception as e: - query_exception = e - pipeline = query.pipeline() - if query_exception: - # ensure that the pipeline uses same error as query - with pytest.raises(query_exception): - pipeline.execute() - else: - # ensure results match query - pipeline_results = [s.to_dict() for s in pipeline.execute()] - assert query_results == pipeline_results - @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) def test_collections(client, database): collections = list(client.collections()) @@ -129,7 +104,7 @@ def test_collections_w_import(database): ) @pytest.mark.parametrize("method", ["stream", "get"]) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_collection_stream_or_get_w_no_explain_options(database, query_docs, method, verify_pipeline): +def test_collection_stream_or_get_w_no_explain_options(database, query_docs, method): from google.cloud.firestore_v1.query_profile import QueryExplainError collection, _, _ = query_docs @@ -144,7 +119,7 @@ def test_collection_stream_or_get_w_no_explain_options(database, query_docs, met match="explain_options not set on query.", ): results.get_explain_metrics() - verify_pipeline(collection) + @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." @@ -152,7 +127,7 @@ def test_collection_stream_or_get_w_no_explain_options(database, query_docs, met @pytest.mark.parametrize("method", ["get", "stream"]) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) def test_collection_stream_or_get_w_explain_options_analyze_false( - database, method, query_docs, verify_pipeline + database, method, query_docs ): from google.cloud.firestore_v1.query_profile import ( ExplainMetrics, @@ -182,7 +157,6 @@ def test_collection_stream_or_get_w_explain_options_analyze_false( match="execution_stats not available when explain_options.analyze=False", ): explain_metrics.execution_stats - verify_pipeline(collection) @pytest.mark.skipif( @@ -191,7 +165,7 @@ def test_collection_stream_or_get_w_explain_options_analyze_false( @pytest.mark.parametrize("method", ["get", "stream"]) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) def test_collection_stream_or_get_w_explain_options_analyze_true( - database, method, query_docs, verify_pipeline + database, method, query_docs ): from google.cloud.firestore_v1.query_profile import ( ExecutionStats, @@ -241,7 +215,6 @@ def test_collection_stream_or_get_w_explain_options_analyze_true( assert "documents_scanned" in execution_stats.debug_stats assert "index_entries_scanned" in execution_stats.debug_stats assert len(execution_stats.debug_stats) > 0 - verify_pipeline(collection) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) @@ -1140,7 +1113,7 @@ def query(collection): @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_query_stream_legacy_where(query_docs, database, verify_pipeline): +def test_query_stream_legacy_where(query_docs, database): """Assert the legacy code still works and returns value""" collection, stored, allowed_vals = query_docs with pytest.warns( @@ -1153,11 +1126,10 @@ def test_query_stream_legacy_where(query_docs, database, verify_pipeline): for key, value in values.items(): assert stored[key] == value assert value["a"] == 1 - verify_pipeline(query) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_query_stream_w_simple_field_eq_op(query_docs, database, verify_pipeline): +def test_query_stream_w_simple_field_eq_op(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("a", "==", 1)) values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()} @@ -1165,13 +1137,10 @@ def test_query_stream_w_simple_field_eq_op(query_docs, database, verify_pipeline for key, value in values.items(): assert stored[key] == value assert value["a"] == 1 - verify_pipeline(query) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_query_stream_w_simple_field_array_contains_op( - query_docs, database, verify_pipeline -): +def test_query_stream_w_simple_field_array_contains_op(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("c", "array_contains", 1)) values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()} @@ -1179,11 +1148,10 @@ def test_query_stream_w_simple_field_array_contains_op( for key, value in values.items(): assert stored[key] == value assert value["a"] == 1 - verify_pipeline(query) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_query_stream_w_simple_field_in_op(query_docs, database, verify_pipeline): +def test_query_stream_w_simple_field_in_op(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) query = collection.where(filter=FieldFilter("a", "in", [1, num_vals + 100])) @@ -1192,11 +1160,10 @@ def test_query_stream_w_simple_field_in_op(query_docs, database, verify_pipeline for key, value in values.items(): assert stored[key] == value assert value["a"] == 1 - verify_pipeline(query) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_query_stream_w_not_eq_op(query_docs, database, verify_pipeline): +def test_query_stream_w_not_eq_op(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("stats.sum", "!=", 4)) values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()} @@ -1215,11 +1182,10 @@ def test_query_stream_w_not_eq_op(query_docs, database, verify_pipeline): ] ) assert expected_ab_pairs == ab_pairs2 - verify_pipeline(query) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_query_stream_w_simple_not_in_op(query_docs, database, verify_pipeline): +def test_query_stream_w_simple_not_in_op(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) query = collection.where( @@ -1228,13 +1194,10 @@ def test_query_stream_w_simple_not_in_op(query_docs, database, verify_pipeline): values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()} assert len(values) == 22 - verify_pipeline(query) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_query_stream_w_simple_field_array_contains_any_op( - query_docs, database, verify_pipeline -): +def test_query_stream_w_simple_field_array_contains_any_op(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) query = collection.where( @@ -1245,11 +1208,10 @@ def test_query_stream_w_simple_field_array_contains_any_op( for key, value in values.items(): assert stored[key] == value assert value["a"] == 1 - verify_pipeline(query) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_query_stream_w_order_by(query_docs, database, verify_pipeline): +def test_query_stream_w_order_by(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.order_by("b", direction=firestore.Query.DESCENDING) values = [(snapshot.id, snapshot.to_dict()) for snapshot in query.stream()] @@ -1260,11 +1222,10 @@ def test_query_stream_w_order_by(query_docs, database, verify_pipeline): b_vals.append(value["b"]) # Make sure the ``b``-values are in DESCENDING order. assert sorted(b_vals, reverse=True) == b_vals - verify_pipeline(query) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_query_stream_w_field_path(query_docs, database, verify_pipeline): +def test_query_stream_w_field_path(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("stats.sum", ">", 4)) values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()} @@ -1286,7 +1247,7 @@ def test_query_stream_w_field_path(query_docs, database, verify_pipeline): @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_query_stream_w_start_end_cursor(query_docs, database, verify_pipeline): +def test_query_stream_w_start_end_cursor(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) query = ( @@ -1299,21 +1260,19 @@ def test_query_stream_w_start_end_cursor(query_docs, database, verify_pipeline): for key, value in values: assert stored[key] == value assert value["a"] == num_vals - 2 - verify_pipeline(query) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_query_stream_wo_results(query_docs, database, verify_pipeline): +def test_query_stream_wo_results(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) query = collection.where(filter=FieldFilter("b", "==", num_vals + 100)) values = list(query.stream()) assert len(values) == 0 - verify_pipeline(query) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_query_stream_w_projection(query_docs, database, verify_pipeline): +def test_query_stream_w_projection(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) query = collection.where(filter=FieldFilter("b", "<=", 1)).select( @@ -1327,11 +1286,10 @@ def test_query_stream_w_projection(query_docs, database, verify_pipeline): "stats": {"product": stored[key]["stats"]["product"]}, } assert expected == value - verify_pipeline(query) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_query_stream_w_multiple_filters(query_docs, database, verify_pipeline): +def test_query_stream_w_multiple_filters(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("stats.product", ">", 5)).where( filter=FieldFilter("stats.product", "<", 10) @@ -1348,11 +1306,10 @@ def test_query_stream_w_multiple_filters(query_docs, database, verify_pipeline): assert stored[key] == value pair = (value["a"], value["b"]) assert pair in matching_pairs - verify_pipeline(query) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_query_stream_w_offset(query_docs, database, verify_pipeline): +def test_query_stream_w_offset(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) offset = 3 @@ -1365,7 +1322,6 @@ def test_query_stream_w_offset(query_docs, database, verify_pipeline): for key, value in values.items(): assert stored[key] == value assert value["b"] == 2 - verify_pipeline(query) @pytest.mark.skipif( @@ -1495,7 +1451,7 @@ def test_query_stream_or_get_w_explain_options_analyze_false( @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_query_with_order_dot_key(client, cleanup, database, verify_pipeline): +def test_query_with_order_dot_key(client, cleanup, database): db = client collection_id = "collek" + UNIQUE_RESOURCE_ID collection = db.collection(collection_id) @@ -1507,38 +1463,33 @@ def test_query_with_order_dot_key(client, cleanup, database, verify_pipeline): query = collection.order_by("wordcount.page1").limit(3) data = [doc.to_dict()["wordcount"]["page1"] for doc in query.stream()] assert [100, 110, 120] == data - verify_pipeline(query) - query2 = collection.order_by("wordcount.page1").limit(3) - for snapshot in query2.stream(): + for snapshot in collection.order_by("wordcount.page1").limit(3).stream(): last_value = snapshot.get("wordcount.page1") - verify_pipeline(query2) cursor_with_nested_keys = {"wordcount": {"page1": last_value}} - query3 = ( + found = list( collection.order_by("wordcount.page1") .start_after(cursor_with_nested_keys) .limit(3) + .stream() ) - found = list(query3.stream()) found_data = [ {"count": 30, "wordcount": {"page1": 130}}, {"count": 40, "wordcount": {"page1": 140}}, {"count": 50, "wordcount": {"page1": 150}}, ] assert found_data == [snap.to_dict() for snap in found] - verify_pipeline(query3) cursor_with_dotted_paths = {"wordcount.page1": last_value} - query4 = ( + cursor_with_key_data = list( collection.order_by("wordcount.page1") .start_after(cursor_with_dotted_paths) .limit(3) + .stream() ) - cursor_with_key_data = list(query4.stream()) assert found_data == [snap.to_dict() for snap in cursor_with_key_data] - verify_pipeline(query4) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_query_unary(client, cleanup, database, verify_pipeline): +def test_query_unary(client, cleanup, database): collection_name = "unary" + UNIQUE_RESOURCE_ID collection = client.collection(collection_name) field_name = "foo" @@ -1559,7 +1510,6 @@ def test_query_unary(client, cleanup, database, verify_pipeline): snapshot0 = values0[0] assert snapshot0.reference._path == document0._path assert snapshot0.to_dict() == {field_name: None} - verify_pipeline(query0) # 1. Query for a NAN. query1 = collection.where(filter=FieldFilter(field_name, "==", nan_val)) @@ -1570,11 +1520,10 @@ def test_query_unary(client, cleanup, database, verify_pipeline): data1 = snapshot1.to_dict() assert len(data1) == 1 assert math.isnan(data1[field_name]) - verify_pipeline(query1) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_collection_group_queries(client, cleanup, database, verify_pipeline): +def test_collection_group_queries(client, cleanup, database): collection_group = "b" + UNIQUE_RESOURCE_ID doc_paths = [ @@ -1604,13 +1553,10 @@ def test_collection_group_queries(client, cleanup, database, verify_pipeline): found = [snapshot.id for snapshot in snapshots] expected = ["cg-doc1", "cg-doc2", "cg-doc3", "cg-doc4", "cg-doc5"] assert found == expected - verify_pipeline(query) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_collection_group_queries_startat_endat( - client, cleanup, database, verify_pipeline -): +def test_collection_group_queries_startat_endat(client, cleanup, database): collection_group = "b" + UNIQUE_RESOURCE_ID doc_paths = [ @@ -1640,7 +1586,6 @@ def test_collection_group_queries_startat_endat( snapshots = list(query.stream()) found = set(snapshot.id for snapshot in snapshots) assert found == set(["cg-doc2", "cg-doc3", "cg-doc4"]) - verify_pipeline(query) query = ( client.collection_group(collection_group) @@ -1651,11 +1596,10 @@ def test_collection_group_queries_startat_endat( snapshots = list(query.stream()) found = set(snapshot.id for snapshot in snapshots) assert found == set(["cg-doc2"]) - verify_pipeline(query) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_collection_group_queries_filters(client, cleanup, database, verify_pipeline): +def test_collection_group_queries_filters(client, cleanup, database): collection_group = "b" + UNIQUE_RESOURCE_ID doc_paths = [ @@ -1697,7 +1641,6 @@ def test_collection_group_queries_filters(client, cleanup, database, verify_pipe snapshots = list(query.stream()) found = set(snapshot.id for snapshot in snapshots) assert found == set(["cg-doc2", "cg-doc3", "cg-doc4"]) - verify_pipeline(query) query = ( client.collection_group(collection_group) @@ -1719,7 +1662,6 @@ def test_collection_group_queries_filters(client, cleanup, database, verify_pipe snapshots = list(query.stream()) found = set(snapshot.id for snapshot in snapshots) assert found == set(["cg-doc2"]) - verify_pipeline(query) @pytest.mark.skipif( @@ -1985,7 +1927,7 @@ def on_snapshot(docs, changes, read_time): @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_watch_query(client, cleanup, database, verify_pipeline): +def test_watch_query(client, cleanup, database): db = client collection_ref = db.collection("wq-users" + UNIQUE_RESOURCE_ID) doc_ref = collection_ref.document("alovelace") @@ -2002,10 +1944,10 @@ def on_snapshot(docs, changes, read_time): on_snapshot.called_count += 1 # A snapshot should return the same thing as if a query ran now. - query_ran_query = collection_ref.where(filter=FieldFilter("first", "==", "Ada")) - query_ran = query_ran_query.stream() + query_ran = collection_ref.where( + filter=FieldFilter("first", "==", "Ada") + ).stream() assert len(docs) == len([i for i in query_ran]) - verify_pipeline(query_ran_query) on_snapshot.called_count = 0 @@ -2206,12 +2148,11 @@ def test_recursive_delete_serialized_empty(client, cleanup, database): @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_recursive_query(client, cleanup, database, verify_pipeline): +def test_recursive_query(client, cleanup, database): col_id: str = f"philosophers-recursive-query{UNIQUE_RESOURCE_ID}" _persist_documents(client, col_id, philosophers_data_set, cleanup) - query = client.collection_group(col_id).recursive() - ids = [doc.id for doc in query.get()] + ids = [doc.id for doc in client.collection_group(col_id).recursive().get()] expected_ids = [ # Aristotle doc and subdocs @@ -2243,18 +2184,16 @@ def test_recursive_query(client, cleanup, database, verify_pipeline): f"Expected '{expected_ids[index]}' at spot {index}, " "got '{ids[index]}'" ) assert ids[index] == expected_ids[index], error_msg - verify_pipeline(query) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_nested_recursive_query(client, cleanup, database, verify_pipeline): +def test_nested_recursive_query(client, cleanup, database): col_id: str = f"philosophers-nested-recursive-query{UNIQUE_RESOURCE_ID}" _persist_documents(client, col_id, philosophers_data_set, cleanup) collection_ref = client.collection(col_id) aristotle = collection_ref.document("Aristotle") - query = aristotle.collection("pets").recursive() - ids = [doc.id for doc in query.get()] + ids = [doc.id for doc in aristotle.collection("pets").recursive().get()] expected_ids = [ # Aristotle pets @@ -2269,7 +2208,7 @@ def test_nested_recursive_query(client, cleanup, database, verify_pipeline): f"Expected '{expected_ids[index]}' at spot {index}, " "got '{ids[index]}'" ) assert ids[index] == expected_ids[index], error_msg - verify_pipeline(query) + @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) def test_chunked_query(client, cleanup, database): @@ -2348,7 +2287,7 @@ def test_chunked_and_recursive(client, cleanup, database): @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_watch_query_order(client, cleanup, database, verify_pipeline): +def test_watch_query_order(client, cleanup, database): db = client collection_ref = db.collection("users") doc_ref1 = collection_ref.document("alovelace" + UNIQUE_RESOURCE_ID) @@ -2384,7 +2323,6 @@ def on_snapshot(docs, changes, read_time): ), "expect the sort order to match, born" on_snapshot.called_count += 1 on_snapshot.last_doc_count = len(docs) - verify_pipeline(query_ref) except Exception as e: on_snapshot.failed = e @@ -2425,7 +2363,7 @@ def on_snapshot(docs, changes, read_time): @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_repro_429(client, cleanup, database, verify_pipeline): +def test_repro_429(client, cleanup, database): # See: https://github.com/googleapis/python-firestore/issues/429 now = datetime.datetime.now(tz=datetime.timezone.utc) collection = client.collection("repro-429" + UNIQUE_RESOURCE_ID) @@ -2450,8 +2388,6 @@ def test_repro_429(client, cleanup, database, verify_pipeline): for snapshot in query2.stream(): print(f"id: {snapshot.id}") - verify_pipeline(query) - verify_pipeline(query2) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) @@ -3021,7 +2957,7 @@ def test_aggregation_query_stream_or_get_w_explain_options_analyze_false( @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_query_with_and_composite_filter(collection, database, verify_pipeline): +def test_query_with_and_composite_filter(collection, database): and_filter = And( filters=[ FieldFilter("stats.product", ">", 5), @@ -3033,11 +2969,10 @@ def test_query_with_and_composite_filter(collection, database, verify_pipeline): for result in query.stream(): assert result.get("stats.product") > 5 assert result.get("stats.product") < 10 - verify_pipeline(query) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_query_with_or_composite_filter(collection, database, verify_pipeline): +def test_query_with_or_composite_filter(collection, database): or_filter = Or( filters=[ FieldFilter("stats.product", ">", 5), @@ -3057,11 +2992,10 @@ def test_query_with_or_composite_filter(collection, database, verify_pipeline): assert gt_5 > 0 assert lt_10 > 0 - verify_pipeline(query) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_query_with_complex_composite_filter(collection, database, verify_pipeline): +def test_query_with_complex_composite_filter(collection, database): field_filter = FieldFilter("b", "==", 0) or_filter = Or( filters=[FieldFilter("stats.sum", "==", 0), FieldFilter("stats.sum", "==", 4)] @@ -3081,7 +3015,6 @@ def test_query_with_complex_composite_filter(collection, database, verify_pipeli assert sum_0 > 0 assert sum_4 > 0 - verify_pipeline(query) # b == 3 || (stats.sum == 4 && a == 4) comp_filter = Or( @@ -3104,7 +3037,6 @@ def test_query_with_complex_composite_filter(collection, database, verify_pipeli assert b_3 is True assert b_not_3 is True - verify_pipeline(query) @pytest.mark.parametrize( @@ -3113,7 +3045,7 @@ def test_query_with_complex_composite_filter(collection, database, verify_pipeli ) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) def test_aggregation_query_in_transaction( - client, cleanup, database, aggregation_type, aggregation_args, expected, verify_pipeline + client, cleanup, database, aggregation_type, aggregation_args, expected ): """ Test creating an aggregation query inside a transaction @@ -3147,7 +3079,6 @@ def in_transaction(transaction): assert len(result[0]) == 1 assert result[0][0].value == expected inner_fn_ran = True - verify_pipeline(aggregation_query) in_transaction(transaction) # make sure we didn't skip assertions in inner function @@ -3155,7 +3086,7 @@ def in_transaction(transaction): @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_or_query_in_transaction(client, cleanup, database, verify_pipeline): +def test_or_query_in_transaction(client, cleanup, database): """ Test running or query inside a transaction. Should pass transaction id along with request """ @@ -3193,7 +3124,6 @@ def in_transaction(transaction): result[0].get("b") == 2 and result[1].get("b") == 1 ) inner_fn_ran = True - verify_pipeline(query) in_transaction(transaction) # make sure we didn't skip assertions in inner function @@ -3204,7 +3134,7 @@ def in_transaction(transaction): FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_query_in_transaction_with_explain_options(client, cleanup, database, verify_pipeline): +def test_query_in_transaction_with_explain_options(client, cleanup, database): """ Test query profiling in transactions. """ @@ -3249,7 +3179,6 @@ def in_transaction(transaction): assert explain_metrics.execution_stats is not None inner_fn_ran = True - verify_pipeline(query) in_transaction(transaction) # make sure we didn't skip assertions in inner function @@ -3258,7 +3187,7 @@ def in_transaction(transaction): @pytest.mark.parametrize("with_rollback,expected", [(True, 2), (False, 3)]) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_transaction_rollback(client, cleanup, database, with_rollback, expected, verify_pipeline): +def test_transaction_rollback(client, cleanup, database, with_rollback, expected): """ Create a document in a transaction that is rolled back Document should not show up in later queries @@ -3302,4 +3231,3 @@ def in_transaction(transaction, rollback): assert len(result) == 1 assert len(result[0]) == 1 assert result[0][0].value == expected - Vector(query) diff --git a/tests/unit/v1/test_base_query.py b/tests/unit/v1/test_base_query.py index d6762d846..24caa5e40 100644 --- a/tests/unit/v1/test_base_query.py +++ b/tests/unit/v1/test_base_query.py @@ -1952,174 +1952,6 @@ def test__collection_group_query_response_to_snapshot_response(): assert snapshot.update_time == response_pb._pb.document.update_time -def test__query_pipeline_decendants(): - from google.cloud.firestore_v1 import pipeline_stages - - client = make_client() - query = client.collection_group("my_col") - pipeline = query.pipeline() - - assert len(pipeline.stages) == 1 - stage = pipeline.stages[0] - assert isinstance(stage, pipeline_stages.CollectionGroup) - assert stage.collection_id == "my_col" - - -@pytest.mark.parametrize("in_path,out_path",[ - ("my_col/doc/", "/my_col/doc/"), - ("/my_col/doc", "/my_col/doc"), - ("my_col/doc/sub_col", "/my_col/doc/sub_col"), -]) -def test__query_pipeline_no_decendants(in_path, out_path): - from google.cloud.firestore_v1 import pipeline_stages - - client = make_client() - query = client.collection(in_path) - pipeline = query.pipeline() - - assert len(pipeline.stages) == 1 - stage = pipeline.stages[0] - assert isinstance(stage, pipeline_stages.Collection) - assert stage.path == out_path - - -def test__query_pipeline_composite_filter(): - from google.cloud.firestore_v1 import FieldFilter - from google.cloud.firestore_v1 import pipeline_expressions as expr - from google.cloud.firestore_v1 import pipeline_stages - - client = make_client() - in_filter = FieldFilter("field_a", "==", "value_a") - query = client.collection("my_col").where(filter=in_filter) - with mock.patch.object(expr.FilterCondition, "_from_query_filter_pb") as convert_mock: - pipeline = query.pipeline() - convert_mock.assert_called_once_with(in_filter._to_pb(), client) - assert len(pipeline.stages) == 2 - stage = pipeline.stages[1] - assert isinstance(stage, pipeline_stages.Where) - assert stage.condition == convert_mock.return_value - - -def test__query_pipeline_projections(): - from google.cloud.firestore_v1 import pipeline_stages - - client = make_client() - query = client.collection("my_col").select(["field_a", "field_b.c"]) - pipeline = query.pipeline() - - assert len(pipeline.stages) == 2 - stage = pipeline.stages[1] - assert isinstance(stage, pipeline_stages.Select) - assert len(stage.projections) == 2 - assert stage.projections[0].path == "field_a" - assert stage.projections[1].path == "field_b.c" - - -def test__query_pipeline_order_exists_multiple(): - from google.cloud.firestore_v1 import pipeline_expressions as expr - from google.cloud.firestore_v1 import pipeline_stages - - client = make_client() - query = client.collection("my_col").order_by("field_a").order_by("field_b") - pipeline = query.pipeline() - - # should have collection, where, and sort - # we're interested in where - assert len(pipeline.stages) == 3 - where_stage = pipeline.stages[1] - assert isinstance(where_stage, pipeline_stages.Where) - # should have and with both orderings - assert isinstance(where_stage.condition, expr.And) - assert len(where_stage.condition.params) == 2 - operands = [p for p in where_stage.condition.params] - assert isinstance(operands[0], expr.Exists) - assert operands[0].params[0].path == "field_a" - assert isinstance(operands[1], expr.Exists) - assert operands[1].params[0].path == "field_b" - -def test__query_pipeline_order_exists_single(): - from google.cloud.firestore_v1 import pipeline_expressions as expr - from google.cloud.firestore_v1 import pipeline_stages - - client = make_client() - query_single = client.collection("my_col").order_by("field_c") - pipeline_single = query_single.pipeline() - - # should have collection, where, and sort - # we're interested in where - assert len(pipeline_single.stages) == 3 - where_stage_single = pipeline_single.stages[1] - assert isinstance(where_stage_single, pipeline_stages.Where) - assert isinstance(where_stage_single.condition, expr.Exists) - assert where_stage_single.condition.params[0].path == "field_c" - - -def test__query_pipeline_order_sorts(): - from google.cloud.firestore_v1 import pipeline_expressions as expr - from google.cloud.firestore_v1 import pipeline_stages - from google.cloud.firestore_v1.base_query import BaseQuery - - client = make_client() - query = ( - client.collection("my_col") - .order_by("field_a", direction=BaseQuery.ASCENDING) - .order_by("field_b", direction=BaseQuery.DESCENDING) - ) - pipeline = query.pipeline() - - assert len(pipeline.stages) == 3 - sort_stage = pipeline.stages[2] - assert isinstance(sort_stage, pipeline_stages.Sort) - assert len(sort_stage.orders) == 2 - assert isinstance(sort_stage.orders[0], expr.Ordering) - assert sort_stage.orders[0].expr.path == "field_a" - assert sort_stage.orders[0].order_dir == expr.Ordering.Direction.ASCENDING - assert isinstance(sort_stage.orders[1], expr.Ordering) - assert sort_stage.orders[1].expr.path == "field_b" - assert sort_stage.orders[1].order_dir == expr.Ordering.Direction.DESCENDING - - -def test__query_pipeline_cursor(): - client = make_client() - query_start = client.collection("my_col").start_at({"field_a": "value"}) - with pytest.raises(NotImplementedError, match="cursors"): - query_start.pipeline() - - query_end = client.collection("my_col").end_at({"field_a": "value"}) - with pytest.raises(NotImplementedError, match="cursors"): - query_end.pipeline() - - query_limit_last = client.collection("my_col").limit_to_last(10) - with pytest.raises(NotImplementedError, match="limitToLast"): - query_limit_last.pipeline() - - -def test__query_pipeline_limit(): - from google.cloud.firestore_v1 import pipeline_stages - - client = make_client() - query = client.collection("my_col").limit(15) - pipeline = query.pipeline() - - assert len(pipeline.stages) == 2 - stage = pipeline.stages[1] - assert isinstance(stage, pipeline_stages.Limit) - assert stage.limit == 15 - - -def test__query_pipeline_offset(): - from google.cloud.firestore_v1 import pipeline_stages - - client = make_client() - query = client.collection("my_col").offset(5) - pipeline = query.pipeline() - - assert len(pipeline.stages) == 2 - stage = pipeline.stages[1] - assert isinstance(stage, pipeline_stages.Offset) - assert stage.offset == 5 - - def _make_order_pb(field_path, direction): from google.cloud.firestore_v1.types import query diff --git a/tests/unit/v1/test_pipeline_expressions.py b/tests/unit/v1/test_pipeline_expressions.py deleted file mode 100644 index 8d46e6c53..000000000 --- a/tests/unit/v1/test_pipeline_expressions.py +++ /dev/null @@ -1,270 +0,0 @@ -# Copyright 2025 Google LLC All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# limitations under the License. - -import pytest -import mock - -from google.cloud.firestore_v1 import _helpers -from google.cloud.firestore_v1.types import document as document_pb -from google.cloud.firestore_v1.types import query as query_pb -from google.cloud.firestore_v1.pipeline_expressions import FilterCondition -from google.cloud.firestore_v1 import pipeline_expressions as expr - - -@pytest.fixture -def mock_client(): - client = mock.Mock(spec=["_database_string", "collection"]) - client._database_string = "projects/p/databases/d" - return client - - -class TestFilterCondition: - - def test__from_query_filter_pb_composite_filter_or(self, mock_client): - """ - test composite OR filters - - should create an or statement, made up of ands checking of existance of relevant fields - """ - filter1_pb = query_pb.StructuredQuery.FieldFilter( - field=query_pb.StructuredQuery.FieldReference(field_path="field1"), - op=query_pb.StructuredQuery.FieldFilter.Operator.EQUAL, - value=_helpers.encode_value("val1"), - ) - filter2_pb = query_pb.StructuredQuery.UnaryFilter( - field=query_pb.StructuredQuery.FieldReference(field_path="field2"), - op=query_pb.StructuredQuery.UnaryFilter.Operator.IS_NULL, - ) - - composite_pb = query_pb.StructuredQuery.CompositeFilter( - op=query_pb.StructuredQuery.CompositeFilter.Operator.OR, - filters=[ - query_pb.StructuredQuery.Filter(field_filter=filter1_pb), - query_pb.StructuredQuery.Filter(unary_filter=filter2_pb), - ], - ) - wrapped_filter_pb = query_pb.StructuredQuery.Filter(composite_filter=composite_pb) - - result = FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) - - # should include existance checks - expected_cond1 = expr.And(expr.Exists(expr.Field.of("field1")), expr.Eq(expr.Field.of("field1"), expr.Constant("val1"))) - expected_cond2 = expr.And(expr.Exists(expr.Field.of("field2")), expr.Eq(expr.Field.of("field2"), expr.Constant(None))) - expected = expr.Or(expected_cond1, expected_cond2) - - assert repr(result) == repr(expected) - - def test__from_query_filter_pb_composite_filter_and(self, mock_client): - """ - test composite AND filters - - should create an and statement, made up of ands checking of existance of relevant fields - """ - filter1_pb = query_pb.StructuredQuery.FieldFilter( - field=query_pb.StructuredQuery.FieldReference(field_path="field1"), - op=query_pb.StructuredQuery.FieldFilter.Operator.GREATER_THAN, - value=_helpers.encode_value(100), - ) - filter2_pb = query_pb.StructuredQuery.FieldFilter( - field=query_pb.StructuredQuery.FieldReference(field_path="field2"), - op=query_pb.StructuredQuery.FieldFilter.Operator.LESS_THAN, - value=_helpers.encode_value(200), - ) - - composite_pb = query_pb.StructuredQuery.CompositeFilter( - op=query_pb.StructuredQuery.CompositeFilter.Operator.AND, - filters=[ - query_pb.StructuredQuery.Filter(field_filter=filter1_pb), - query_pb.StructuredQuery.Filter(field_filter=filter2_pb), - ], - ) - wrapped_filter_pb = query_pb.StructuredQuery.Filter(composite_filter=composite_pb) - - result = FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) - - # should include existance checks - expected_cond1 = expr.And(expr.Exists(expr.Field.of("field1")), expr.Gt(expr.Field.of("field1"), expr.Constant(100))) - expected_cond2 = expr.And(expr.Exists(expr.Field.of("field2")), expr.Lt(expr.Field.of("field2"), expr.Constant(200))) - expected = expr.And(expected_cond1, expected_cond2) - assert repr(result) == repr(expected) - - def test__from_query_filter_pb_composite_filter_nested(self, mock_client): - """ - test composite filter with complex nested checks - """ - # OR (field1 == "val1", AND(field2 > 10, field3 IS NOT NULL)) - filter1_pb = query_pb.StructuredQuery.FieldFilter( - field=query_pb.StructuredQuery.FieldReference(field_path="field1"), - op=query_pb.StructuredQuery.FieldFilter.Operator.EQUAL, - value=_helpers.encode_value("val1"), - ) - filter2_pb = query_pb.StructuredQuery.FieldFilter( - field=query_pb.StructuredQuery.FieldReference(field_path="field2"), - op=query_pb.StructuredQuery.FieldFilter.Operator.GREATER_THAN, - value=_helpers.encode_value(10), - ) - filter3_pb = query_pb.StructuredQuery.UnaryFilter( - field=query_pb.StructuredQuery.FieldReference(field_path="field3"), - op=query_pb.StructuredQuery.UnaryFilter.Operator.IS_NOT_NULL, - ) - inner_and_pb = query_pb.StructuredQuery.CompositeFilter( - op=query_pb.StructuredQuery.CompositeFilter.Operator.AND, - filters=[ - query_pb.StructuredQuery.Filter(field_filter=filter2_pb), - query_pb.StructuredQuery.Filter(unary_filter=filter3_pb), - ], - ) - outer_or_pb = query_pb.StructuredQuery.CompositeFilter( - op=query_pb.StructuredQuery.CompositeFilter.Operator.OR, - filters=[ - query_pb.StructuredQuery.Filter(field_filter=filter1_pb), - query_pb.StructuredQuery.Filter(composite_filter=inner_and_pb), - ], - ) - wrapped_filter_pb = query_pb.StructuredQuery.Filter(composite_filter=outer_or_pb) - - result = FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) - - expected_cond1 = expr.And(expr.Exists(expr.Field.of("field1")), expr.Eq(expr.Field.of("field1"), expr.Constant("val1"))) - expected_cond2 = expr.And(expr.Exists(expr.Field.of("field2")), expr.Gt(expr.Field.of("field2"), expr.Constant(10))) - expected_cond3 = expr.And(expr.Exists(expr.Field.of("field3")), expr.Not(expr.Eq(expr.Field.of("field3"), expr.Constant(None)))) - expected_inner_and = expr.And(expected_cond2, expected_cond3) - expected_outer_or = expr.Or(expected_cond1, expected_inner_and) - - assert repr(result) == repr(expected_outer_or) - - - def test__from_query_filter_pb_composite_filter_unknown_op(self, mock_client): - """ - check composite filter with unsupported operator type - """ - filter1_pb = query_pb.StructuredQuery.FieldFilter( - field=query_pb.StructuredQuery.FieldReference(field_path="field1"), - op=query_pb.StructuredQuery.FieldFilter.Operator.EQUAL, - value=_helpers.encode_value("val1"), - ) - composite_pb = query_pb.StructuredQuery.CompositeFilter( - op=query_pb.StructuredQuery.CompositeFilter.Operator.OPERATOR_UNSPECIFIED, - filters=[query_pb.StructuredQuery.Filter(field_filter=filter1_pb)], - ) - wrapped_filter_pb = query_pb.StructuredQuery.Filter(composite_filter=composite_pb) - - with pytest.raises(TypeError, match="Unexpected CompositeFilter operator type"): - FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) - - @pytest.mark.parametrize( - "op_enum, expected_expr_func", - [ - (query_pb.StructuredQuery.UnaryFilter.Operator.IS_NAN, expr.IsNaN), - (query_pb.StructuredQuery.UnaryFilter.Operator.IS_NOT_NAN, lambda f: expr.Not(f.is_nan())), - (query_pb.StructuredQuery.UnaryFilter.Operator.IS_NULL, lambda f: f.eq(None)), - (query_pb.StructuredQuery.UnaryFilter.Operator.IS_NOT_NULL, lambda f: expr.Not(f.eq(None))), - ], - ) - def test__from_query_filter_pb_unary_filter(self, mock_client, op_enum, expected_expr_func): - """ - test supported unary filters - """ - field_path = "unary_field" - filter_pb = query_pb.StructuredQuery.UnaryFilter( - field=query_pb.StructuredQuery.FieldReference(field_path=field_path), - op=op_enum, - ) - wrapped_filter_pb = query_pb.StructuredQuery.Filter(unary_filter=filter_pb) - - result = FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) - - field_expr_inst = expr.Field.of(field_path) - expected_condition = expected_expr_func(field_expr_inst) - # should include existance checks - expected = expr.And(expr.Exists(field_expr_inst), expected_condition) - - assert repr(result) == repr(expected) - - def test__from_query_filter_pb_unary_filter_unknown_op(self, mock_client): - """ - check unary filter with unsupported operator type - """ - field_path = "unary_field" - filter_pb = query_pb.StructuredQuery.UnaryFilter( - field=query_pb.StructuredQuery.FieldReference(field_path=field_path), - op=query_pb.StructuredQuery.UnaryFilter.Operator.OPERATOR_UNSPECIFIED, # Unknown op - ) - wrapped_filter_pb = query_pb.StructuredQuery.Filter(unary_filter=filter_pb) - - with pytest.raises(TypeError, match="Unexpected UnaryFilter operator type"): - FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) - - - @pytest.mark.parametrize( - "op_enum, value, expected_expr_func", - [ - (query_pb.StructuredQuery.FieldFilter.Operator.LESS_THAN, 10, expr.Lt), - (query_pb.StructuredQuery.FieldFilter.Operator.LESS_THAN_OR_EQUAL, 10, expr.Lte), - (query_pb.StructuredQuery.FieldFilter.Operator.GREATER_THAN, 10, expr.Gt), - (query_pb.StructuredQuery.FieldFilter.Operator.GREATER_THAN_OR_EQUAL, 10, expr.Gte), - (query_pb.StructuredQuery.FieldFilter.Operator.EQUAL, 10, expr.Eq), - (query_pb.StructuredQuery.FieldFilter.Operator.NOT_EQUAL, 10, expr.Neq), - (query_pb.StructuredQuery.FieldFilter.Operator.ARRAY_CONTAINS, 10, expr.ArrayContains), - (query_pb.StructuredQuery.FieldFilter.Operator.ARRAY_CONTAINS_ANY, [10, 20], expr.ArrayContainsAny), - (query_pb.StructuredQuery.FieldFilter.Operator.IN, [10, 20], expr.In), - (query_pb.StructuredQuery.FieldFilter.Operator.NOT_IN, [10, 20], lambda f, v: expr.Not(f.in_any(v))), - ], - ) - def test__from_query_filter_pb_field_filter(self, mock_client, op_enum, value, expected_expr_func): - """ - test supported field filters - """ - field_path = "test_field" - value_pb = _helpers.encode_value(value) - filter_pb = query_pb.StructuredQuery.FieldFilter( - field=query_pb.StructuredQuery.FieldReference(field_path=field_path), - op=op_enum, - value=value_pb, - ) - wrapped_filter_pb = query_pb.StructuredQuery.Filter(field_filter=filter_pb) - - result = FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) - - field_expr = expr.Field.of(field_path) - # convert values into constants - value = [expr.Constant(e) for e in value] if isinstance(value, list) else expr.Constant(value) - expected_condition = expected_expr_func(field_expr, value) - # should include existance checks - expected = expr.And(expr.Exists(field_expr), expected_condition) - - assert repr(result) == repr(expected) - - def test__from_query_filter_pb_field_filter_unknown_op(self, mock_client): - """ - check field filter with unsupported operator type - """ - field_path = "test_field" - value_pb = _helpers.encode_value(10) - filter_pb = query_pb.StructuredQuery.FieldFilter( - field=query_pb.StructuredQuery.FieldReference(field_path=field_path), - op=query_pb.StructuredQuery.FieldFilter.Operator.OPERATOR_UNSPECIFIED, # Unknown op - value=value_pb, - ) - wrapped_filter_pb = query_pb.StructuredQuery.Filter(field_filter=filter_pb) - - with pytest.raises(TypeError, match="Unexpected FieldFilter operator type"): - FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) - - def test__from_query_filter_pb_unknown_filter_type(self, mock_client): - """ - test with unsupported filter type - """ - # Test with an unexpected protobuf type - with pytest.raises(TypeError, match="Unexpected filter type"): - FilterCondition._from_query_filter_pb(document_pb.Value(), mock_client) From d8dc10f383fe14076da72487e2b0bfa113f0ff09 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 6 May 2025 14:49:25 -0700 Subject: [PATCH 103/131] added PipelineStages --- google/cloud/firestore_v1/async_client.py | 5 +- google/cloud/firestore_v1/base_client.py | 9 ++- google/cloud/firestore_v1/client.py | 5 +- google/cloud/firestore_v1/pipeline_source.py | 85 ++++++++++++++++++++ google/cloud/firestore_v1/pipeline_stages.py | 54 +++++++++++++ noxfile.py | 1 + 6 files changed, 153 insertions(+), 6 deletions(-) create mode 100644 google/cloud/firestore_v1/pipeline_source.py diff --git a/google/cloud/firestore_v1/async_client.py b/google/cloud/firestore_v1/async_client.py index 10aa02c69..13a00398e 100644 --- a/google/cloud/firestore_v1/async_client.py +++ b/google/cloud/firestore_v1/async_client.py @@ -414,5 +414,6 @@ def transaction(self, **kwargs) -> AsyncTransaction: """ return AsyncTransaction(self, **kwargs) - def pipeline(self, *stages) -> AsyncPipeline: - return AsyncPipeline(self, *stages) + @property + def _pipeline_cls(self): + raise AsyncPipeline \ No newline at end of file diff --git a/google/cloud/firestore_v1/base_client.py b/google/cloud/firestore_v1/base_client.py index 585de4ce2..767c577d3 100644 --- a/google/cloud/firestore_v1/base_client.py +++ b/google/cloud/firestore_v1/base_client.py @@ -34,6 +34,7 @@ Optional, Tuple, Union, + Type ) import google.api_core.client_options @@ -57,6 +58,7 @@ from google.cloud.firestore_v1.base_transaction import BaseTransaction from google.cloud.firestore_v1.bulk_writer import BulkWriter, BulkWriterOptions from google.cloud.firestore_v1.field_path import render_field_path +from google.cloud.firestore_v1.pipeline import PipelineSource, _BasePipeline DEFAULT_DATABASE = "(default)" """str: The default database used in a :class:`~google.cloud.firestore_v1.client.Client`.""" @@ -475,9 +477,12 @@ def batch(self) -> BaseWriteBatch: def transaction(self, **kwargs) -> BaseTransaction: raise NotImplementedError - def pipeline(self, *stages): - raise NotImplementedError + def pipeline(self) -> PipelineSource: + return PipelineSource(self) + @property + def _pipeline_cls(self) -> Type["_BasePipeline"]: + raise NotImplementedError def _reference_info(references: list) -> Tuple[list, dict]: """Get information about document references. diff --git a/google/cloud/firestore_v1/client.py b/google/cloud/firestore_v1/client.py index ed1a2543f..f28b8eed5 100644 --- a/google/cloud/firestore_v1/client.py +++ b/google/cloud/firestore_v1/client.py @@ -406,5 +406,6 @@ def transaction(self, **kwargs) -> Transaction: """ return Transaction(self, **kwargs) - def pipeline(self, *stages) -> Pipeline: - return Pipeline(self, *stages) + @property + def _pipeline_cls(self): + raise Pipeline \ No newline at end of file diff --git a/google/cloud/firestore_v1/pipeline_source.py b/google/cloud/firestore_v1/pipeline_source.py new file mode 100644 index 000000000..f1ede9302 --- /dev/null +++ b/google/cloud/firestore_v1/pipeline_source.py @@ -0,0 +1,85 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations +from typing import Generic, TypeVar, TYPE_CHECKING +from google.cloud.firestore_v1 import pipeline_stages as stages +from google.cloud.firestore_v1.base_pipeline import _BasePipeline + +if TYPE_CHECKING: + from google.cloud.firestore_v1.client import Client + from google.cloud.firestore_v1.async_client import AsyncClient + from google.cloud.firestore_v1.document import DocumentReference + + + +PipelineType = TypeVar('PipelineType', bound=_BasePipeline) + +class PipelineSource(Generic[PipelineType]): + """ + A factory for creating Pipeline instances, which provide a framework for building data + transformation and query pipelines for Firestore. + + Start by calling client.pipeline() to obtain an instance of PipelineSource. + From there, you can use the provided methods .collection() to specify the + data source for your pipeline. + + This class is typically used to start building Firestore pipelines. It allows you to define + the initial data source for a pipeline. + """ + + def __init__(self, client: Client | AsyncClient): + self.client = client + + def collection(self, path: str) -> PipelineType: + """ + Creates a new Pipeline that operates on a specified Firestore collection. + + Args: + path: The path to the Firestore collection (e.g., "users") + Returns: + a new pipeline instance targeting the specified collection + """ + return self.client._pipeline_cls(self.client, stages.Collection(path)) + + def collection_group(self, collection_id: str) -> PipelineType: + """ + Creates a new Pipeline that that operates on all documents in a collection group. + + Args: + collection_id: The ID of the collection group + Returns: + a new pipeline instance targeting the specified collection group + """ + return self.client._pipeline_cls(self.client, stages.CollectionGroup(collection_id)) + + def database(self) -> PipelineType: + """ + Creates a new Pipeline that operates on all documents in the Firestore database. + + Returns: + a new pipeline instance targeting the specified collection + """ + return self.client._pipeline_cls(self.client, stages.Database()) + + def documents(self, *docs: "DocumentReference") -> PipelineType: + """ + Creates a new Pipeline that operates on a specific set of Firestore documents. + + Args: + docs: The DocumentReference instances representing the documents to include in the pipeline. + Returns: + a new pipeline instance targeting the specified documents + """ + return self.client._pipeline_cls(self.client, stages.Documents.of(*docs)) \ No newline at end of file diff --git a/google/cloud/firestore_v1/pipeline_stages.py b/google/cloud/firestore_v1/pipeline_stages.py index 8e51fb1c1..3f796ddc5 100644 --- a/google/cloud/firestore_v1/pipeline_stages.py +++ b/google/cloud/firestore_v1/pipeline_stages.py @@ -53,6 +53,60 @@ def __repr__(self): return f"{self.__class__.__name__}({', '.join(items)})" +class Collection(Stage): + """Specifies a collection as the initial data source.""" + + def __init__(self, path: str): + super().__init__() + if not path.startswith("/"): + path = f"/{path}" + self.path = path + + def _pb_args(self): + return [Value(reference_value=self.path)] + + +class CollectionGroup(Stage): + """Specifies a collection group as the initial data source.""" + + def __init__(self, collection_id: str): + super().__init__("collection_group") + self.collection_id = collection_id + + def _pb_args(self): + return [Value(string_value=self.collection_id)] + + +class Database(Stage): + """Specifies the default database as the initial data source.""" + + def __init__(self): + super().__init__() + + def _pb_args(self): + return [] + + +class Documents(Stage): + """Specifies specific documents as the initial data source.""" + + def __init__(self, *paths: str): + super().__init__() + self.paths = paths + + @staticmethod + def of(*documents: "DocumentReference") -> "Documents": + doc_paths = ["/" + doc.path for doc in documents] + return Documents(*doc_paths) + + def _pb_args(self): + return [ + Value( + list_value={"values": [Value(string_value=path) for path in self.paths]} + ) + ] + + class GenericStage(Stage): """Represents a generic, named stage with parameters.""" diff --git a/noxfile.py b/noxfile.py index 503b049ac..47ad67649 100644 --- a/noxfile.py +++ b/noxfile.py @@ -153,6 +153,7 @@ def mypy(session): session.run("mypy", "-p", "google.cloud.firestore_v1.pipeline_expressions", "-p", "google.cloud.firestore_v1.pipeline_stages", + "-p", "google.cloud.firestore_v1.pipeline_source", "-p", "google.cloud.firestore_v1.pipeline", "--no-incremental") From 1f2390a30d12f92728bb6472ae2cbdeb1140623e Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 6 May 2025 14:59:23 -0700 Subject: [PATCH 104/131] removed collection.pipeline --- google/cloud/firestore_v1/base_collection.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/google/cloud/firestore_v1/base_collection.py b/google/cloud/firestore_v1/base_collection.py index f0e3a7b6a..1ac1ba318 100644 --- a/google/cloud/firestore_v1/base_collection.py +++ b/google/cloud/firestore_v1/base_collection.py @@ -590,9 +590,6 @@ def find_nearest( distance_threshold=distance_threshold, ) - def pipeline(self): - return self._query().pipeline() - def _auto_id() -> str: """Generate a "random" automatically generated ID. From 39f261a5c30d4ecca8f07d58d963e738e50bbb10 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 6 May 2025 15:14:07 -0700 Subject: [PATCH 105/131] fixed docstrings --- google/cloud/firestore_v1/async_pipeline.py | 6 +++--- google/cloud/firestore_v1/pipeline.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/google/cloud/firestore_v1/async_pipeline.py b/google/cloud/firestore_v1/async_pipeline.py index 0e3453c94..8c1340379 100644 --- a/google/cloud/firestore_v1/async_pipeline.py +++ b/google/cloud/firestore_v1/async_pipeline.py @@ -37,14 +37,14 @@ class AsyncPipeline(_BasePipeline): >>> >>> async def run_pipeline(): ... client = AsyncClient(...) - ... pipeline = client.collection("books") - ... .pipeline() + ... pipeline = client.pipeline() + ... .collection("books") ... .where(Field.of("published").gt(1980)) ... .select("title", "author") ... async for result in pipeline.execute_async(): ... print(result) - Use `client.collection("...").pipeline()` to create instances of this class. + Use `client.pipeline()` to create instances of this class. """ def __init__(self, client: AsyncClient, *stages: stages.Stage): diff --git a/google/cloud/firestore_v1/pipeline.py b/google/cloud/firestore_v1/pipeline.py index 67fa11fba..86cff0c10 100644 --- a/google/cloud/firestore_v1/pipeline.py +++ b/google/cloud/firestore_v1/pipeline.py @@ -34,14 +34,14 @@ class Pipeline(_BasePipeline): >>> >>> def run_pipeline(): ... client = Client(...) - ... pipeline = client.collection("books") - ... .pipeline() + ... pipeline = client.pipeline() + ... .collection("books") ... .where(Field.of("published").gt(1980)) ... .select("title", "author") ... for result in pipeline.execute(): ... print(result) - Use `client.collection("...").pipeline()` to create instances of this class. + Use `client.pipeline()` to create instances of this class. """ def __init__(self, client: Client, *stages: stages.Stage): From 48bf9f76bd5b470e89886f1047e94f5a5051b9d6 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 6 May 2025 16:45:24 -0700 Subject: [PATCH 106/131] added pipeline_result --- google/cloud/firestore_v1/async_pipeline.py | 17 ++- google/cloud/firestore_v1/base_pipeline.py | 25 ++-- google/cloud/firestore_v1/pipeline.py | 16 ++- google/cloud/firestore_v1/pipeline_result.py | 138 +++++++++++++++++++ google/cloud/firestore_v1/pipeline_source.py | 4 +- google/cloud/firestore_v1/pipeline_stages.py | 8 +- noxfile.py | 5 +- 7 files changed, 185 insertions(+), 28 deletions(-) create mode 100644 google/cloud/firestore_v1/pipeline_result.py diff --git a/google/cloud/firestore_v1/async_pipeline.py b/google/cloud/firestore_v1/async_pipeline.py index 8c1340379..34cd3d4e2 100644 --- a/google/cloud/firestore_v1/async_pipeline.py +++ b/google/cloud/firestore_v1/async_pipeline.py @@ -17,8 +17,9 @@ from typing import AsyncIterable, TYPE_CHECKING from google.cloud.firestore_v1 import pipeline_stages as stages from google.cloud.firestore_v1.types.firestore import ExecutePipelineRequest -from google.cloud.firestore_v1.types.firestore import ExecutePipelineResponse +from google.cloud.firestore_v1.async_document import AsyncDocumentReference from google.cloud.firestore_v1.base_pipeline import _BasePipeline +from google.cloud.firestore_v1.pipeline_result import PipelineResult if TYPE_CHECKING: from google.cloud.firestore_v1.async_client import AsyncClient @@ -57,7 +58,7 @@ def __init__(self, client: AsyncClient, *stages: stages.Stage): """ super().__init__(client, *stages) - async def execute(self) -> AsyncIterable["DocumentSnapshot"]: + async def execute(self) -> AsyncIterable[PipelineResult]: database_name = ( f"projects/{self._client.project}/databases/{self._client._database}" ) @@ -69,5 +70,13 @@ async def execute(self) -> AsyncIterable["DocumentSnapshot"]: async for response in await self._client._firestore_api.execute_pipeline( request ): - for snapshot in self._parse_response(response, self._client): - yield snapshot + for doc in response.results: + doc_ref = AsyncDocumentReference(doc.name, client=self._client) if doc.name else None + yield PipelineResult( + self._client, + doc.fields, + doc_ref, + response._pb.execution_time, + doc.create_time, + doc.update_tiem, + ) \ No newline at end of file diff --git a/google/cloud/firestore_v1/base_pipeline.py b/google/cloud/firestore_v1/base_pipeline.py index 0793e58f9..bf4044a54 100644 --- a/google/cloud/firestore_v1/base_pipeline.py +++ b/google/cloud/firestore_v1/base_pipeline.py @@ -13,13 +13,19 @@ # limitations under the License. from __future__ import annotations +from typing import TYPE_CHECKING from google.cloud.firestore_v1 import pipeline_stages as stages -from google.cloud.firestore_v1.base_client import BaseClient +from google.cloud.firestore_v1.document import DocumentReference from google.cloud.firestore_v1.types.pipeline import ( StructuredPipeline as StructuredPipeline_pb, ) +from google.cloud.firestore_v1.pipeline_result import PipelineResult from google.cloud.firestore_v1 import _helpers, document +if TYPE_CHECKING: + from google.cloud.firestore_v1.client import Client + from google.cloud.firestore_v1.async_client import AsyncClient + class _BasePipeline: """ @@ -29,7 +35,7 @@ class _BasePipeline: Use `client.collection.("...").pipeline()` to create pipeline instances. """ - def __init__(self, client: BaseClient, *stages: stages.Stage): + def __init__(self, client: Client | AsyncClient, *stages: stages.Stage): """ Initializes a new pipeline with the given stages. @@ -60,17 +66,4 @@ def _append(self, new_stage): """ Create a new Pipeline object with a new stage appended """ - return self.__class__(self._client, *self.stages, new_stage) - - @staticmethod - def _parse_response(response_pb, client): - for doc in response_pb.results: - data = _helpers.decode_dict(doc.fields, client) - yield document.DocumentSnapshot( - None, - data, - exists=True, - read_time=response_pb._pb.execution_time, - create_time=doc.create_time, - update_time=doc.update_time, - ) \ No newline at end of file + return self.__class__(self._client, *self.stages, new_stage) \ No newline at end of file diff --git a/google/cloud/firestore_v1/pipeline.py b/google/cloud/firestore_v1/pipeline.py index 86cff0c10..68771496e 100644 --- a/google/cloud/firestore_v1/pipeline.py +++ b/google/cloud/firestore_v1/pipeline.py @@ -17,8 +17,9 @@ from typing import AsyncIterable, Iterable, TYPE_CHECKING from google.cloud.firestore_v1 import pipeline_stages as stages from google.cloud.firestore_v1.types.firestore import ExecutePipelineRequest -from google.cloud.firestore_v1.types.firestore import ExecutePipelineResponse +from google.cloud.firestore_v1.document import DocumentReference from google.cloud.firestore_v1.base_pipeline import _BasePipeline +from google.cloud.firestore_v1.pipeline_result import PipelineResult if TYPE_CHECKING: from google.cloud.firestore_v1.client import Client @@ -54,7 +55,7 @@ def __init__(self, client: Client, *stages: stages.Stage): """ super().__init__(client, *stages) - def execute(self) -> Iterable["DocumentSnapshot"]: + def execute(self) -> Iterable[PipelineResult]: database_name = ( f"projects/{self._client.project}/databases/{self._client._database}" ) @@ -63,4 +64,13 @@ def execute(self) -> Iterable["DocumentSnapshot"]: structured_pipeline=self._to_pb(), ) for response in self._client._firestore_api.execute_pipeline(request): - yield from self._parse_response(response, self) + for doc in response.results: + doc_ref = DocumentReference(doc.name, client=self._client) if doc.name else None + yield PipelineResult( + self._client, + doc.fields, + doc_ref, + response._pb.execution_time, + doc.create_time, + doc.update_tiem, + ) \ No newline at end of file diff --git a/google/cloud/firestore_v1/pipeline_result.py b/google/cloud/firestore_v1/pipeline_result.py new file mode 100644 index 000000000..56a7c2ea4 --- /dev/null +++ b/google/cloud/firestore_v1/pipeline_result.py @@ -0,0 +1,138 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations +from typing import Any, TYPE_CHECKING +from google.cloud.firestore_v1._helpers import decode_dict +from google.cloud.firestore_v1._helpers import decode_value +from google.cloud.firestore_v1.field_path import get_nested_value +from google.cloud.firestore_v1.field_path import FieldPath + +if TYPE_CHECKING: + from google.cloud.firestore_v1.base_client import BaseClient + from google.cloud.firestore_v1.base_document import BaseDocumentReference + from google.protobuf.timestamp_pb2 import Timestamp + from google.cloud.firestore_v1.types.document import Value as ValueProto + + +class PipelineResult: + """ + Contains data read from a Firestore Pipeline. The data can be extracted with + the `data()` or `get()` methods. + + If the PipelineResult represents a non-document result `ref` may be `None`. + """ + + def __init__( + self, + client: BaseClient, + fields_pb: dict[str, ValueProto], + ref: BaseDocumentReference | None = None, + execution_time: Timestamp | None = None, + create_time: Timestamp | None = None, + update_time: Timestamp | None = None, + ): + """ + PipelineResult should be returned from `pipeline.execute()`, not constructed manually. + + Args: + client: The Firestore client instance. + fields_pb: A map of field names to their protobuf Value representations. + ref: The DocumentReference or AsyncDocumentReference if this result corresponds to a document. + execution_time: The time at which the pipeline execution producing this result occurred. + create_time: The creation time of the document, if applicable. + update_time: The last update time of the document, if applicable. + """ + self._client = client + self._fields_pb = fields_pb + self._ref = ref + self._execution_time = execution_time + self._create_time = create_time + self._update_time = update_time + + @property + def ref(self) -> BaseDocumentReference | None: + """ + The `BaseDocumentReference` if this result represents a document, else `None`. + """ + return self._ref + + @property + def id(self) -> str | None: + """The ID of the document if this result represents a document, else `None`.""" + return self._ref.id if self._ref else None + + @property + def create_time(self) -> Timestamp | None: + """The creation time of the document. `None` if not applicable (e.g., not a document result or document doesn't exist).""" + return self._create_time + + @property + def update_time(self) -> Timestamp | None: + """The last update time of the document. `None` if not applicable.""" + return self._update_time + + @property + def execution_time(self) -> Timestamp: + """ + The time at which the pipeline producing this result was executed. + + Raise: + ValueError: if not set + """ + if self._execution_time is None: + raise ValueError("'execution_time' is expected to exist, but it is None.") + return self._execution_time + + def __eq__(self, other: object) -> bool: + """ + Compares this `PipelineResult` to another object for equality. + + Two `PipelineResult` instances are considered equal if their document + references (if any) are equal and their underlying field data + (protobuf representation) is identical. + """ + if not isinstance(other, PipelineResult): + return NotImplemented + return (self._ref == other._ref) and (self._fields_pb == other._fields_pb) + + def data(self) -> Any: + """ + Retrieves all fields in the result. + + If a converter was provided to this `PipelineResult`, the result of the + converter's `from_firestore` method is returned. + + Returns: + The data, either as a custom object (if a converter is used) or a dictionary. + Returns `None` if the document doesn't exist. + """ + if self._fields_pb is None: + return None + + return decode_dict(self._fields_pb, self._client) + + def get(self, field_path: str | FieldPath) -> Any: + """ + Retrieves the field specified by `field_path`. + + Args: + field_path: The field path (e.g. 'foo' or 'foo.bar') to a specific field. + + Returns: + The data at the specified field location, decoded to Python types. + """ + str_path = field_path if isinstance(field_path, str) else field_path.to_api_repr() + value = get_nested_value(str_path, self._fields_pb) + return decode_value(value, self._client) diff --git a/google/cloud/firestore_v1/pipeline_source.py b/google/cloud/firestore_v1/pipeline_source.py index f1ede9302..b1a49c0f6 100644 --- a/google/cloud/firestore_v1/pipeline_source.py +++ b/google/cloud/firestore_v1/pipeline_source.py @@ -20,7 +20,7 @@ if TYPE_CHECKING: from google.cloud.firestore_v1.client import Client from google.cloud.firestore_v1.async_client import AsyncClient - from google.cloud.firestore_v1.document import DocumentReference + from google.cloud.firestore_v1.base_document import BaseDocumentReference @@ -73,7 +73,7 @@ def database(self) -> PipelineType: """ return self.client._pipeline_cls(self.client, stages.Database()) - def documents(self, *docs: "DocumentReference") -> PipelineType: + def documents(self, *docs: "BaseDocumentReference") -> PipelineType: """ Creates a new Pipeline that operates on a specific set of Firestore documents. diff --git a/google/cloud/firestore_v1/pipeline_stages.py b/google/cloud/firestore_v1/pipeline_stages.py index 3f796ddc5..242671756 100644 --- a/google/cloud/firestore_v1/pipeline_stages.py +++ b/google/cloud/firestore_v1/pipeline_stages.py @@ -13,7 +13,7 @@ # limitations under the License. from __future__ import annotations -from typing import Optional +from typing import Optional, TYPE_CHECKING from abc import ABC from abc import abstractmethod @@ -23,6 +23,10 @@ Expr, ) +if TYPE_CHECKING: + from google.cloud.firestore_v1.base_document import BaseDocumentReference + + class Stage(ABC): """Base class for all pipeline stages. @@ -95,7 +99,7 @@ def __init__(self, *paths: str): self.paths = paths @staticmethod - def of(*documents: "DocumentReference") -> "Documents": + def of(*documents: "BaseDocumentReference") -> "Documents": doc_paths = ["/" + doc.path for doc in documents] return Documents(*doc_paths) diff --git a/noxfile.py b/noxfile.py index 47ad67649..de5dd7c22 100644 --- a/noxfile.py +++ b/noxfile.py @@ -148,12 +148,15 @@ def pytype(session): def mypy(session): """Verify type hints are mypy compatible.""" session.install("-e", ".") - session.install("mypy", "types-setuptools") + session.install("mypy", "types-setuptools", "types-protobuf") # TODO: also verify types on tests, all of google package session.run("mypy", "-p", "google.cloud.firestore_v1.pipeline_expressions", "-p", "google.cloud.firestore_v1.pipeline_stages", "-p", "google.cloud.firestore_v1.pipeline_source", + "-p", "google.cloud.firestore_v1.pipeline_result", + "-p", "google.cloud.firestore_v1.base_pipeline", + "-p", "google.cloud.firestore_v1.async_pipeline", "-p", "google.cloud.firestore_v1.pipeline", "--no-incremental") From b57424d9e786250669a6105a114c176b40e9f9c4 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 8 May 2025 14:07:16 -0700 Subject: [PATCH 107/131] chore: updated gapic layer for execute_query --- .../services/firestore/async_client.py | 107 +++++++ .../firestore_v1/services/firestore/client.py | 105 +++++++ .../services/firestore/transports/base.py | 17 ++ .../services/firestore/transports/grpc.py | 28 ++ .../firestore/transports/grpc_asyncio.py | 33 +++ .../services/firestore/transports/rest.py | 266 ++++++++++++++++-- .../firestore/transports/rest_base.py | 120 +++++--- google/cloud/firestore_v1/types/__init__.py | 16 ++ google/cloud/firestore_v1/types/document.py | 165 +++++++++++ .../cloud/firestore_v1/types/explain_stats.py | 53 ++++ google/cloud/firestore_v1/types/firestore.py | 145 ++++++++++ google/cloud/firestore_v1/types/pipeline.py | 61 ++++ 12 files changed, 1048 insertions(+), 68 deletions(-) create mode 100644 google/cloud/firestore_v1/types/explain_stats.py create mode 100644 google/cloud/firestore_v1/types/pipeline.py diff --git a/google/cloud/firestore_v1/services/firestore/async_client.py b/google/cloud/firestore_v1/services/firestore/async_client.py index 56cf7d3af..916914969 100644 --- a/google/cloud/firestore_v1/services/firestore/async_client.py +++ b/google/cloud/firestore_v1/services/firestore/async_client.py @@ -52,6 +52,7 @@ from google.cloud.firestore_v1.types import common from google.cloud.firestore_v1.types import document from google.cloud.firestore_v1.types import document as gf_document +from google.cloud.firestore_v1.types import explain_stats from google.cloud.firestore_v1.types import firestore from google.cloud.firestore_v1.types import query from google.cloud.firestore_v1.types import query_profile @@ -236,6 +237,9 @@ def __init__( If a Callable is given, it will be called with the same set of initialization arguments as used in the FirestoreTransport constructor. If set to None, a transport is chosen automatically. + NOTE: "rest" transport functionality is currently in a + beta state (preview). We welcome your feedback via an + issue in this library's source repository. client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): Custom options for the client. @@ -1247,6 +1251,109 @@ async def sample_run_query(): # Done; return the response. return response + def execute_pipeline( + self, + request: Optional[Union[firestore.ExecutePipelineRequest, dict]] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> Awaitable[AsyncIterable[firestore.ExecutePipelineResponse]]: + r"""Executes a pipeline query. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import firestore_v1 + + async def sample_execute_pipeline(): + # Create a client + client = firestore_v1.FirestoreAsyncClient() + + # Initialize request argument(s) + structured_pipeline = firestore_v1.StructuredPipeline() + structured_pipeline.pipeline.stages.name = "name_value" + + request = firestore_v1.ExecutePipelineRequest( + structured_pipeline=structured_pipeline, + transaction=b'transaction_blob', + database="database_value", + ) + + # Make the request + stream = await client.execute_pipeline(request=request) + + # Handle the response + async for response in stream: + print(response) + + Args: + request (Optional[Union[google.cloud.firestore_v1.types.ExecutePipelineRequest, dict]]): + The request object. The request for + [Firestore.ExecutePipeline][google.firestore.v1.Firestore.ExecutePipeline]. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + AsyncIterable[google.cloud.firestore_v1.types.ExecutePipelineResponse]: + The response for [Firestore.Execute][]. + """ + # Create or coerce a protobuf request object. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, firestore.ExecutePipelineRequest): + request = firestore.ExecutePipelineRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.execute_pipeline + ] + + header_params = {} + + routing_param_regex = re.compile("^projects/(?P[^/]+)(?:/.*)?$") + regex_match = routing_param_regex.match(request.database) + if regex_match and regex_match.group("project_id"): + header_params["project_id"] = regex_match.group("project_id") + + routing_param_regex = re.compile( + "^projects/[^/]+/databases/(?P[^/]+)(?:/.*)?$" + ) + regex_match = routing_param_regex.match(request.database) + if regex_match and regex_match.group("database_id"): + header_params["database_id"] = regex_match.group("database_id") + + if header_params: + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(header_params), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + def run_aggregation_query( self, request: Optional[Union[firestore.RunAggregationQueryRequest, dict]] = None, diff --git a/google/cloud/firestore_v1/services/firestore/client.py b/google/cloud/firestore_v1/services/firestore/client.py index 1fb800e61..340cd5ef2 100644 --- a/google/cloud/firestore_v1/services/firestore/client.py +++ b/google/cloud/firestore_v1/services/firestore/client.py @@ -67,6 +67,7 @@ from google.cloud.firestore_v1.types import common from google.cloud.firestore_v1.types import document from google.cloud.firestore_v1.types import document as gf_document +from google.cloud.firestore_v1.types import explain_stats from google.cloud.firestore_v1.types import firestore from google.cloud.firestore_v1.types import query from google.cloud.firestore_v1.types import query_profile @@ -551,6 +552,9 @@ def __init__( If a Callable is given, it will be called with the same set of initialization arguments as used in the FirestoreTransport constructor. If set to None, a transport is chosen automatically. + NOTE: "rest" transport functionality is currently in a + beta state (preview). We welcome your feedback via an + issue in this library's source repository. client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): Custom options for the client. @@ -1630,6 +1634,107 @@ def sample_run_query(): # Done; return the response. return response + def execute_pipeline( + self, + request: Optional[Union[firestore.ExecutePipelineRequest, dict]] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> Iterable[firestore.ExecutePipelineResponse]: + r"""Executes a pipeline query. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import firestore_v1 + + def sample_execute_pipeline(): + # Create a client + client = firestore_v1.FirestoreClient() + + # Initialize request argument(s) + structured_pipeline = firestore_v1.StructuredPipeline() + structured_pipeline.pipeline.stages.name = "name_value" + + request = firestore_v1.ExecutePipelineRequest( + structured_pipeline=structured_pipeline, + transaction=b'transaction_blob', + database="database_value", + ) + + # Make the request + stream = client.execute_pipeline(request=request) + + # Handle the response + for response in stream: + print(response) + + Args: + request (Union[google.cloud.firestore_v1.types.ExecutePipelineRequest, dict]): + The request object. The request for + [Firestore.ExecutePipeline][google.firestore.v1.Firestore.ExecutePipeline]. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + Iterable[google.cloud.firestore_v1.types.ExecutePipelineResponse]: + The response for [Firestore.Execute][]. + """ + # Create or coerce a protobuf request object. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, firestore.ExecutePipelineRequest): + request = firestore.ExecutePipelineRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.execute_pipeline] + + header_params = {} + + routing_param_regex = re.compile("^projects/(?P[^/]+)(?:/.*)?$") + regex_match = routing_param_regex.match(request.database) + if regex_match and regex_match.group("project_id"): + header_params["project_id"] = regex_match.group("project_id") + + routing_param_regex = re.compile( + "^projects/[^/]+/databases/(?P[^/]+)(?:/.*)?$" + ) + regex_match = routing_param_regex.match(request.database) + if regex_match and regex_match.group("database_id"): + header_params["database_id"] = regex_match.group("database_id") + + if header_params: + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(header_params), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + def run_aggregation_query( self, request: Optional[Union[firestore.RunAggregationQueryRequest, dict]] = None, diff --git a/google/cloud/firestore_v1/services/firestore/transports/base.py b/google/cloud/firestore_v1/services/firestore/transports/base.py index 862b098d1..50e0b6dd3 100644 --- a/google/cloud/firestore_v1/services/firestore/transports/base.py +++ b/google/cloud/firestore_v1/services/firestore/transports/base.py @@ -286,6 +286,11 @@ def _prep_wrapped_messages(self, client_info): default_timeout=300.0, client_info=client_info, ), + self.execute_pipeline: gapic_v1.method.wrap_method( + self.execute_pipeline, + default_timeout=None, + client_info=client_info, + ), self.run_aggregation_query: gapic_v1.method.wrap_method( self.run_aggregation_query, default_retry=retries.Retry( @@ -509,6 +514,18 @@ def run_query( ]: raise NotImplementedError() + @property + def execute_pipeline( + self, + ) -> Callable[ + [firestore.ExecutePipelineRequest], + Union[ + firestore.ExecutePipelineResponse, + Awaitable[firestore.ExecutePipelineResponse], + ], + ]: + raise NotImplementedError() + @property def run_aggregation_query( self, diff --git a/google/cloud/firestore_v1/services/firestore/transports/grpc.py b/google/cloud/firestore_v1/services/firestore/transports/grpc.py index c302a73c2..2a8f4caf9 100644 --- a/google/cloud/firestore_v1/services/firestore/transports/grpc.py +++ b/google/cloud/firestore_v1/services/firestore/transports/grpc.py @@ -571,6 +571,34 @@ def run_query( ) return self._stubs["run_query"] + @property + def execute_pipeline( + self, + ) -> Callable[ + [firestore.ExecutePipelineRequest], firestore.ExecutePipelineResponse + ]: + r"""Return a callable for the execute pipeline method over gRPC. + + Executes a pipeline query. + + Returns: + Callable[[~.ExecutePipelineRequest], + ~.ExecutePipelineResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "execute_pipeline" not in self._stubs: + self._stubs["execute_pipeline"] = self._logged_channel.unary_stream( + "/google.firestore.v1.Firestore/ExecutePipeline", + request_serializer=firestore.ExecutePipelineRequest.serialize, + response_deserializer=firestore.ExecutePipelineResponse.deserialize, + ) + return self._stubs["execute_pipeline"] + @property def run_aggregation_query( self, diff --git a/google/cloud/firestore_v1/services/firestore/transports/grpc_asyncio.py b/google/cloud/firestore_v1/services/firestore/transports/grpc_asyncio.py index f46162296..8801dc45a 100644 --- a/google/cloud/firestore_v1/services/firestore/transports/grpc_asyncio.py +++ b/google/cloud/firestore_v1/services/firestore/transports/grpc_asyncio.py @@ -587,6 +587,34 @@ def run_query( ) return self._stubs["run_query"] + @property + def execute_pipeline( + self, + ) -> Callable[ + [firestore.ExecutePipelineRequest], Awaitable[firestore.ExecutePipelineResponse] + ]: + r"""Return a callable for the execute pipeline method over gRPC. + + Executes a pipeline query. + + Returns: + Callable[[~.ExecutePipelineRequest], + Awaitable[~.ExecutePipelineResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "execute_pipeline" not in self._stubs: + self._stubs["execute_pipeline"] = self._logged_channel.unary_stream( + "/google.firestore.v1.Firestore/ExecutePipeline", + request_serializer=firestore.ExecutePipelineRequest.serialize, + response_deserializer=firestore.ExecutePipelineResponse.deserialize, + ) + return self._stubs["execute_pipeline"] + @property def run_aggregation_query( self, @@ -962,6 +990,11 @@ def _prep_wrapped_messages(self, client_info): default_timeout=300.0, client_info=client_info, ), + self.execute_pipeline: self._wrap_method( + self.execute_pipeline, + default_timeout=None, + client_info=client_info, + ), self.run_aggregation_query: self._wrap_method( self.run_aggregation_query, default_retry=retries.AsyncRetry( diff --git a/google/cloud/firestore_v1/services/firestore/transports/rest.py b/google/cloud/firestore_v1/services/firestore/transports/rest.py index 3794ecea3..4bd282fe6 100644 --- a/google/cloud/firestore_v1/services/firestore/transports/rest.py +++ b/google/cloud/firestore_v1/services/firestore/transports/rest.py @@ -123,6 +123,14 @@ def pre_delete_document(self, request, metadata): logging.log(f"Received request: {request}") return request, metadata + def pre_execute_pipeline(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_execute_pipeline(self, response): + logging.log(f"Received response: {response}") + return response + def pre_get_document(self, request, metadata): logging.log(f"Received request: {request}") return request, metadata @@ -441,6 +449,56 @@ def pre_delete_document( """ return request, metadata + def pre_execute_pipeline( + self, + request: firestore.ExecutePipelineRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + firestore.ExecutePipelineRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for execute_pipeline + + Override in a subclass to manipulate the request or metadata + before they are sent to the Firestore server. + """ + return request, metadata + + def post_execute_pipeline( + self, response: rest_streaming.ResponseIterator + ) -> rest_streaming.ResponseIterator: + """Post-rpc interceptor for execute_pipeline + + DEPRECATED. Please use the `post_execute_pipeline_with_metadata` + interceptor instead. + + Override in a subclass to read or manipulate the response + after it is returned by the Firestore server but before + it is returned to user code. This `post_execute_pipeline` interceptor runs + before the `post_execute_pipeline_with_metadata` interceptor. + """ + return response + + def post_execute_pipeline_with_metadata( + self, + response: rest_streaming.ResponseIterator, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + rest_streaming.ResponseIterator, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Post-rpc interceptor for execute_pipeline + + Override in a subclass to read or manipulate the response or metadata after it + is returned by the Firestore server but before it is returned to user code. + + We recommend only using this `post_execute_pipeline_with_metadata` + interceptor in new development instead of the `post_execute_pipeline` interceptor. + When both interceptors are used, this `post_execute_pipeline_with_metadata` interceptor runs after the + `post_execute_pipeline` interceptor. The (possibly modified) response returned by + `post_execute_pipeline` will be passed to + `post_execute_pipeline_with_metadata`. + """ + return response, metadata + def pre_get_document( self, request: firestore.GetDocumentRequest, @@ -932,35 +990,39 @@ def __init__( ) -> None: """Instantiate the transport. - Args: - host (Optional[str]): - The hostname to connect to (default: 'firestore.googleapis.com'). - credentials (Optional[google.auth.credentials.Credentials]): The - authorization credentials to attach to requests. These - credentials identify the application to the service; if none - are specified, the client will attempt to ascertain the - credentials from the environment. - - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. - scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client - certificate to configure mutual TLS HTTP channel. It is ignored - if ``channel`` is provided. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you are developing - your own client library. - always_use_jwt_access (Optional[bool]): Whether self signed JWT should - be used for service account credentials. - url_scheme: the protocol scheme for the API endpoint. Normally - "https", but for testing or local servers, - "http" can be specified. + NOTE: This REST transport functionality is currently in a beta + state (preview). We welcome your feedback via a GitHub issue in + this library's repository. Thank you! + + Args: + host (Optional[str]): + The hostname to connect to (default: 'firestore.googleapis.com'). + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you are developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + url_scheme: the protocol scheme for the API endpoint. Normally + "https", but for testing or local servers, + "http" can be specified. """ # Run the base constructor # TODO(yon-mg): resolve other ctor params i.e. scopes, quota, etc. @@ -1852,6 +1914,142 @@ def __call__( if response.status_code >= 400: raise core_exceptions.from_http_response(response) + class _ExecutePipeline( + _BaseFirestoreRestTransport._BaseExecutePipeline, FirestoreRestStub + ): + def __hash__(self): + return hash("FirestoreRestTransport.ExecutePipeline") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + stream=True, + ) + return response + + def __call__( + self, + request: firestore.ExecutePipelineRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> rest_streaming.ResponseIterator: + r"""Call the execute pipeline method over HTTP. + + Args: + request (~.firestore.ExecutePipelineRequest): + The request object. The request for + [Firestore.ExecutePipeline][google.firestore.v1.Firestore.ExecutePipeline]. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.firestore.ExecutePipelineResponse: + The response for [Firestore.Execute][]. + """ + + http_options = ( + _BaseFirestoreRestTransport._BaseExecutePipeline._get_http_options() + ) + + request, metadata = self._interceptor.pre_execute_pipeline( + request, metadata + ) + transcoded_request = _BaseFirestoreRestTransport._BaseExecutePipeline._get_transcoded_request( + http_options, request + ) + + body = ( + _BaseFirestoreRestTransport._BaseExecutePipeline._get_request_body_json( + transcoded_request + ) + ) + + # Jsonify the query params + query_params = ( + _BaseFirestoreRestTransport._BaseExecutePipeline._get_query_params_json( + transcoded_request + ) + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.firestore_v1.FirestoreClient.ExecutePipeline", + extra={ + "serviceName": "google.firestore.v1.Firestore", + "rpcName": "ExecutePipeline", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = FirestoreRestTransport._ExecutePipeline._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = rest_streaming.ResponseIterator( + response, firestore.ExecutePipelineResponse + ) + + resp = self._interceptor.post_execute_pipeline(resp) + response_metadata = [(k, str(v)) for k, v in response.headers.items()] + resp, _ = self._interceptor.post_execute_pipeline_with_metadata( + resp, response_metadata + ) + return resp + class _GetDocument(_BaseFirestoreRestTransport._BaseGetDocument, FirestoreRestStub): def __hash__(self): return hash("FirestoreRestTransport.GetDocument") @@ -3090,6 +3288,16 @@ def delete_document( # In C++ this would require a dynamic_cast return self._DeleteDocument(self._session, self._host, self._interceptor) # type: ignore + @property + def execute_pipeline( + self, + ) -> Callable[ + [firestore.ExecutePipelineRequest], firestore.ExecutePipelineResponse + ]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._ExecutePipeline(self._session, self._host, self._interceptor) # type: ignore + @property def get_document( self, diff --git a/google/cloud/firestore_v1/services/firestore/transports/rest_base.py b/google/cloud/firestore_v1/services/firestore/transports/rest_base.py index 1d95cd16e..721f0792f 100644 --- a/google/cloud/firestore_v1/services/firestore/transports/rest_base.py +++ b/google/cloud/firestore_v1/services/firestore/transports/rest_base.py @@ -130,7 +130,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + transcoded_request["body"], use_integers_for_enums=False ) return body @@ -139,7 +139,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -148,7 +148,6 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseBatchWrite: @@ -187,7 +186,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + transcoded_request["body"], use_integers_for_enums=False ) return body @@ -196,7 +195,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -205,7 +204,6 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseBeginTransaction: @@ -244,7 +242,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + transcoded_request["body"], use_integers_for_enums=False ) return body @@ -253,7 +251,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -262,7 +260,6 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseCommit: @@ -301,7 +298,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + transcoded_request["body"], use_integers_for_enums=False ) return body @@ -310,7 +307,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -319,7 +316,6 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseCreateDocument: @@ -358,7 +354,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + transcoded_request["body"], use_integers_for_enums=False ) return body @@ -367,7 +363,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -376,7 +372,6 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseDeleteDocument: @@ -414,7 +409,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -423,7 +418,62 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseExecutePipeline: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1/{database=projects/*/databases/*}/documents:executePipeline", + "body": "*", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = firestore.ExecutePipelineRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=False + ) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=False, + ) + ) + query_params.update( + _BaseFirestoreRestTransport._BaseExecutePipeline._get_unset_required_fields( + query_params + ) + ) + return query_params class _BaseGetDocument: @@ -461,7 +511,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -470,7 +520,6 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseListCollectionIds: @@ -514,7 +563,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + transcoded_request["body"], use_integers_for_enums=False ) return body @@ -523,7 +572,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -532,7 +581,6 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseListDocuments: @@ -574,7 +622,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -583,7 +631,6 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseListen: @@ -631,7 +678,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + transcoded_request["body"], use_integers_for_enums=False ) return body @@ -640,7 +687,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -649,7 +696,6 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseRollback: @@ -688,7 +734,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + transcoded_request["body"], use_integers_for_enums=False ) return body @@ -697,7 +743,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -706,7 +752,6 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseRunAggregationQuery: @@ -750,7 +795,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + transcoded_request["body"], use_integers_for_enums=False ) return body @@ -759,7 +804,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -768,7 +813,6 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseRunQuery: @@ -812,7 +856,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + transcoded_request["body"], use_integers_for_enums=False ) return body @@ -821,7 +865,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -830,7 +874,6 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseUpdateDocument: @@ -869,7 +912,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + transcoded_request["body"], use_integers_for_enums=False ) return body @@ -878,7 +921,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -887,7 +930,6 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseWrite: diff --git a/google/cloud/firestore_v1/types/__init__.py b/google/cloud/firestore_v1/types/__init__.py index ae1004e13..ed1965d7f 100644 --- a/google/cloud/firestore_v1/types/__init__.py +++ b/google/cloud/firestore_v1/types/__init__.py @@ -28,9 +28,14 @@ from .document import ( ArrayValue, Document, + Function, MapValue, + Pipeline, Value, ) +from .explain_stats import ( + ExplainStats, +) from .firestore import ( BatchGetDocumentsRequest, BatchGetDocumentsResponse, @@ -42,6 +47,8 @@ CommitResponse, CreateDocumentRequest, DeleteDocumentRequest, + ExecutePipelineRequest, + ExecutePipelineResponse, GetDocumentRequest, ListCollectionIdsRequest, ListCollectionIdsResponse, @@ -62,6 +69,9 @@ WriteRequest, WriteResponse, ) +from .pipeline import ( + StructuredPipeline, +) from .query import ( Cursor, StructuredAggregationQuery, @@ -92,8 +102,11 @@ "TransactionOptions", "ArrayValue", "Document", + "Function", "MapValue", + "Pipeline", "Value", + "ExplainStats", "BatchGetDocumentsRequest", "BatchGetDocumentsResponse", "BatchWriteRequest", @@ -104,6 +117,8 @@ "CommitResponse", "CreateDocumentRequest", "DeleteDocumentRequest", + "ExecutePipelineRequest", + "ExecutePipelineResponse", "GetDocumentRequest", "ListCollectionIdsRequest", "ListCollectionIdsResponse", @@ -123,6 +138,7 @@ "UpdateDocumentRequest", "WriteRequest", "WriteResponse", + "StructuredPipeline", "Cursor", "StructuredAggregationQuery", "StructuredQuery", diff --git a/google/cloud/firestore_v1/types/document.py b/google/cloud/firestore_v1/types/document.py index 0942354f5..1757571b1 100644 --- a/google/cloud/firestore_v1/types/document.py +++ b/google/cloud/firestore_v1/types/document.py @@ -31,6 +31,8 @@ "Value", "ArrayValue", "MapValue", + "Function", + "Pipeline", }, ) @@ -183,6 +185,37 @@ class Value(proto.Message): map_value (google.cloud.firestore_v1.types.MapValue): A map value. + This field is a member of `oneof`_ ``value_type``. + field_reference_value (str): + Value which references a field. + + This is considered relative (vs absolute) since it only + refers to a field and not a field within a particular + document. + + **Requires:** + + - Must follow [field reference][FieldReference.field_path] + limitations. + + - Not allowed to be used when writing documents. + + This field is a member of `oneof`_ ``value_type``. + function_value (google.cloud.firestore_v1.types.Function): + A value that represents an unevaluated expression. + + **Requires:** + + - Not allowed to be used when writing documents. + + This field is a member of `oneof`_ ``value_type``. + pipeline_value (google.cloud.firestore_v1.types.Pipeline): + A value that represents an unevaluated pipeline. + + **Requires:** + + - Not allowed to be used when writing documents. + This field is a member of `oneof`_ ``value_type``. """ @@ -246,6 +279,23 @@ class Value(proto.Message): oneof="value_type", message="MapValue", ) + field_reference_value: str = proto.Field( + proto.STRING, + number=19, + oneof="value_type", + ) + function_value: "Function" = proto.Field( + proto.MESSAGE, + number=20, + oneof="value_type", + message="Function", + ) + pipeline_value: "Pipeline" = proto.Field( + proto.MESSAGE, + number=21, + oneof="value_type", + message="Pipeline", + ) class ArrayValue(proto.Message): @@ -285,4 +335,119 @@ class MapValue(proto.Message): ) +class Function(proto.Message): + r"""Represents an unevaluated scalar expression. + + For example, the expression ``like(user_name, "%alice%")`` is + represented as: + + :: + + name: "like" + args { field_reference: "user_name" } + args { string_value: "%alice%" } + + Attributes: + name (str): + Required. The name of the function to evaluate. + + **Requires:** + + - must be in snake case (lower case with underscore + separator). + args (MutableSequence[google.cloud.firestore_v1.types.Value]): + Optional. Ordered list of arguments the given + function expects. + options (MutableMapping[str, google.cloud.firestore_v1.types.Value]): + Optional. Optional named arguments that + certain functions may support. + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + args: MutableSequence["Value"] = proto.RepeatedField( + proto.MESSAGE, + number=2, + message="Value", + ) + options: MutableMapping[str, "Value"] = proto.MapField( + proto.STRING, + proto.MESSAGE, + number=3, + message="Value", + ) + + +class Pipeline(proto.Message): + r"""A Firestore query represented as an ordered list of + operations / stages. + + Attributes: + stages (MutableSequence[google.cloud.firestore_v1.types.Pipeline.Stage]): + Required. Ordered list of stages to evaluate. + """ + + class Stage(proto.Message): + r"""A single operation within a pipeline. + + A stage is made up of a unique name, and a list of arguments. The + exact number of arguments & types is dependent on the stage type. + + To give an example, the stage ``filter(state = "MD")`` would be + encoded as: + + :: + + name: "filter" + args { + function_value { + name: "eq" + args { field_reference_value: "state" } + args { string_value: "MD" } + } + } + + See public documentation for the full list. + + Attributes: + name (str): + Required. The name of the stage to evaluate. + + **Requires:** + + - must be in snake case (lower case with underscore + separator). + args (MutableSequence[google.cloud.firestore_v1.types.Value]): + Optional. Ordered list of arguments the given + stage expects. + options (MutableMapping[str, google.cloud.firestore_v1.types.Value]): + Optional. Optional named arguments that + certain functions may support. + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + args: MutableSequence["Value"] = proto.RepeatedField( + proto.MESSAGE, + number=2, + message="Value", + ) + options: MutableMapping[str, "Value"] = proto.MapField( + proto.STRING, + proto.MESSAGE, + number=3, + message="Value", + ) + + stages: MutableSequence[Stage] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message=Stage, + ) + + __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/firestore_v1/types/explain_stats.py b/google/cloud/firestore_v1/types/explain_stats.py new file mode 100644 index 000000000..1fda228b6 --- /dev/null +++ b/google/cloud/firestore_v1/types/explain_stats.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +import proto # type: ignore + +from google.protobuf import any_pb2 # type: ignore + + +__protobuf__ = proto.module( + package="google.firestore.v1", + manifest={ + "ExplainStats", + }, +) + + +class ExplainStats(proto.Message): + r"""Explain stats for an RPC request, includes both the optimized + plan and execution stats. + + Attributes: + data (google.protobuf.any_pb2.Any): + The format depends on the ``output_format`` options in the + request. + + The only option today is ``TEXT``, which is a + ``google.protobuf.StringValue``. + """ + + data: any_pb2.Any = proto.Field( + proto.MESSAGE, + number=1, + message=any_pb2.Any, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/firestore_v1/types/firestore.py b/google/cloud/firestore_v1/types/firestore.py index 53a6c6e7a..f1753c92f 100644 --- a/google/cloud/firestore_v1/types/firestore.py +++ b/google/cloud/firestore_v1/types/firestore.py @@ -22,6 +22,8 @@ from google.cloud.firestore_v1.types import aggregation_result from google.cloud.firestore_v1.types import common from google.cloud.firestore_v1.types import document as gf_document +from google.cloud.firestore_v1.types import explain_stats as gf_explain_stats +from google.cloud.firestore_v1.types import pipeline from google.cloud.firestore_v1.types import query as gf_query from google.cloud.firestore_v1.types import query_profile from google.cloud.firestore_v1.types import write @@ -48,6 +50,8 @@ "RollbackRequest", "RunQueryRequest", "RunQueryResponse", + "ExecutePipelineRequest", + "ExecutePipelineResponse", "RunAggregationQueryRequest", "RunAggregationQueryResponse", "PartitionQueryRequest", @@ -835,6 +839,147 @@ class RunQueryResponse(proto.Message): ) +class ExecutePipelineRequest(proto.Message): + r"""The request for + [Firestore.ExecutePipeline][google.firestore.v1.Firestore.ExecutePipeline]. + + This message has `oneof`_ fields (mutually exclusive fields). + For each oneof, at most one member field can be set at the same time. + Setting any member of the oneof automatically clears all other + members. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + database (str): + Required. Database identifier, in the form + ``projects/{project}/databases/{database}``. + structured_pipeline (google.cloud.firestore_v1.types.StructuredPipeline): + A pipelined operation. + + This field is a member of `oneof`_ ``pipeline_type``. + transaction (bytes): + Run the query within an already active + transaction. + The value here is the opaque transaction ID to + execute the query in. + + This field is a member of `oneof`_ ``consistency_selector``. + new_transaction (google.cloud.firestore_v1.types.TransactionOptions): + Execute the pipeline in a new transaction. + + The identifier of the newly created transaction + will be returned in the first response on the + stream. This defaults to a read-only + transaction. + + This field is a member of `oneof`_ ``consistency_selector``. + read_time (google.protobuf.timestamp_pb2.Timestamp): + Execute the pipeline in a snapshot + transaction at the given time. + This must be a microsecond precision timestamp + within the past one hour, or if Point-in-Time + Recovery is enabled, can additionally be a whole + minute timestamp within the past 7 days. + + This field is a member of `oneof`_ ``consistency_selector``. + """ + + database: str = proto.Field( + proto.STRING, + number=1, + ) + structured_pipeline: pipeline.StructuredPipeline = proto.Field( + proto.MESSAGE, + number=2, + oneof="pipeline_type", + message=pipeline.StructuredPipeline, + ) + transaction: bytes = proto.Field( + proto.BYTES, + number=5, + oneof="consistency_selector", + ) + new_transaction: common.TransactionOptions = proto.Field( + proto.MESSAGE, + number=6, + oneof="consistency_selector", + message=common.TransactionOptions, + ) + read_time: timestamp_pb2.Timestamp = proto.Field( + proto.MESSAGE, + number=7, + oneof="consistency_selector", + message=timestamp_pb2.Timestamp, + ) + + +class ExecutePipelineResponse(proto.Message): + r"""The response for [Firestore.Execute][]. + + Attributes: + transaction (bytes): + Newly created transaction identifier. + + This field is only specified as part of the first response + from the server, alongside the ``results`` field when the + original request specified + [ExecuteRequest.new_transaction][]. + results (MutableSequence[google.cloud.firestore_v1.types.Document]): + An ordered batch of results returned executing a pipeline. + + The batch size is variable, and can even be zero for when + only a partial progress message is returned. + + The fields present in the returned documents are only those + that were explicitly requested in the pipeline, this include + those like [``__name__``][google.firestore.v1.Document.name] + & + [``__update_time__``][google.firestore.v1.Document.update_time]. + This is explicitly a divergence from ``Firestore.RunQuery`` + / ``Firestore.GetDocument`` RPCs which always return such + fields even when they are not specified in the + [``mask``][google.firestore.v1.DocumentMask]. + execution_time (google.protobuf.timestamp_pb2.Timestamp): + The time at which the document(s) were read. + + This may be monotonically increasing; in this case, the + previous documents in the result stream are guaranteed not + to have changed between their ``execution_time`` and this + one. + + If the query returns no results, a response with + ``execution_time`` and no ``results`` will be sent, and this + represents the time at which the operation was run. + explain_stats (google.cloud.firestore_v1.types.ExplainStats): + Query explain stats. + + Contains all metadata related to pipeline + planning and execution, specific contents depend + on the supplied pipeline options. + """ + + transaction: bytes = proto.Field( + proto.BYTES, + number=1, + ) + results: MutableSequence[gf_document.Document] = proto.RepeatedField( + proto.MESSAGE, + number=2, + message=gf_document.Document, + ) + execution_time: timestamp_pb2.Timestamp = proto.Field( + proto.MESSAGE, + number=3, + message=timestamp_pb2.Timestamp, + ) + explain_stats: gf_explain_stats.ExplainStats = proto.Field( + proto.MESSAGE, + number=4, + message=gf_explain_stats.ExplainStats, + ) + + class RunAggregationQueryRequest(proto.Message): r"""The request for [Firestore.RunAggregationQuery][google.firestore.v1.Firestore.RunAggregationQuery]. diff --git a/google/cloud/firestore_v1/types/pipeline.py b/google/cloud/firestore_v1/types/pipeline.py new file mode 100644 index 000000000..29fbe884b --- /dev/null +++ b/google/cloud/firestore_v1/types/pipeline.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +import proto # type: ignore + +from google.cloud.firestore_v1.types import document + + +__protobuf__ = proto.module( + package="google.firestore.v1", + manifest={ + "StructuredPipeline", + }, +) + + +class StructuredPipeline(proto.Message): + r"""A Firestore query represented as an ordered list of operations / + stages. + + This is considered the top-level function which plans & executes a + query. It is logically equivalent to ``query(stages, options)``, but + prevents the client from having to build a function wrapper. + + Attributes: + pipeline (google.cloud.firestore_v1.types.Pipeline): + Required. The pipeline query to execute. + options (MutableMapping[str, google.cloud.firestore_v1.types.Value]): + Optional. Optional query-level arguments. + """ + + pipeline: document.Pipeline = proto.Field( + proto.MESSAGE, + number=1, + message=document.Pipeline, + ) + options: MutableMapping[str, document.Value] = proto.MapField( + proto.STRING, + proto.MESSAGE, + number=2, + message=document.Value, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) From 8fde414d2862cff70fa102a4ce9c6b2e8d50ee8f Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 8 May 2025 14:09:56 -0700 Subject: [PATCH 108/131] updated gapics --- google/cloud/firestore_v1/async_client.py | 2 +- google/cloud/firestore_v1/async_pipeline.py | 8 +- google/cloud/firestore_v1/base_client.py | 3 +- google/cloud/firestore_v1/base_pipeline.py | 2 +- google/cloud/firestore_v1/base_query.py | 1 - google/cloud/firestore_v1/client.py | 2 +- google/cloud/firestore_v1/pipeline.py | 8 +- .../firestore_v1/pipeline_expressions.py | 2 +- google/cloud/firestore_v1/pipeline_result.py | 4 +- google/cloud/firestore_v1/pipeline_source.py | 8 +- google/cloud/firestore_v1/pipeline_stages.py | 2 +- .../cloud/firestore_v1/services/__init__.py | 2 +- .../services/firestore/__init__.py | 2 +- .../services/firestore/async_client.py | 330 +- .../firestore_v1/services/firestore/client.py | 416 ++- .../firestore_v1/services/firestore/pagers.py | 50 +- .../services/firestore/transports/README.rst | 9 + .../services/firestore/transports/__init__.py | 2 +- .../services/firestore/transports/base.py | 34 +- .../services/firestore/transports/grpc.py | 133 +- .../firestore/transports/grpc_asyncio.py | 203 +- .../services/firestore/transports/rest.py | 3181 ++++++++++++----- .../firestore/transports/rest_base.py | 1046 ++++++ google/cloud/firestore_v1/types/__init__.py | 2 +- .../firestore_v1/types/aggregation_result.py | 2 +- .../cloud/firestore_v1/types/bloom_filter.py | 2 +- google/cloud/firestore_v1/types/common.py | 2 +- google/cloud/firestore_v1/types/document.py | 33 +- .../cloud/firestore_v1/types/explain_stats.py | 2 +- google/cloud/firestore_v1/types/firestore.py | 4 +- google/cloud/firestore_v1/types/pipeline.py | 2 +- google/cloud/firestore_v1/types/query.py | 7 +- .../cloud/firestore_v1/types/query_profile.py | 2 +- google/cloud/firestore_v1/types/write.py | 2 +- 34 files changed, 4251 insertions(+), 1259 deletions(-) create mode 100644 google/cloud/firestore_v1/services/firestore/transports/README.rst create mode 100644 google/cloud/firestore_v1/services/firestore/transports/rest_base.py diff --git a/google/cloud/firestore_v1/async_client.py b/google/cloud/firestore_v1/async_client.py index 13a00398e..738121a8e 100644 --- a/google/cloud/firestore_v1/async_client.py +++ b/google/cloud/firestore_v1/async_client.py @@ -416,4 +416,4 @@ def transaction(self, **kwargs) -> AsyncTransaction: @property def _pipeline_cls(self): - raise AsyncPipeline \ No newline at end of file + raise AsyncPipeline diff --git a/google/cloud/firestore_v1/async_pipeline.py b/google/cloud/firestore_v1/async_pipeline.py index 34cd3d4e2..5e4fc05e7 100644 --- a/google/cloud/firestore_v1/async_pipeline.py +++ b/google/cloud/firestore_v1/async_pipeline.py @@ -71,7 +71,11 @@ async def execute(self) -> AsyncIterable[PipelineResult]: request ): for doc in response.results: - doc_ref = AsyncDocumentReference(doc.name, client=self._client) if doc.name else None + doc_ref = ( + AsyncDocumentReference(doc.name, client=self._client) + if doc.name + else None + ) yield PipelineResult( self._client, doc.fields, @@ -79,4 +83,4 @@ async def execute(self) -> AsyncIterable[PipelineResult]: response._pb.execution_time, doc.create_time, doc.update_tiem, - ) \ No newline at end of file + ) diff --git a/google/cloud/firestore_v1/base_client.py b/google/cloud/firestore_v1/base_client.py index 767c577d3..dfe6b9e65 100644 --- a/google/cloud/firestore_v1/base_client.py +++ b/google/cloud/firestore_v1/base_client.py @@ -34,7 +34,7 @@ Optional, Tuple, Union, - Type + Type, ) import google.api_core.client_options @@ -484,6 +484,7 @@ def pipeline(self) -> PipelineSource: def _pipeline_cls(self) -> Type["_BasePipeline"]: raise NotImplementedError + def _reference_info(references: list) -> Tuple[list, dict]: """Get information about document references. diff --git a/google/cloud/firestore_v1/base_pipeline.py b/google/cloud/firestore_v1/base_pipeline.py index bf4044a54..91da9ac2a 100644 --- a/google/cloud/firestore_v1/base_pipeline.py +++ b/google/cloud/firestore_v1/base_pipeline.py @@ -66,4 +66,4 @@ def _append(self, new_stage): """ Create a new Pipeline object with a new stage appended """ - return self.__class__(self._client, *self.stages, new_stage) \ No newline at end of file + return self.__class__(self._client, *self.stages, new_stage) diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index cec7ce02b..de1a0d5b5 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -1108,7 +1108,6 @@ def recursive(self: QueryType) -> QueryType: def pipeline(self): raise NotImplementedError - def _comparator(self, doc1, doc2) -> int: _orders = self._orders diff --git a/google/cloud/firestore_v1/client.py b/google/cloud/firestore_v1/client.py index f28b8eed5..6622d70d5 100644 --- a/google/cloud/firestore_v1/client.py +++ b/google/cloud/firestore_v1/client.py @@ -408,4 +408,4 @@ def transaction(self, **kwargs) -> Transaction: @property def _pipeline_cls(self): - raise Pipeline \ No newline at end of file + raise Pipeline diff --git a/google/cloud/firestore_v1/pipeline.py b/google/cloud/firestore_v1/pipeline.py index 68771496e..8ddde7ec9 100644 --- a/google/cloud/firestore_v1/pipeline.py +++ b/google/cloud/firestore_v1/pipeline.py @@ -65,7 +65,11 @@ def execute(self) -> Iterable[PipelineResult]: ) for response in self._client._firestore_api.execute_pipeline(request): for doc in response.results: - doc_ref = DocumentReference(doc.name, client=self._client) if doc.name else None + doc_ref = ( + DocumentReference(doc.name, client=self._client) + if doc.name + else None + ) yield PipelineResult( self._client, doc.fields, @@ -73,4 +77,4 @@ def execute(self) -> Iterable[PipelineResult]: response._pb.execution_time, doc.create_time, doc.update_tiem, - ) \ No newline at end of file + ) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index 4a560748f..0219a29c6 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -86,4 +86,4 @@ def __repr__(self): return f"Constant.of({self.value!r})" def _to_pb(self) -> Value: - return encode_value(self.value) \ No newline at end of file + return encode_value(self.value) diff --git a/google/cloud/firestore_v1/pipeline_result.py b/google/cloud/firestore_v1/pipeline_result.py index 56a7c2ea4..61341db4d 100644 --- a/google/cloud/firestore_v1/pipeline_result.py +++ b/google/cloud/firestore_v1/pipeline_result.py @@ -133,6 +133,8 @@ def get(self, field_path: str | FieldPath) -> Any: Returns: The data at the specified field location, decoded to Python types. """ - str_path = field_path if isinstance(field_path, str) else field_path.to_api_repr() + str_path = ( + field_path if isinstance(field_path, str) else field_path.to_api_repr() + ) value = get_nested_value(str_path, self._fields_pb) return decode_value(value, self._client) diff --git a/google/cloud/firestore_v1/pipeline_source.py b/google/cloud/firestore_v1/pipeline_source.py index b1a49c0f6..97044e471 100644 --- a/google/cloud/firestore_v1/pipeline_source.py +++ b/google/cloud/firestore_v1/pipeline_source.py @@ -23,8 +23,8 @@ from google.cloud.firestore_v1.base_document import BaseDocumentReference +PipelineType = TypeVar("PipelineType", bound=_BasePipeline) -PipelineType = TypeVar('PipelineType', bound=_BasePipeline) class PipelineSource(Generic[PipelineType]): """ @@ -62,7 +62,9 @@ def collection_group(self, collection_id: str) -> PipelineType: Returns: a new pipeline instance targeting the specified collection group """ - return self.client._pipeline_cls(self.client, stages.CollectionGroup(collection_id)) + return self.client._pipeline_cls( + self.client, stages.CollectionGroup(collection_id) + ) def database(self) -> PipelineType: """ @@ -82,4 +84,4 @@ def documents(self, *docs: "BaseDocumentReference") -> PipelineType: Returns: a new pipeline instance targeting the specified documents """ - return self.client._pipeline_cls(self.client, stages.Documents.of(*docs)) \ No newline at end of file + return self.client._pipeline_cls(self.client, stages.Documents.of(*docs)) diff --git a/google/cloud/firestore_v1/pipeline_stages.py b/google/cloud/firestore_v1/pipeline_stages.py index 242671756..0eaefd7d0 100644 --- a/google/cloud/firestore_v1/pipeline_stages.py +++ b/google/cloud/firestore_v1/pipeline_stages.py @@ -121,4 +121,4 @@ def __init__(self, name: str, *params: Expr | Value): ] def _pb_args(self): - return self.params \ No newline at end of file + return self.params diff --git a/google/cloud/firestore_v1/services/__init__.py b/google/cloud/firestore_v1/services/__init__.py index 8f6cf0682..cbf94b283 100644 --- a/google/cloud/firestore_v1/services/__init__.py +++ b/google/cloud/firestore_v1/services/__init__.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2024 Google LLC +# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/google/cloud/firestore_v1/services/firestore/__init__.py b/google/cloud/firestore_v1/services/firestore/__init__.py index a33859857..a69a11b29 100644 --- a/google/cloud/firestore_v1/services/firestore/__init__.py +++ b/google/cloud/firestore_v1/services/firestore/__init__.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2024 Google LLC +# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/google/cloud/firestore_v1/services/firestore/async_client.py b/google/cloud/firestore_v1/services/firestore/async_client.py index a11916f2d..916914969 100644 --- a/google/cloud/firestore_v1/services/firestore/async_client.py +++ b/google/cloud/firestore_v1/services/firestore/async_client.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2024 Google LLC +# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import logging as std_logging from collections import OrderedDict import re from typing import ( @@ -64,6 +65,15 @@ from .transports.grpc_asyncio import FirestoreGrpcAsyncIOTransport from .client import FirestoreClient +try: + from google.api_core import client_logging # type: ignore + + CLIENT_LOGGING_SUPPORTED = True # pragma: NO COVER +except ImportError: # pragma: NO COVER + CLIENT_LOGGING_SUPPORTED = False + +_LOGGER = std_logging.getLogger(__name__) + class FirestoreAsyncClient: """The Cloud Firestore service. @@ -227,6 +237,9 @@ def __init__( If a Callable is given, it will be called with the same set of initialization arguments as used in the FirestoreTransport constructor. If set to None, a transport is chosen automatically. + NOTE: "rest" transport functionality is currently in a + beta state (preview). We welcome your feedback via an + issue in this library's source repository. client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): Custom options for the client. @@ -270,13 +283,35 @@ def __init__( client_info=client_info, ) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + std_logging.DEBUG + ): # pragma: NO COVER + _LOGGER.debug( + "Created client `google.firestore_v1.FirestoreAsyncClient`.", + extra={ + "serviceName": "google.firestore.v1.Firestore", + "universeDomain": getattr( + self._client._transport._credentials, "universe_domain", "" + ), + "credentialsType": f"{type(self._client._transport._credentials).__module__}.{type(self._client._transport._credentials).__qualname__}", + "credentialsInfo": getattr( + self.transport._credentials, "get_cred_info", lambda: None + )(), + } + if hasattr(self._client._transport, "_credentials") + else { + "serviceName": "google.firestore.v1.Firestore", + "credentialsType": None, + }, + ) + async def get_document( self, request: Optional[Union[firestore.GetDocumentRequest, dict]] = None, *, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> document.Document: r"""Gets a single document. @@ -314,8 +349,10 @@ async def sample_get_document(): retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. Returns: google.cloud.firestore_v1.types.Document: @@ -362,7 +399,7 @@ async def list_documents( *, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> pagers.ListDocumentsAsyncPager: r"""Lists documents. @@ -401,8 +438,10 @@ async def sample_list_documents(): retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. Returns: google.cloud.firestore_v1.services.firestore.pagers.ListDocumentsAsyncPager: @@ -469,7 +508,7 @@ async def update_document( update_mask: Optional[common.DocumentMask] = None, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> gf_document.Document: r"""Updates or inserts a document. @@ -528,8 +567,10 @@ async def sample_update_document(): retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. Returns: google.cloud.firestore_v1.types.Document: @@ -541,7 +582,10 @@ async def sample_update_document(): # Create or coerce a protobuf request object. # - Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([document, update_mask]) + flattened_params = [document, update_mask] + has_flattened_params = ( + len([param for param in flattened_params if param is not None]) > 0 + ) if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " @@ -595,7 +639,7 @@ async def delete_document( name: Optional[str] = None, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> None: r"""Deletes a document. @@ -637,13 +681,18 @@ async def sample_delete_document(): retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. """ # Create or coerce a protobuf request object. # - Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([name]) + flattened_params = [name] + has_flattened_params = ( + len([param for param in flattened_params if param is not None]) > 0 + ) if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " @@ -689,7 +738,7 @@ def batch_get_documents( *, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> Awaitable[AsyncIterable[firestore.BatchGetDocumentsResponse]]: r"""Gets multiple documents. @@ -731,8 +780,10 @@ async def sample_batch_get_documents(): retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. Returns: AsyncIterable[google.cloud.firestore_v1.types.BatchGetDocumentsResponse]: @@ -779,7 +830,7 @@ async def begin_transaction( database: Optional[str] = None, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> firestore.BeginTransactionResponse: r"""Starts a new transaction. @@ -823,8 +874,10 @@ async def sample_begin_transaction(): retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. Returns: google.cloud.firestore_v1.types.BeginTransactionResponse: @@ -835,7 +888,10 @@ async def sample_begin_transaction(): # Create or coerce a protobuf request object. # - Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([database]) + flattened_params = [database] + has_flattened_params = ( + len([param for param in flattened_params if param is not None]) > 0 + ) if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " @@ -886,7 +942,7 @@ async def commit( writes: Optional[MutableSequence[gf_write.Write]] = None, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> firestore.CommitResponse: r"""Commits a transaction, while optionally updating documents. @@ -939,8 +995,10 @@ async def sample_commit(): retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. Returns: google.cloud.firestore_v1.types.CommitResponse: @@ -951,7 +1009,10 @@ async def sample_commit(): # Create or coerce a protobuf request object. # - Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([database, writes]) + flattened_params = [database, writes] + has_flattened_params = ( + len([param for param in flattened_params if param is not None]) > 0 + ) if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " @@ -1002,7 +1063,7 @@ async def rollback( transaction: Optional[bytes] = None, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> None: r"""Rolls back a transaction. @@ -1051,13 +1112,18 @@ async def sample_rollback(): retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. """ # Create or coerce a protobuf request object. # - Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([database, transaction]) + flattened_params = [database, transaction] + has_flattened_params = ( + len([param for param in flattened_params if param is not None]) > 0 + ) if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " @@ -1103,7 +1169,7 @@ def run_query( *, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> Awaitable[AsyncIterable[firestore.RunQueryResponse]]: r"""Runs a query. @@ -1142,8 +1208,10 @@ async def sample_run_query(): retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. Returns: AsyncIterable[google.cloud.firestore_v1.types.RunQueryResponse]: @@ -1189,40 +1257,91 @@ def execute_pipeline( *, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> Awaitable[AsyncIterable[firestore.ExecutePipelineResponse]]: r"""Executes a pipeline query. + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import firestore_v1 + + async def sample_execute_pipeline(): + # Create a client + client = firestore_v1.FirestoreAsyncClient() + + # Initialize request argument(s) + structured_pipeline = firestore_v1.StructuredPipeline() + structured_pipeline.pipeline.stages.name = "name_value" + + request = firestore_v1.ExecutePipelineRequest( + structured_pipeline=structured_pipeline, + transaction=b'transaction_blob', + database="database_value", + ) + + # Make the request + stream = await client.execute_pipeline(request=request) + + # Handle the response + async for response in stream: + print(response) + Args: request (Optional[Union[google.cloud.firestore_v1.types.ExecutePipelineRequest, dict]]): The request object. The request for [Firestore.ExecutePipeline][google.firestore.v1.Firestore.ExecutePipeline]. - retry (google.api_core.retry.Retry): Designation of what errors, if any, + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. Returns: AsyncIterable[google.cloud.firestore_v1.types.ExecutePipelineResponse]: The response for [Firestore.Execute][]. """ # Create or coerce a protobuf request object. - request = firestore.ExecutePipelineRequest(request) + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, firestore.ExecutePipelineRequest): + request = firestore.ExecutePipelineRequest(request) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.execute_pipeline, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._client._transport._wrapped_methods[ + self._client._transport.execute_pipeline + ] - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("database", request.database),)), + header_params = {} + + routing_param_regex = re.compile("^projects/(?P[^/]+)(?:/.*)?$") + regex_match = routing_param_regex.match(request.database) + if regex_match and regex_match.group("project_id"): + header_params["project_id"] = regex_match.group("project_id") + + routing_param_regex = re.compile( + "^projects/[^/]+/databases/(?P[^/]+)(?:/.*)?$" ) + regex_match = routing_param_regex.match(request.database) + if regex_match and regex_match.group("database_id"): + header_params["database_id"] = regex_match.group("database_id") + + if header_params: + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(header_params), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() # Send the request. response = rpc( @@ -1241,7 +1360,7 @@ def run_aggregation_query( *, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> Awaitable[AsyncIterable[firestore.RunAggregationQueryResponse]]: r"""Runs an aggregation query. @@ -1294,8 +1413,10 @@ async def sample_run_aggregation_query(): retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. Returns: AsyncIterable[google.cloud.firestore_v1.types.RunAggregationQueryResponse]: @@ -1341,7 +1462,7 @@ async def partition_query( *, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> pagers.PartitionQueryAsyncPager: r"""Partitions a query by returning partition cursors that can be used to run the query in parallel. The @@ -1383,8 +1504,10 @@ async def sample_partition_query(): retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. Returns: google.cloud.firestore_v1.services.firestore.pagers.PartitionQueryAsyncPager: @@ -1444,7 +1567,7 @@ def write( *, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> Awaitable[AsyncIterable[firestore.WriteResponse]]: r"""Streams batches of document updates and deletes, in order. This method is only available via gRPC or @@ -1506,8 +1629,10 @@ def request_generator(): retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. Returns: AsyncIterable[google.cloud.firestore_v1.types.WriteResponse]: @@ -1544,7 +1669,7 @@ def listen( *, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> Awaitable[AsyncIterable[firestore.ListenResponse]]: r"""Listens to changes. This method is only available via gRPC or WebChannel (not REST). @@ -1597,8 +1722,10 @@ def request_generator(): retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. Returns: AsyncIterable[google.cloud.firestore_v1.types.ListenResponse]: @@ -1636,7 +1763,7 @@ async def list_collection_ids( parent: Optional[str] = None, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> pagers.ListCollectionIdsAsyncPager: r"""Lists all the collection IDs underneath a document. @@ -1683,8 +1810,10 @@ async def sample_list_collection_ids(): retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. Returns: google.cloud.firestore_v1.services.firestore.pagers.ListCollectionIdsAsyncPager: @@ -1698,7 +1827,10 @@ async def sample_list_collection_ids(): # Create or coerce a protobuf request object. # - Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([parent]) + flattened_params = [parent] + has_flattened_params = ( + len([param for param in flattened_params if param is not None]) > 0 + ) if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " @@ -1758,7 +1890,7 @@ async def batch_write( *, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> firestore.BatchWriteResponse: r"""Applies a batch of write operations. @@ -1805,8 +1937,10 @@ async def sample_batch_write(): retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. Returns: google.cloud.firestore_v1.types.BatchWriteResponse: @@ -1852,7 +1986,7 @@ async def create_document( *, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> document.Document: r"""Creates a new document. @@ -1890,8 +2024,10 @@ async def sample_create_document(): retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. Returns: google.cloud.firestore_v1.types.Document: @@ -1943,7 +2079,7 @@ async def list_operations( *, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> operations_pb2.ListOperationsResponse: r"""Lists operations that match the specified filter in the request. @@ -1954,8 +2090,10 @@ async def list_operations( retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. Returns: ~.operations_pb2.ListOperationsResponse: Response message for ``ListOperations`` method. @@ -1968,11 +2106,7 @@ async def list_operations( # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_operations, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self.transport._wrapped_methods[self._client._transport.list_operations] # Certain fields should be provided within the metadata header; # add these here. @@ -2000,7 +2134,7 @@ async def get_operation( *, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> operations_pb2.Operation: r"""Gets the latest state of a long-running operation. @@ -2011,8 +2145,10 @@ async def get_operation( retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. Returns: ~.operations_pb2.Operation: An ``Operation`` object. @@ -2025,11 +2161,7 @@ async def get_operation( # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_operation, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self.transport._wrapped_methods[self._client._transport.get_operation] # Certain fields should be provided within the metadata header; # add these here. @@ -2057,7 +2189,7 @@ async def delete_operation( *, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> None: r"""Deletes a long-running operation. @@ -2073,8 +2205,10 @@ async def delete_operation( retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. Returns: None """ @@ -2086,11 +2220,7 @@ async def delete_operation( # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_operation, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self.transport._wrapped_methods[self._client._transport.delete_operation] # Certain fields should be provided within the metadata header; # add these here. @@ -2115,7 +2245,7 @@ async def cancel_operation( *, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> None: r"""Starts asynchronous cancellation on a long-running operation. @@ -2130,8 +2260,10 @@ async def cancel_operation( retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. Returns: None """ @@ -2143,11 +2275,7 @@ async def cancel_operation( # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.cancel_operation, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self.transport._wrapped_methods[self._client._transport.cancel_operation] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/firestore_v1/services/firestore/client.py b/google/cloud/firestore_v1/services/firestore/client.py index c8974eaa4..340cd5ef2 100644 --- a/google/cloud/firestore_v1/services/firestore/client.py +++ b/google/cloud/firestore_v1/services/firestore/client.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2024 Google LLC +# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,6 +14,9 @@ # limitations under the License. # from collections import OrderedDict +from http import HTTPStatus +import json +import logging as std_logging import os import re from typing import ( @@ -50,6 +53,15 @@ except AttributeError: # pragma: NO COVER OptionalRetry = Union[retries.Retry, object, None] # type: ignore +try: + from google.api_core import client_logging # type: ignore + + CLIENT_LOGGING_SUPPORTED = True # pragma: NO COVER +except ImportError: # pragma: NO COVER + CLIENT_LOGGING_SUPPORTED = False + +_LOGGER = std_logging.getLogger(__name__) + from google.cloud.firestore_v1.services.firestore import pagers from google.cloud.firestore_v1.types import aggregation_result from google.cloud.firestore_v1.types import common @@ -459,52 +471,45 @@ def _get_universe_domain( raise ValueError("Universe Domain cannot be an empty string.") return universe_domain - @staticmethod - def _compare_universes( - client_universe: str, credentials: ga_credentials.Credentials - ) -> bool: - """Returns True iff the universe domains used by the client and credentials match. - - Args: - client_universe (str): The universe domain configured via the client options. - credentials (ga_credentials.Credentials): The credentials being used in the client. + def _validate_universe_domain(self): + """Validates client's and credentials' universe domains are consistent. Returns: - bool: True iff client_universe matches the universe in credentials. + bool: True iff the configured universe domain is valid. Raises: - ValueError: when client_universe does not match the universe in credentials. + ValueError: If the configured universe domain is not valid. """ - default_universe = FirestoreClient._DEFAULT_UNIVERSE - credentials_universe = getattr(credentials, "universe_domain", default_universe) - - if client_universe != credentials_universe: - raise ValueError( - "The configured universe domain " - f"({client_universe}) does not match the universe domain " - f"found in the credentials ({credentials_universe}). " - "If you haven't configured the universe domain explicitly, " - f"`{default_universe}` is the default." - ) + # NOTE (b/349488459): universe validation is disabled until further notice. return True - def _validate_universe_domain(self): - """Validates client's and credentials' universe domains are consistent. - - Returns: - bool: True iff the configured universe domain is valid. + def _add_cred_info_for_auth_errors( + self, error: core_exceptions.GoogleAPICallError + ) -> None: + """Adds credential info string to error details for 401/403/404 errors. - Raises: - ValueError: If the configured universe domain is not valid. + Args: + error (google.api_core.exceptions.GoogleAPICallError): The error to add the cred info. """ - self._is_universe_domain_valid = ( - self._is_universe_domain_valid - or FirestoreClient._compare_universes( - self.universe_domain, self.transport._credentials - ) - ) - return self._is_universe_domain_valid + if error.code not in [ + HTTPStatus.UNAUTHORIZED, + HTTPStatus.FORBIDDEN, + HTTPStatus.NOT_FOUND, + ]: + return + + cred = self._transport._credentials + + # get_cred_info is only available in google-auth>=2.35.0 + if not hasattr(cred, "get_cred_info"): + return + + # ignore the type check since pypy test fails when get_cred_info + # is not available + cred_info = cred.get_cred_info() # type: ignore + if cred_info and hasattr(error._details, "append"): + error._details.append(json.dumps(cred_info)) @property def api_endpoint(self): @@ -547,6 +552,9 @@ def __init__( If a Callable is given, it will be called with the same set of initialization arguments as used in the FirestoreTransport constructor. If set to None, a transport is chosen automatically. + NOTE: "rest" transport functionality is currently in a + beta state (preview). We welcome your feedback via an + issue in this library's source repository. client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): Custom options for the client. @@ -610,6 +618,10 @@ def __init__( # Initialize the universe domain validation. self._is_universe_domain_valid = False + if CLIENT_LOGGING_SUPPORTED: # pragma: NO COVER + # Setup logging. + client_logging.initialize_logging() + api_key_value = getattr(self._client_options, "api_key", None) if api_key_value and credentials: raise ValueError( @@ -672,13 +684,36 @@ def __init__( api_audience=self._client_options.api_audience, ) + if "async" not in str(self._transport): + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + std_logging.DEBUG + ): # pragma: NO COVER + _LOGGER.debug( + "Created client `google.firestore_v1.FirestoreClient`.", + extra={ + "serviceName": "google.firestore.v1.Firestore", + "universeDomain": getattr( + self._transport._credentials, "universe_domain", "" + ), + "credentialsType": f"{type(self._transport._credentials).__module__}.{type(self._transport._credentials).__qualname__}", + "credentialsInfo": getattr( + self.transport._credentials, "get_cred_info", lambda: None + )(), + } + if hasattr(self._transport, "_credentials") + else { + "serviceName": "google.firestore.v1.Firestore", + "credentialsType": None, + }, + ) + def get_document( self, request: Optional[Union[firestore.GetDocumentRequest, dict]] = None, *, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> document.Document: r"""Gets a single document. @@ -716,8 +751,10 @@ def sample_get_document(): retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. Returns: google.cloud.firestore_v1.types.Document: @@ -762,7 +799,7 @@ def list_documents( *, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> pagers.ListDocumentsPager: r"""Lists documents. @@ -801,8 +838,10 @@ def sample_list_documents(): retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. Returns: google.cloud.firestore_v1.services.firestore.pagers.ListDocumentsPager: @@ -867,7 +906,7 @@ def update_document( update_mask: Optional[common.DocumentMask] = None, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> gf_document.Document: r"""Updates or inserts a document. @@ -926,8 +965,10 @@ def sample_update_document(): retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. Returns: google.cloud.firestore_v1.types.Document: @@ -939,7 +980,10 @@ def sample_update_document(): # Create or coerce a protobuf request object. # - Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([document, update_mask]) + flattened_params = [document, update_mask] + has_flattened_params = ( + len([param for param in flattened_params if param is not None]) > 0 + ) if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " @@ -990,7 +1034,7 @@ def delete_document( name: Optional[str] = None, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> None: r"""Deletes a document. @@ -1032,13 +1076,18 @@ def sample_delete_document(): retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. """ # Create or coerce a protobuf request object. # - Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([name]) + flattened_params = [name] + has_flattened_params = ( + len([param for param in flattened_params if param is not None]) > 0 + ) if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " @@ -1081,7 +1130,7 @@ def batch_get_documents( *, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> Iterable[firestore.BatchGetDocumentsResponse]: r"""Gets multiple documents. @@ -1123,8 +1172,10 @@ def sample_batch_get_documents(): retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. Returns: Iterable[google.cloud.firestore_v1.types.BatchGetDocumentsResponse]: @@ -1169,7 +1220,7 @@ def begin_transaction( database: Optional[str] = None, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> firestore.BeginTransactionResponse: r"""Starts a new transaction. @@ -1213,8 +1264,10 @@ def sample_begin_transaction(): retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. Returns: google.cloud.firestore_v1.types.BeginTransactionResponse: @@ -1225,7 +1278,10 @@ def sample_begin_transaction(): # Create or coerce a protobuf request object. # - Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([database]) + flattened_params = [database] + has_flattened_params = ( + len([param for param in flattened_params if param is not None]) > 0 + ) if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " @@ -1273,7 +1329,7 @@ def commit( writes: Optional[MutableSequence[gf_write.Write]] = None, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> firestore.CommitResponse: r"""Commits a transaction, while optionally updating documents. @@ -1326,8 +1382,10 @@ def sample_commit(): retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. Returns: google.cloud.firestore_v1.types.CommitResponse: @@ -1338,7 +1396,10 @@ def sample_commit(): # Create or coerce a protobuf request object. # - Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([database, writes]) + flattened_params = [database, writes] + has_flattened_params = ( + len([param for param in flattened_params if param is not None]) > 0 + ) if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " @@ -1388,7 +1449,7 @@ def rollback( transaction: Optional[bytes] = None, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> None: r"""Rolls back a transaction. @@ -1437,13 +1498,18 @@ def sample_rollback(): retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. """ # Create or coerce a protobuf request object. # - Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([database, transaction]) + flattened_params = [database, transaction] + has_flattened_params = ( + len([param for param in flattened_params if param is not None]) > 0 + ) if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " @@ -1488,7 +1554,7 @@ def run_query( *, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> Iterable[firestore.RunQueryResponse]: r"""Runs a query. @@ -1527,8 +1593,10 @@ def sample_run_query(): retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. Returns: Iterable[google.cloud.firestore_v1.types.RunQueryResponse]: @@ -1572,10 +1640,42 @@ def execute_pipeline( *, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> Iterable[firestore.ExecutePipelineResponse]: r"""Executes a pipeline query. + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import firestore_v1 + + def sample_execute_pipeline(): + # Create a client + client = firestore_v1.FirestoreClient() + + # Initialize request argument(s) + structured_pipeline = firestore_v1.StructuredPipeline() + structured_pipeline.pipeline.stages.name = "name_value" + + request = firestore_v1.ExecutePipelineRequest( + structured_pipeline=structured_pipeline, + transaction=b'transaction_blob', + database="database_value", + ) + + # Make the request + stream = client.execute_pipeline(request=request) + + # Handle the response + for response in stream: + print(response) + Args: request (Union[google.cloud.firestore_v1.types.ExecutePipelineRequest, dict]): The request object. The request for @@ -1583,18 +1683,18 @@ def execute_pipeline( retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. Returns: Iterable[google.cloud.firestore_v1.types.ExecutePipelineResponse]: The response for [Firestore.Execute][]. """ # Create or coerce a protobuf request object. - # Minor optimization to avoid making a copy if the user passes - # in a firestore.ExecutePipelineRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. if not isinstance(request, firestore.ExecutePipelineRequest): request = firestore.ExecutePipelineRequest(request) @@ -1621,6 +1721,9 @@ def execute_pipeline( gapic_v1.routing_header.to_grpc_metadata(header_params), ) + # Validate the universe domain. + self._validate_universe_domain() + # Send the request. response = rpc( request, @@ -1638,7 +1741,7 @@ def run_aggregation_query( *, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> Iterable[firestore.RunAggregationQueryResponse]: r"""Runs an aggregation query. @@ -1691,8 +1794,10 @@ def sample_run_aggregation_query(): retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. Returns: Iterable[google.cloud.firestore_v1.types.RunAggregationQueryResponse]: @@ -1736,7 +1841,7 @@ def partition_query( *, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> pagers.PartitionQueryPager: r"""Partitions a query by returning partition cursors that can be used to run the query in parallel. The @@ -1778,8 +1883,10 @@ def sample_partition_query(): retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. Returns: google.cloud.firestore_v1.services.firestore.pagers.PartitionQueryPager: @@ -1837,7 +1944,7 @@ def write( *, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> Iterable[firestore.WriteResponse]: r"""Streams batches of document updates and deletes, in order. This method is only available via gRPC or @@ -1899,8 +2006,10 @@ def request_generator(): retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. Returns: Iterable[google.cloud.firestore_v1.types.WriteResponse]: @@ -1937,7 +2046,7 @@ def listen( *, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> Iterable[firestore.ListenResponse]: r"""Listens to changes. This method is only available via gRPC or WebChannel (not REST). @@ -1990,8 +2099,10 @@ def request_generator(): retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. Returns: Iterable[google.cloud.firestore_v1.types.ListenResponse]: @@ -2029,7 +2140,7 @@ def list_collection_ids( parent: Optional[str] = None, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> pagers.ListCollectionIdsPager: r"""Lists all the collection IDs underneath a document. @@ -2076,8 +2187,10 @@ def sample_list_collection_ids(): retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. Returns: google.cloud.firestore_v1.services.firestore.pagers.ListCollectionIdsPager: @@ -2091,7 +2204,10 @@ def sample_list_collection_ids(): # Create or coerce a protobuf request object. # - Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([parent]) + flattened_params = [parent] + has_flattened_params = ( + len([param for param in flattened_params if param is not None]) > 0 + ) if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " @@ -2148,7 +2264,7 @@ def batch_write( *, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> firestore.BatchWriteResponse: r"""Applies a batch of write operations. @@ -2195,8 +2311,10 @@ def sample_batch_write(): retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. Returns: google.cloud.firestore_v1.types.BatchWriteResponse: @@ -2240,7 +2358,7 @@ def create_document( *, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> document.Document: r"""Creates a new document. @@ -2278,8 +2396,10 @@ def sample_create_document(): retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. Returns: google.cloud.firestore_v1.types.Document: @@ -2342,7 +2462,7 @@ def list_operations( *, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> operations_pb2.ListOperationsResponse: r"""Lists operations that match the specified filter in the request. @@ -2353,8 +2473,10 @@ def list_operations( retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. Returns: ~.operations_pb2.ListOperationsResponse: Response message for ``ListOperations`` method. @@ -2367,11 +2489,7 @@ def list_operations( # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.list_operations, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._transport._wrapped_methods[self._transport.list_operations] # Certain fields should be provided within the metadata header; # add these here. @@ -2382,16 +2500,20 @@ def list_operations( # Validate the universe domain. self._validate_universe_domain() - # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + try: + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) - # Done; return the response. - return response + # Done; return the response. + return response + except core_exceptions.GoogleAPICallError as e: + self._add_cred_info_for_auth_errors(e) + raise e def get_operation( self, @@ -2399,7 +2521,7 @@ def get_operation( *, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> operations_pb2.Operation: r"""Gets the latest state of a long-running operation. @@ -2410,8 +2532,10 @@ def get_operation( retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. Returns: ~.operations_pb2.Operation: An ``Operation`` object. @@ -2424,11 +2548,7 @@ def get_operation( # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.get_operation, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._transport._wrapped_methods[self._transport.get_operation] # Certain fields should be provided within the metadata header; # add these here. @@ -2439,16 +2559,20 @@ def get_operation( # Validate the universe domain. self._validate_universe_domain() - # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + try: + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) - # Done; return the response. - return response + # Done; return the response. + return response + except core_exceptions.GoogleAPICallError as e: + self._add_cred_info_for_auth_errors(e) + raise e def delete_operation( self, @@ -2456,7 +2580,7 @@ def delete_operation( *, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> None: r"""Deletes a long-running operation. @@ -2472,8 +2596,10 @@ def delete_operation( retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. Returns: None """ @@ -2485,11 +2611,7 @@ def delete_operation( # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.delete_operation, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._transport._wrapped_methods[self._transport.delete_operation] # Certain fields should be provided within the metadata header; # add these here. @@ -2514,7 +2636,7 @@ def cancel_operation( *, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> None: r"""Starts asynchronous cancellation on a long-running operation. @@ -2529,8 +2651,10 @@ def cancel_operation( retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. Returns: None """ @@ -2542,11 +2666,7 @@ def cancel_operation( # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.cancel_operation, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) + rpc = self._transport._wrapped_methods[self._transport.cancel_operation] # Certain fields should be provided within the metadata header; # add these here. diff --git a/google/cloud/firestore_v1/services/firestore/pagers.py b/google/cloud/firestore_v1/services/firestore/pagers.py index 71ebf18fb..be9e4b714 100644 --- a/google/cloud/firestore_v1/services/firestore/pagers.py +++ b/google/cloud/firestore_v1/services/firestore/pagers.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2024 Google LLC +# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -68,7 +68,7 @@ def __init__( *, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = () + metadata: Sequence[Tuple[str, Union[str, bytes]]] = () ): """Instantiate the pager. @@ -82,8 +82,10 @@ def __init__( retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. """ self._method = method self._request = firestore.ListDocumentsRequest(request) @@ -142,7 +144,7 @@ def __init__( *, retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = () + metadata: Sequence[Tuple[str, Union[str, bytes]]] = () ): """Instantiates the pager. @@ -156,8 +158,10 @@ def __init__( retry (google.api_core.retry.AsyncRetry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. """ self._method = method self._request = firestore.ListDocumentsRequest(request) @@ -220,7 +224,7 @@ def __init__( *, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = () + metadata: Sequence[Tuple[str, Union[str, bytes]]] = () ): """Instantiate the pager. @@ -234,8 +238,10 @@ def __init__( retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. """ self._method = method self._request = firestore.PartitionQueryRequest(request) @@ -294,7 +300,7 @@ def __init__( *, retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = () + metadata: Sequence[Tuple[str, Union[str, bytes]]] = () ): """Instantiates the pager. @@ -308,8 +314,10 @@ def __init__( retry (google.api_core.retry.AsyncRetry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. """ self._method = method self._request = firestore.PartitionQueryRequest(request) @@ -372,7 +380,7 @@ def __init__( *, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = () + metadata: Sequence[Tuple[str, Union[str, bytes]]] = () ): """Instantiate the pager. @@ -386,8 +394,10 @@ def __init__( retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. """ self._method = method self._request = firestore.ListCollectionIdsRequest(request) @@ -446,7 +456,7 @@ def __init__( *, retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = () + metadata: Sequence[Tuple[str, Union[str, bytes]]] = () ): """Instantiates the pager. @@ -460,8 +470,10 @@ def __init__( retry (google.api_core.retry.AsyncRetry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. """ self._method = method self._request = firestore.ListCollectionIdsRequest(request) diff --git a/google/cloud/firestore_v1/services/firestore/transports/README.rst b/google/cloud/firestore_v1/services/firestore/transports/README.rst new file mode 100644 index 000000000..1823b6773 --- /dev/null +++ b/google/cloud/firestore_v1/services/firestore/transports/README.rst @@ -0,0 +1,9 @@ + +transport inheritance structure +_______________________________ + +`FirestoreTransport` is the ABC for all transports. +- public child `FirestoreGrpcTransport` for sync gRPC transport (defined in `grpc.py`). +- public child `FirestoreGrpcAsyncIOTransport` for async gRPC transport (defined in `grpc_asyncio.py`). +- private child `_BaseFirestoreRestTransport` for base REST transport with inner classes `_BaseMETHOD` (defined in `rest_base.py`). +- public child `FirestoreRestTransport` for sync REST transport with inner classes `METHOD` derived from the parent's corresponding `_BaseMETHOD` classes (defined in `rest.py`). diff --git a/google/cloud/firestore_v1/services/firestore/transports/__init__.py b/google/cloud/firestore_v1/services/firestore/transports/__init__.py index f32c361e0..f3ca95f79 100644 --- a/google/cloud/firestore_v1/services/firestore/transports/__init__.py +++ b/google/cloud/firestore_v1/services/firestore/transports/__init__.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2024 Google LLC +# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/google/cloud/firestore_v1/services/firestore/transports/base.py b/google/cloud/firestore_v1/services/firestore/transports/base.py index a5f1f52ed..50e0b6dd3 100644 --- a/google/cloud/firestore_v1/services/firestore/transports/base.py +++ b/google/cloud/firestore_v1/services/firestore/transports/base.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2024 Google LLC +# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -395,6 +395,26 @@ def _prep_wrapped_messages(self, client_info): default_timeout=60.0, client_info=client_info, ), + self.cancel_operation: gapic_v1.method.wrap_method( + self.cancel_operation, + default_timeout=None, + client_info=client_info, + ), + self.delete_operation: gapic_v1.method.wrap_method( + self.delete_operation, + default_timeout=None, + client_info=client_info, + ), + self.get_operation: gapic_v1.method.wrap_method( + self.get_operation, + default_timeout=None, + client_info=client_info, + ), + self.list_operations: gapic_v1.method.wrap_method( + self.list_operations, + default_timeout=None, + client_info=client_info, + ), } def close(self): @@ -494,6 +514,18 @@ def run_query( ]: raise NotImplementedError() + @property + def execute_pipeline( + self, + ) -> Callable[ + [firestore.ExecutePipelineRequest], + Union[ + firestore.ExecutePipelineResponse, + Awaitable[firestore.ExecutePipelineResponse], + ], + ]: + raise NotImplementedError() + @property def run_aggregation_query( self, diff --git a/google/cloud/firestore_v1/services/firestore/transports/grpc.py b/google/cloud/firestore_v1/services/firestore/transports/grpc.py index 508fa93db..2a8f4caf9 100644 --- a/google/cloud/firestore_v1/services/firestore/transports/grpc.py +++ b/google/cloud/firestore_v1/services/firestore/transports/grpc.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2024 Google LLC +# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import json +import logging as std_logging +import pickle import warnings from typing import Callable, Dict, Optional, Sequence, Tuple, Union @@ -21,8 +24,11 @@ import google.auth # type: ignore from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore +from google.protobuf.json_format import MessageToJson +import google.protobuf.message import grpc # type: ignore +import proto # type: ignore from google.cloud.firestore_v1.types import document from google.cloud.firestore_v1.types import document as gf_document @@ -32,6 +38,80 @@ from google.protobuf import empty_pb2 # type: ignore from .base import FirestoreTransport, DEFAULT_CLIENT_INFO +try: + from google.api_core import client_logging # type: ignore + + CLIENT_LOGGING_SUPPORTED = True # pragma: NO COVER +except ImportError: # pragma: NO COVER + CLIENT_LOGGING_SUPPORTED = False + +_LOGGER = std_logging.getLogger(__name__) + + +class _LoggingClientInterceptor(grpc.UnaryUnaryClientInterceptor): # pragma: NO COVER + def intercept_unary_unary(self, continuation, client_call_details, request): + logging_enabled = CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + std_logging.DEBUG + ) + if logging_enabled: # pragma: NO COVER + request_metadata = client_call_details.metadata + if isinstance(request, proto.Message): + request_payload = type(request).to_json(request) + elif isinstance(request, google.protobuf.message.Message): + request_payload = MessageToJson(request) + else: + request_payload = f"{type(request).__name__}: {pickle.dumps(request)}" + + request_metadata = { + key: value.decode("utf-8") if isinstance(value, bytes) else value + for key, value in request_metadata + } + grpc_request = { + "payload": request_payload, + "requestMethod": "grpc", + "metadata": dict(request_metadata), + } + _LOGGER.debug( + f"Sending request for {client_call_details.method}", + extra={ + "serviceName": "google.firestore.v1.Firestore", + "rpcName": str(client_call_details.method), + "request": grpc_request, + "metadata": grpc_request["metadata"], + }, + ) + response = continuation(client_call_details, request) + if logging_enabled: # pragma: NO COVER + response_metadata = response.trailing_metadata() + # Convert gRPC metadata `` to list of tuples + metadata = ( + dict([(k, str(v)) for k, v in response_metadata]) + if response_metadata + else None + ) + result = response.result() + if isinstance(result, proto.Message): + response_payload = type(result).to_json(result) + elif isinstance(result, google.protobuf.message.Message): + response_payload = MessageToJson(result) + else: + response_payload = f"{type(result).__name__}: {pickle.dumps(result)}" + grpc_response = { + "payload": response_payload, + "metadata": metadata, + "status": "OK", + } + _LOGGER.debug( + f"Received response for {client_call_details.method}.", + extra={ + "serviceName": "google.firestore.v1.Firestore", + "rpcName": client_call_details.method, + "response": grpc_response, + "metadata": grpc_response["metadata"], + }, + ) + return response + class FirestoreGrpcTransport(FirestoreTransport): """gRPC backend transport for Firestore. @@ -193,7 +273,12 @@ def __init__( ], ) - # Wrap messages. This must be done after self._grpc_channel exists + self._interceptor = _LoggingClientInterceptor() + self._logged_channel = grpc.intercept_channel( + self._grpc_channel, self._interceptor + ) + + # Wrap messages. This must be done after self._logged_channel exists self._prep_wrapped_messages(client_info) @classmethod @@ -267,7 +352,7 @@ def get_document( # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. if "get_document" not in self._stubs: - self._stubs["get_document"] = self.grpc_channel.unary_unary( + self._stubs["get_document"] = self._logged_channel.unary_unary( "/google.firestore.v1.Firestore/GetDocument", request_serializer=firestore.GetDocumentRequest.serialize, response_deserializer=document.Document.deserialize, @@ -293,7 +378,7 @@ def list_documents( # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. if "list_documents" not in self._stubs: - self._stubs["list_documents"] = self.grpc_channel.unary_unary( + self._stubs["list_documents"] = self._logged_channel.unary_unary( "/google.firestore.v1.Firestore/ListDocuments", request_serializer=firestore.ListDocumentsRequest.serialize, response_deserializer=firestore.ListDocumentsResponse.deserialize, @@ -319,7 +404,7 @@ def update_document( # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. if "update_document" not in self._stubs: - self._stubs["update_document"] = self.grpc_channel.unary_unary( + self._stubs["update_document"] = self._logged_channel.unary_unary( "/google.firestore.v1.Firestore/UpdateDocument", request_serializer=firestore.UpdateDocumentRequest.serialize, response_deserializer=gf_document.Document.deserialize, @@ -345,7 +430,7 @@ def delete_document( # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. if "delete_document" not in self._stubs: - self._stubs["delete_document"] = self.grpc_channel.unary_unary( + self._stubs["delete_document"] = self._logged_channel.unary_unary( "/google.firestore.v1.Firestore/DeleteDocument", request_serializer=firestore.DeleteDocumentRequest.serialize, response_deserializer=empty_pb2.Empty.FromString, @@ -376,7 +461,7 @@ def batch_get_documents( # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. if "batch_get_documents" not in self._stubs: - self._stubs["batch_get_documents"] = self.grpc_channel.unary_stream( + self._stubs["batch_get_documents"] = self._logged_channel.unary_stream( "/google.firestore.v1.Firestore/BatchGetDocuments", request_serializer=firestore.BatchGetDocumentsRequest.serialize, response_deserializer=firestore.BatchGetDocumentsResponse.deserialize, @@ -404,7 +489,7 @@ def begin_transaction( # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. if "begin_transaction" not in self._stubs: - self._stubs["begin_transaction"] = self.grpc_channel.unary_unary( + self._stubs["begin_transaction"] = self._logged_channel.unary_unary( "/google.firestore.v1.Firestore/BeginTransaction", request_serializer=firestore.BeginTransactionRequest.serialize, response_deserializer=firestore.BeginTransactionResponse.deserialize, @@ -429,7 +514,7 @@ def commit(self) -> Callable[[firestore.CommitRequest], firestore.CommitResponse # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. if "commit" not in self._stubs: - self._stubs["commit"] = self.grpc_channel.unary_unary( + self._stubs["commit"] = self._logged_channel.unary_unary( "/google.firestore.v1.Firestore/Commit", request_serializer=firestore.CommitRequest.serialize, response_deserializer=firestore.CommitResponse.deserialize, @@ -453,7 +538,7 @@ def rollback(self) -> Callable[[firestore.RollbackRequest], empty_pb2.Empty]: # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. if "rollback" not in self._stubs: - self._stubs["rollback"] = self.grpc_channel.unary_unary( + self._stubs["rollback"] = self._logged_channel.unary_unary( "/google.firestore.v1.Firestore/Rollback", request_serializer=firestore.RollbackRequest.serialize, response_deserializer=empty_pb2.Empty.FromString, @@ -479,7 +564,7 @@ def run_query( # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. if "run_query" not in self._stubs: - self._stubs["run_query"] = self.grpc_channel.unary_stream( + self._stubs["run_query"] = self._logged_channel.unary_stream( "/google.firestore.v1.Firestore/RunQuery", request_serializer=firestore.RunQueryRequest.serialize, response_deserializer=firestore.RunQueryResponse.deserialize, @@ -507,7 +592,7 @@ def execute_pipeline( # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. if "execute_pipeline" not in self._stubs: - self._stubs["execute_pipeline"] = self.grpc_channel.unary_stream( + self._stubs["execute_pipeline"] = self._logged_channel.unary_stream( "/google.firestore.v1.Firestore/ExecutePipeline", request_serializer=firestore.ExecutePipelineRequest.serialize, response_deserializer=firestore.ExecutePipelineResponse.deserialize, @@ -549,7 +634,7 @@ def run_aggregation_query( # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. if "run_aggregation_query" not in self._stubs: - self._stubs["run_aggregation_query"] = self.grpc_channel.unary_stream( + self._stubs["run_aggregation_query"] = self._logged_channel.unary_stream( "/google.firestore.v1.Firestore/RunAggregationQuery", request_serializer=firestore.RunAggregationQueryRequest.serialize, response_deserializer=firestore.RunAggregationQueryResponse.deserialize, @@ -579,7 +664,7 @@ def partition_query( # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. if "partition_query" not in self._stubs: - self._stubs["partition_query"] = self.grpc_channel.unary_unary( + self._stubs["partition_query"] = self._logged_channel.unary_unary( "/google.firestore.v1.Firestore/PartitionQuery", request_serializer=firestore.PartitionQueryRequest.serialize, response_deserializer=firestore.PartitionQueryResponse.deserialize, @@ -605,7 +690,7 @@ def write(self) -> Callable[[firestore.WriteRequest], firestore.WriteResponse]: # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. if "write" not in self._stubs: - self._stubs["write"] = self.grpc_channel.stream_stream( + self._stubs["write"] = self._logged_channel.stream_stream( "/google.firestore.v1.Firestore/Write", request_serializer=firestore.WriteRequest.serialize, response_deserializer=firestore.WriteResponse.deserialize, @@ -630,7 +715,7 @@ def listen(self) -> Callable[[firestore.ListenRequest], firestore.ListenResponse # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. if "listen" not in self._stubs: - self._stubs["listen"] = self.grpc_channel.stream_stream( + self._stubs["listen"] = self._logged_channel.stream_stream( "/google.firestore.v1.Firestore/Listen", request_serializer=firestore.ListenRequest.serialize, response_deserializer=firestore.ListenResponse.deserialize, @@ -658,7 +743,7 @@ def list_collection_ids( # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. if "list_collection_ids" not in self._stubs: - self._stubs["list_collection_ids"] = self.grpc_channel.unary_unary( + self._stubs["list_collection_ids"] = self._logged_channel.unary_unary( "/google.firestore.v1.Firestore/ListCollectionIds", request_serializer=firestore.ListCollectionIdsRequest.serialize, response_deserializer=firestore.ListCollectionIdsResponse.deserialize, @@ -694,7 +779,7 @@ def batch_write( # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. if "batch_write" not in self._stubs: - self._stubs["batch_write"] = self.grpc_channel.unary_unary( + self._stubs["batch_write"] = self._logged_channel.unary_unary( "/google.firestore.v1.Firestore/BatchWrite", request_serializer=firestore.BatchWriteRequest.serialize, response_deserializer=firestore.BatchWriteResponse.deserialize, @@ -720,7 +805,7 @@ def create_document( # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. if "create_document" not in self._stubs: - self._stubs["create_document"] = self.grpc_channel.unary_unary( + self._stubs["create_document"] = self._logged_channel.unary_unary( "/google.firestore.v1.Firestore/CreateDocument", request_serializer=firestore.CreateDocumentRequest.serialize, response_deserializer=document.Document.deserialize, @@ -728,7 +813,7 @@ def create_document( return self._stubs["create_document"] def close(self): - self.grpc_channel.close() + self._logged_channel.close() @property def delete_operation( @@ -740,7 +825,7 @@ def delete_operation( # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. if "delete_operation" not in self._stubs: - self._stubs["delete_operation"] = self.grpc_channel.unary_unary( + self._stubs["delete_operation"] = self._logged_channel.unary_unary( "/google.longrunning.Operations/DeleteOperation", request_serializer=operations_pb2.DeleteOperationRequest.SerializeToString, response_deserializer=None, @@ -757,7 +842,7 @@ def cancel_operation( # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. if "cancel_operation" not in self._stubs: - self._stubs["cancel_operation"] = self.grpc_channel.unary_unary( + self._stubs["cancel_operation"] = self._logged_channel.unary_unary( "/google.longrunning.Operations/CancelOperation", request_serializer=operations_pb2.CancelOperationRequest.SerializeToString, response_deserializer=None, @@ -774,7 +859,7 @@ def get_operation( # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. if "get_operation" not in self._stubs: - self._stubs["get_operation"] = self.grpc_channel.unary_unary( + self._stubs["get_operation"] = self._logged_channel.unary_unary( "/google.longrunning.Operations/GetOperation", request_serializer=operations_pb2.GetOperationRequest.SerializeToString, response_deserializer=operations_pb2.Operation.FromString, @@ -793,7 +878,7 @@ def list_operations( # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. if "list_operations" not in self._stubs: - self._stubs["list_operations"] = self.grpc_channel.unary_unary( + self._stubs["list_operations"] = self._logged_channel.unary_unary( "/google.longrunning.Operations/ListOperations", request_serializer=operations_pb2.ListOperationsRequest.SerializeToString, response_deserializer=operations_pb2.ListOperationsResponse.FromString, diff --git a/google/cloud/firestore_v1/services/firestore/transports/grpc_asyncio.py b/google/cloud/firestore_v1/services/firestore/transports/grpc_asyncio.py index ae0dc1c04..8801dc45a 100644 --- a/google/cloud/firestore_v1/services/firestore/transports/grpc_asyncio.py +++ b/google/cloud/firestore_v1/services/firestore/transports/grpc_asyncio.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2024 Google LLC +# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,6 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import inspect +import json +import pickle +import logging as std_logging import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple, Union @@ -22,8 +26,11 @@ from google.api_core import retry_async as retries from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore +from google.protobuf.json_format import MessageToJson +import google.protobuf.message import grpc # type: ignore +import proto # type: ignore from grpc.experimental import aio # type: ignore from google.cloud.firestore_v1.types import document @@ -35,6 +42,82 @@ from .base import FirestoreTransport, DEFAULT_CLIENT_INFO from .grpc import FirestoreGrpcTransport +try: + from google.api_core import client_logging # type: ignore + + CLIENT_LOGGING_SUPPORTED = True # pragma: NO COVER +except ImportError: # pragma: NO COVER + CLIENT_LOGGING_SUPPORTED = False + +_LOGGER = std_logging.getLogger(__name__) + + +class _LoggingClientAIOInterceptor( + grpc.aio.UnaryUnaryClientInterceptor +): # pragma: NO COVER + async def intercept_unary_unary(self, continuation, client_call_details, request): + logging_enabled = CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + std_logging.DEBUG + ) + if logging_enabled: # pragma: NO COVER + request_metadata = client_call_details.metadata + if isinstance(request, proto.Message): + request_payload = type(request).to_json(request) + elif isinstance(request, google.protobuf.message.Message): + request_payload = MessageToJson(request) + else: + request_payload = f"{type(request).__name__}: {pickle.dumps(request)}" + + request_metadata = { + key: value.decode("utf-8") if isinstance(value, bytes) else value + for key, value in request_metadata + } + grpc_request = { + "payload": request_payload, + "requestMethod": "grpc", + "metadata": dict(request_metadata), + } + _LOGGER.debug( + f"Sending request for {client_call_details.method}", + extra={ + "serviceName": "google.firestore.v1.Firestore", + "rpcName": str(client_call_details.method), + "request": grpc_request, + "metadata": grpc_request["metadata"], + }, + ) + response = await continuation(client_call_details, request) + if logging_enabled: # pragma: NO COVER + response_metadata = await response.trailing_metadata() + # Convert gRPC metadata `` to list of tuples + metadata = ( + dict([(k, str(v)) for k, v in response_metadata]) + if response_metadata + else None + ) + result = await response + if isinstance(result, proto.Message): + response_payload = type(result).to_json(result) + elif isinstance(result, google.protobuf.message.Message): + response_payload = MessageToJson(result) + else: + response_payload = f"{type(result).__name__}: {pickle.dumps(result)}" + grpc_response = { + "payload": response_payload, + "metadata": metadata, + "status": "OK", + } + _LOGGER.debug( + f"Received response to rpc {client_call_details.method}.", + extra={ + "serviceName": "google.firestore.v1.Firestore", + "rpcName": str(client_call_details.method), + "response": grpc_response, + "metadata": grpc_response["metadata"], + }, + ) + return response + class FirestoreGrpcAsyncIOTransport(FirestoreTransport): """gRPC AsyncIO backend transport for Firestore. @@ -239,7 +322,13 @@ def __init__( ], ) - # Wrap messages. This must be done after self._grpc_channel exists + self._interceptor = _LoggingClientAIOInterceptor() + self._grpc_channel._unary_unary_interceptors.append(self._interceptor) + self._logged_channel = self._grpc_channel + self._wrap_with_kind = ( + "kind" in inspect.signature(gapic_v1.method_async.wrap_method).parameters + ) + # Wrap messages. This must be done after self._logged_channel exists self._prep_wrapped_messages(client_info) @property @@ -271,7 +360,7 @@ def get_document( # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. if "get_document" not in self._stubs: - self._stubs["get_document"] = self.grpc_channel.unary_unary( + self._stubs["get_document"] = self._logged_channel.unary_unary( "/google.firestore.v1.Firestore/GetDocument", request_serializer=firestore.GetDocumentRequest.serialize, response_deserializer=document.Document.deserialize, @@ -299,7 +388,7 @@ def list_documents( # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. if "list_documents" not in self._stubs: - self._stubs["list_documents"] = self.grpc_channel.unary_unary( + self._stubs["list_documents"] = self._logged_channel.unary_unary( "/google.firestore.v1.Firestore/ListDocuments", request_serializer=firestore.ListDocumentsRequest.serialize, response_deserializer=firestore.ListDocumentsResponse.deserialize, @@ -325,7 +414,7 @@ def update_document( # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. if "update_document" not in self._stubs: - self._stubs["update_document"] = self.grpc_channel.unary_unary( + self._stubs["update_document"] = self._logged_channel.unary_unary( "/google.firestore.v1.Firestore/UpdateDocument", request_serializer=firestore.UpdateDocumentRequest.serialize, response_deserializer=gf_document.Document.deserialize, @@ -351,7 +440,7 @@ def delete_document( # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. if "delete_document" not in self._stubs: - self._stubs["delete_document"] = self.grpc_channel.unary_unary( + self._stubs["delete_document"] = self._logged_channel.unary_unary( "/google.firestore.v1.Firestore/DeleteDocument", request_serializer=firestore.DeleteDocumentRequest.serialize, response_deserializer=empty_pb2.Empty.FromString, @@ -383,7 +472,7 @@ def batch_get_documents( # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. if "batch_get_documents" not in self._stubs: - self._stubs["batch_get_documents"] = self.grpc_channel.unary_stream( + self._stubs["batch_get_documents"] = self._logged_channel.unary_stream( "/google.firestore.v1.Firestore/BatchGetDocuments", request_serializer=firestore.BatchGetDocumentsRequest.serialize, response_deserializer=firestore.BatchGetDocumentsResponse.deserialize, @@ -412,7 +501,7 @@ def begin_transaction( # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. if "begin_transaction" not in self._stubs: - self._stubs["begin_transaction"] = self.grpc_channel.unary_unary( + self._stubs["begin_transaction"] = self._logged_channel.unary_unary( "/google.firestore.v1.Firestore/BeginTransaction", request_serializer=firestore.BeginTransactionRequest.serialize, response_deserializer=firestore.BeginTransactionResponse.deserialize, @@ -439,7 +528,7 @@ def commit( # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. if "commit" not in self._stubs: - self._stubs["commit"] = self.grpc_channel.unary_unary( + self._stubs["commit"] = self._logged_channel.unary_unary( "/google.firestore.v1.Firestore/Commit", request_serializer=firestore.CommitRequest.serialize, response_deserializer=firestore.CommitResponse.deserialize, @@ -465,7 +554,7 @@ def rollback( # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. if "rollback" not in self._stubs: - self._stubs["rollback"] = self.grpc_channel.unary_unary( + self._stubs["rollback"] = self._logged_channel.unary_unary( "/google.firestore.v1.Firestore/Rollback", request_serializer=firestore.RollbackRequest.serialize, response_deserializer=empty_pb2.Empty.FromString, @@ -491,7 +580,7 @@ def run_query( # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. if "run_query" not in self._stubs: - self._stubs["run_query"] = self.grpc_channel.unary_stream( + self._stubs["run_query"] = self._logged_channel.unary_stream( "/google.firestore.v1.Firestore/RunQuery", request_serializer=firestore.RunQueryRequest.serialize, response_deserializer=firestore.RunQueryResponse.deserialize, @@ -519,7 +608,7 @@ def execute_pipeline( # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. if "execute_pipeline" not in self._stubs: - self._stubs["execute_pipeline"] = self.grpc_channel.unary_stream( + self._stubs["execute_pipeline"] = self._logged_channel.unary_stream( "/google.firestore.v1.Firestore/ExecutePipeline", request_serializer=firestore.ExecutePipelineRequest.serialize, response_deserializer=firestore.ExecutePipelineResponse.deserialize, @@ -562,7 +651,7 @@ def run_aggregation_query( # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. if "run_aggregation_query" not in self._stubs: - self._stubs["run_aggregation_query"] = self.grpc_channel.unary_stream( + self._stubs["run_aggregation_query"] = self._logged_channel.unary_stream( "/google.firestore.v1.Firestore/RunAggregationQuery", request_serializer=firestore.RunAggregationQueryRequest.serialize, response_deserializer=firestore.RunAggregationQueryResponse.deserialize, @@ -594,7 +683,7 @@ def partition_query( # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. if "partition_query" not in self._stubs: - self._stubs["partition_query"] = self.grpc_channel.unary_unary( + self._stubs["partition_query"] = self._logged_channel.unary_unary( "/google.firestore.v1.Firestore/PartitionQuery", request_serializer=firestore.PartitionQueryRequest.serialize, response_deserializer=firestore.PartitionQueryResponse.deserialize, @@ -622,7 +711,7 @@ def write( # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. if "write" not in self._stubs: - self._stubs["write"] = self.grpc_channel.stream_stream( + self._stubs["write"] = self._logged_channel.stream_stream( "/google.firestore.v1.Firestore/Write", request_serializer=firestore.WriteRequest.serialize, response_deserializer=firestore.WriteResponse.deserialize, @@ -649,7 +738,7 @@ def listen( # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. if "listen" not in self._stubs: - self._stubs["listen"] = self.grpc_channel.stream_stream( + self._stubs["listen"] = self._logged_channel.stream_stream( "/google.firestore.v1.Firestore/Listen", request_serializer=firestore.ListenRequest.serialize, response_deserializer=firestore.ListenResponse.deserialize, @@ -678,7 +767,7 @@ def list_collection_ids( # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. if "list_collection_ids" not in self._stubs: - self._stubs["list_collection_ids"] = self.grpc_channel.unary_unary( + self._stubs["list_collection_ids"] = self._logged_channel.unary_unary( "/google.firestore.v1.Firestore/ListCollectionIds", request_serializer=firestore.ListCollectionIdsRequest.serialize, response_deserializer=firestore.ListCollectionIdsResponse.deserialize, @@ -716,7 +805,7 @@ def batch_write( # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. if "batch_write" not in self._stubs: - self._stubs["batch_write"] = self.grpc_channel.unary_unary( + self._stubs["batch_write"] = self._logged_channel.unary_unary( "/google.firestore.v1.Firestore/BatchWrite", request_serializer=firestore.BatchWriteRequest.serialize, response_deserializer=firestore.BatchWriteResponse.deserialize, @@ -742,7 +831,7 @@ def create_document( # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. if "create_document" not in self._stubs: - self._stubs["create_document"] = self.grpc_channel.unary_unary( + self._stubs["create_document"] = self._logged_channel.unary_unary( "/google.firestore.v1.Firestore/CreateDocument", request_serializer=firestore.CreateDocumentRequest.serialize, response_deserializer=document.Document.deserialize, @@ -752,7 +841,7 @@ def create_document( def _prep_wrapped_messages(self, client_info): """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" self._wrapped_methods = { - self.get_document: gapic_v1.method_async.wrap_method( + self.get_document: self._wrap_method( self.get_document, default_retry=retries.AsyncRetry( initial=0.1, @@ -769,7 +858,7 @@ def _prep_wrapped_messages(self, client_info): default_timeout=60.0, client_info=client_info, ), - self.list_documents: gapic_v1.method_async.wrap_method( + self.list_documents: self._wrap_method( self.list_documents, default_retry=retries.AsyncRetry( initial=0.1, @@ -786,7 +875,7 @@ def _prep_wrapped_messages(self, client_info): default_timeout=60.0, client_info=client_info, ), - self.update_document: gapic_v1.method_async.wrap_method( + self.update_document: self._wrap_method( self.update_document, default_retry=retries.AsyncRetry( initial=0.1, @@ -801,7 +890,7 @@ def _prep_wrapped_messages(self, client_info): default_timeout=60.0, client_info=client_info, ), - self.delete_document: gapic_v1.method_async.wrap_method( + self.delete_document: self._wrap_method( self.delete_document, default_retry=retries.AsyncRetry( initial=0.1, @@ -818,7 +907,7 @@ def _prep_wrapped_messages(self, client_info): default_timeout=60.0, client_info=client_info, ), - self.batch_get_documents: gapic_v1.method_async.wrap_method( + self.batch_get_documents: self._wrap_method( self.batch_get_documents, default_retry=retries.AsyncRetry( initial=0.1, @@ -835,7 +924,7 @@ def _prep_wrapped_messages(self, client_info): default_timeout=300.0, client_info=client_info, ), - self.begin_transaction: gapic_v1.method_async.wrap_method( + self.begin_transaction: self._wrap_method( self.begin_transaction, default_retry=retries.AsyncRetry( initial=0.1, @@ -852,7 +941,7 @@ def _prep_wrapped_messages(self, client_info): default_timeout=60.0, client_info=client_info, ), - self.commit: gapic_v1.method_async.wrap_method( + self.commit: self._wrap_method( self.commit, default_retry=retries.AsyncRetry( initial=0.1, @@ -867,7 +956,7 @@ def _prep_wrapped_messages(self, client_info): default_timeout=60.0, client_info=client_info, ), - self.rollback: gapic_v1.method_async.wrap_method( + self.rollback: self._wrap_method( self.rollback, default_retry=retries.AsyncRetry( initial=0.1, @@ -884,7 +973,7 @@ def _prep_wrapped_messages(self, client_info): default_timeout=60.0, client_info=client_info, ), - self.run_query: gapic_v1.method_async.wrap_method( + self.run_query: self._wrap_method( self.run_query, default_retry=retries.AsyncRetry( initial=0.1, @@ -901,7 +990,12 @@ def _prep_wrapped_messages(self, client_info): default_timeout=300.0, client_info=client_info, ), - self.run_aggregation_query: gapic_v1.method_async.wrap_method( + self.execute_pipeline: self._wrap_method( + self.execute_pipeline, + default_timeout=None, + client_info=client_info, + ), + self.run_aggregation_query: self._wrap_method( self.run_aggregation_query, default_retry=retries.AsyncRetry( initial=0.1, @@ -918,7 +1012,7 @@ def _prep_wrapped_messages(self, client_info): default_timeout=300.0, client_info=client_info, ), - self.partition_query: gapic_v1.method_async.wrap_method( + self.partition_query: self._wrap_method( self.partition_query, default_retry=retries.AsyncRetry( initial=0.1, @@ -935,12 +1029,12 @@ def _prep_wrapped_messages(self, client_info): default_timeout=300.0, client_info=client_info, ), - self.write: gapic_v1.method_async.wrap_method( + self.write: self._wrap_method( self.write, default_timeout=86400.0, client_info=client_info, ), - self.listen: gapic_v1.method_async.wrap_method( + self.listen: self._wrap_method( self.listen, default_retry=retries.AsyncRetry( initial=0.1, @@ -957,7 +1051,7 @@ def _prep_wrapped_messages(self, client_info): default_timeout=86400.0, client_info=client_info, ), - self.list_collection_ids: gapic_v1.method_async.wrap_method( + self.list_collection_ids: self._wrap_method( self.list_collection_ids, default_retry=retries.AsyncRetry( initial=0.1, @@ -974,7 +1068,7 @@ def _prep_wrapped_messages(self, client_info): default_timeout=60.0, client_info=client_info, ), - self.batch_write: gapic_v1.method_async.wrap_method( + self.batch_write: self._wrap_method( self.batch_write, default_retry=retries.AsyncRetry( initial=0.1, @@ -990,7 +1084,7 @@ def _prep_wrapped_messages(self, client_info): default_timeout=60.0, client_info=client_info, ), - self.create_document: gapic_v1.method_async.wrap_method( + self.create_document: self._wrap_method( self.create_document, default_retry=retries.AsyncRetry( initial=0.1, @@ -1005,10 +1099,39 @@ def _prep_wrapped_messages(self, client_info): default_timeout=60.0, client_info=client_info, ), + self.cancel_operation: self._wrap_method( + self.cancel_operation, + default_timeout=None, + client_info=client_info, + ), + self.delete_operation: self._wrap_method( + self.delete_operation, + default_timeout=None, + client_info=client_info, + ), + self.get_operation: self._wrap_method( + self.get_operation, + default_timeout=None, + client_info=client_info, + ), + self.list_operations: self._wrap_method( + self.list_operations, + default_timeout=None, + client_info=client_info, + ), } + def _wrap_method(self, func, *args, **kwargs): + if self._wrap_with_kind: # pragma: NO COVER + kwargs["kind"] = self.kind + return gapic_v1.method_async.wrap_method(func, *args, **kwargs) + def close(self): - return self.grpc_channel.close() + return self._logged_channel.close() + + @property + def kind(self) -> str: + return "grpc_asyncio" @property def delete_operation( @@ -1020,7 +1143,7 @@ def delete_operation( # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. if "delete_operation" not in self._stubs: - self._stubs["delete_operation"] = self.grpc_channel.unary_unary( + self._stubs["delete_operation"] = self._logged_channel.unary_unary( "/google.longrunning.Operations/DeleteOperation", request_serializer=operations_pb2.DeleteOperationRequest.SerializeToString, response_deserializer=None, @@ -1037,7 +1160,7 @@ def cancel_operation( # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. if "cancel_operation" not in self._stubs: - self._stubs["cancel_operation"] = self.grpc_channel.unary_unary( + self._stubs["cancel_operation"] = self._logged_channel.unary_unary( "/google.longrunning.Operations/CancelOperation", request_serializer=operations_pb2.CancelOperationRequest.SerializeToString, response_deserializer=None, @@ -1054,7 +1177,7 @@ def get_operation( # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. if "get_operation" not in self._stubs: - self._stubs["get_operation"] = self.grpc_channel.unary_unary( + self._stubs["get_operation"] = self._logged_channel.unary_unary( "/google.longrunning.Operations/GetOperation", request_serializer=operations_pb2.GetOperationRequest.SerializeToString, response_deserializer=operations_pb2.Operation.FromString, @@ -1073,7 +1196,7 @@ def list_operations( # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. if "list_operations" not in self._stubs: - self._stubs["list_operations"] = self.grpc_channel.unary_unary( + self._stubs["list_operations"] = self._logged_channel.unary_unary( "/google.longrunning.Operations/ListOperations", request_serializer=operations_pb2.ListOperationsRequest.SerializeToString, response_deserializer=operations_pb2.ListOperationsResponse.FromString, diff --git a/google/cloud/firestore_v1/services/firestore/transports/rest.py b/google/cloud/firestore_v1/services/firestore/transports/rest.py index 6a4ae16b4..4bd282fe6 100644 --- a/google/cloud/firestore_v1/services/firestore/transports/rest.py +++ b/google/cloud/firestore_v1/services/firestore/transports/rest.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2024 Google LLC +# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,32 +13,25 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import logging +import json # type: ignore from google.auth.transport.requests import AuthorizedSession # type: ignore -import json # type: ignore -import grpc # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore from google.auth import credentials as ga_credentials # type: ignore from google.api_core import exceptions as core_exceptions from google.api_core import retry as retries from google.api_core import rest_helpers from google.api_core import rest_streaming -from google.api_core import path_template from google.api_core import gapic_v1 from google.protobuf import json_format from google.cloud.location import locations_pb2 # type: ignore + from requests import __version__ as requests_version import dataclasses -import re from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import warnings -try: - OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] -except AttributeError: # pragma: NO COVER - OptionalRetry = Union[retries.Retry, object, None] # type: ignore - from google.cloud.firestore_v1.types import document from google.cloud.firestore_v1.types import document as gf_document @@ -46,13 +39,28 @@ from google.protobuf import empty_pb2 # type: ignore from google.longrunning import operations_pb2 # type: ignore -from .base import FirestoreTransport, DEFAULT_CLIENT_INFO as BASE_DEFAULT_CLIENT_INFO +from .rest_base import _BaseFirestoreRestTransport +from .base import DEFAULT_CLIENT_INFO as BASE_DEFAULT_CLIENT_INFO + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + +try: + from google.api_core import client_logging # type: ignore + + CLIENT_LOGGING_SUPPORTED = True # pragma: NO COVER +except ImportError: # pragma: NO COVER + CLIENT_LOGGING_SUPPORTED = False + +_LOGGER = logging.getLogger(__name__) DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=BASE_DEFAULT_CLIENT_INFO.gapic_version, grpc_version=None, - rest_version=requests_version, + rest_version=f"requests@{requests_version}", ) @@ -192,8 +200,10 @@ def post_update_document(self, response): def pre_batch_get_documents( self, request: firestore.BatchGetDocumentsRequest, - metadata: Sequence[Tuple[str, str]], - ) -> Tuple[firestore.BatchGetDocumentsRequest, Sequence[Tuple[str, str]]]: + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + firestore.BatchGetDocumentsRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: """Pre-rpc interceptor for batch_get_documents Override in a subclass to manipulate the request or metadata @@ -206,15 +216,42 @@ def post_batch_get_documents( ) -> rest_streaming.ResponseIterator: """Post-rpc interceptor for batch_get_documents - Override in a subclass to manipulate the response + DEPRECATED. Please use the `post_batch_get_documents_with_metadata` + interceptor instead. + + Override in a subclass to read or manipulate the response after it is returned by the Firestore server but before - it is returned to user code. + it is returned to user code. This `post_batch_get_documents` interceptor runs + before the `post_batch_get_documents_with_metadata` interceptor. """ return response + def post_batch_get_documents_with_metadata( + self, + response: rest_streaming.ResponseIterator, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + rest_streaming.ResponseIterator, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Post-rpc interceptor for batch_get_documents + + Override in a subclass to read or manipulate the response or metadata after it + is returned by the Firestore server but before it is returned to user code. + + We recommend only using this `post_batch_get_documents_with_metadata` + interceptor in new development instead of the `post_batch_get_documents` interceptor. + When both interceptors are used, this `post_batch_get_documents_with_metadata` interceptor runs after the + `post_batch_get_documents` interceptor. The (possibly modified) response returned by + `post_batch_get_documents` will be passed to + `post_batch_get_documents_with_metadata`. + """ + return response, metadata + def pre_batch_write( - self, request: firestore.BatchWriteRequest, metadata: Sequence[Tuple[str, str]] - ) -> Tuple[firestore.BatchWriteRequest, Sequence[Tuple[str, str]]]: + self, + request: firestore.BatchWriteRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[firestore.BatchWriteRequest, Sequence[Tuple[str, Union[str, bytes]]]]: """Pre-rpc interceptor for batch_write Override in a subclass to manipulate the request or metadata @@ -227,17 +264,42 @@ def post_batch_write( ) -> firestore.BatchWriteResponse: """Post-rpc interceptor for batch_write - Override in a subclass to manipulate the response + DEPRECATED. Please use the `post_batch_write_with_metadata` + interceptor instead. + + Override in a subclass to read or manipulate the response after it is returned by the Firestore server but before - it is returned to user code. + it is returned to user code. This `post_batch_write` interceptor runs + before the `post_batch_write_with_metadata` interceptor. """ return response + def post_batch_write_with_metadata( + self, + response: firestore.BatchWriteResponse, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[firestore.BatchWriteResponse, Sequence[Tuple[str, Union[str, bytes]]]]: + """Post-rpc interceptor for batch_write + + Override in a subclass to read or manipulate the response or metadata after it + is returned by the Firestore server but before it is returned to user code. + + We recommend only using this `post_batch_write_with_metadata` + interceptor in new development instead of the `post_batch_write` interceptor. + When both interceptors are used, this `post_batch_write_with_metadata` interceptor runs after the + `post_batch_write` interceptor. The (possibly modified) response returned by + `post_batch_write` will be passed to + `post_batch_write_with_metadata`. + """ + return response, metadata + def pre_begin_transaction( self, request: firestore.BeginTransactionRequest, - metadata: Sequence[Tuple[str, str]], - ) -> Tuple[firestore.BeginTransactionRequest, Sequence[Tuple[str, str]]]: + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + firestore.BeginTransactionRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: """Pre-rpc interceptor for begin_transaction Override in a subclass to manipulate the request or metadata @@ -250,15 +312,42 @@ def post_begin_transaction( ) -> firestore.BeginTransactionResponse: """Post-rpc interceptor for begin_transaction - Override in a subclass to manipulate the response + DEPRECATED. Please use the `post_begin_transaction_with_metadata` + interceptor instead. + + Override in a subclass to read or manipulate the response after it is returned by the Firestore server but before - it is returned to user code. + it is returned to user code. This `post_begin_transaction` interceptor runs + before the `post_begin_transaction_with_metadata` interceptor. """ return response + def post_begin_transaction_with_metadata( + self, + response: firestore.BeginTransactionResponse, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + firestore.BeginTransactionResponse, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Post-rpc interceptor for begin_transaction + + Override in a subclass to read or manipulate the response or metadata after it + is returned by the Firestore server but before it is returned to user code. + + We recommend only using this `post_begin_transaction_with_metadata` + interceptor in new development instead of the `post_begin_transaction` interceptor. + When both interceptors are used, this `post_begin_transaction_with_metadata` interceptor runs after the + `post_begin_transaction` interceptor. The (possibly modified) response returned by + `post_begin_transaction` will be passed to + `post_begin_transaction_with_metadata`. + """ + return response, metadata + def pre_commit( - self, request: firestore.CommitRequest, metadata: Sequence[Tuple[str, str]] - ) -> Tuple[firestore.CommitRequest, Sequence[Tuple[str, str]]]: + self, + request: firestore.CommitRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[firestore.CommitRequest, Sequence[Tuple[str, Union[str, bytes]]]]: """Pre-rpc interceptor for commit Override in a subclass to manipulate the request or metadata @@ -271,17 +360,42 @@ def post_commit( ) -> firestore.CommitResponse: """Post-rpc interceptor for commit - Override in a subclass to manipulate the response + DEPRECATED. Please use the `post_commit_with_metadata` + interceptor instead. + + Override in a subclass to read or manipulate the response after it is returned by the Firestore server but before - it is returned to user code. + it is returned to user code. This `post_commit` interceptor runs + before the `post_commit_with_metadata` interceptor. """ return response + def post_commit_with_metadata( + self, + response: firestore.CommitResponse, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[firestore.CommitResponse, Sequence[Tuple[str, Union[str, bytes]]]]: + """Post-rpc interceptor for commit + + Override in a subclass to read or manipulate the response or metadata after it + is returned by the Firestore server but before it is returned to user code. + + We recommend only using this `post_commit_with_metadata` + interceptor in new development instead of the `post_commit` interceptor. + When both interceptors are used, this `post_commit_with_metadata` interceptor runs after the + `post_commit` interceptor. The (possibly modified) response returned by + `post_commit` will be passed to + `post_commit_with_metadata`. + """ + return response, metadata + def pre_create_document( self, request: firestore.CreateDocumentRequest, - metadata: Sequence[Tuple[str, str]], - ) -> Tuple[firestore.CreateDocumentRequest, Sequence[Tuple[str, str]]]: + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + firestore.CreateDocumentRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: """Pre-rpc interceptor for create_document Override in a subclass to manipulate the request or metadata @@ -292,17 +406,42 @@ def pre_create_document( def post_create_document(self, response: document.Document) -> document.Document: """Post-rpc interceptor for create_document - Override in a subclass to manipulate the response + DEPRECATED. Please use the `post_create_document_with_metadata` + interceptor instead. + + Override in a subclass to read or manipulate the response after it is returned by the Firestore server but before - it is returned to user code. + it is returned to user code. This `post_create_document` interceptor runs + before the `post_create_document_with_metadata` interceptor. """ return response + def post_create_document_with_metadata( + self, + response: document.Document, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[document.Document, Sequence[Tuple[str, Union[str, bytes]]]]: + """Post-rpc interceptor for create_document + + Override in a subclass to read or manipulate the response or metadata after it + is returned by the Firestore server but before it is returned to user code. + + We recommend only using this `post_create_document_with_metadata` + interceptor in new development instead of the `post_create_document` interceptor. + When both interceptors are used, this `post_create_document_with_metadata` interceptor runs after the + `post_create_document` interceptor. The (possibly modified) response returned by + `post_create_document` will be passed to + `post_create_document_with_metadata`. + """ + return response, metadata + def pre_delete_document( self, request: firestore.DeleteDocumentRequest, - metadata: Sequence[Tuple[str, str]], - ) -> Tuple[firestore.DeleteDocumentRequest, Sequence[Tuple[str, str]]]: + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + firestore.DeleteDocumentRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: """Pre-rpc interceptor for delete_document Override in a subclass to manipulate the request or metadata @@ -313,8 +452,10 @@ def pre_delete_document( def pre_execute_pipeline( self, request: firestore.ExecutePipelineRequest, - metadata: Sequence[Tuple[str, str]], - ) -> Tuple[firestore.ExecutePipelineRequest, Sequence[Tuple[str, str]]]: + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + firestore.ExecutePipelineRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: """Pre-rpc interceptor for execute_pipeline Override in a subclass to manipulate the request or metadata @@ -327,15 +468,42 @@ def post_execute_pipeline( ) -> rest_streaming.ResponseIterator: """Post-rpc interceptor for execute_pipeline - Override in a subclass to manipulate the response + DEPRECATED. Please use the `post_execute_pipeline_with_metadata` + interceptor instead. + + Override in a subclass to read or manipulate the response after it is returned by the Firestore server but before - it is returned to user code. + it is returned to user code. This `post_execute_pipeline` interceptor runs + before the `post_execute_pipeline_with_metadata` interceptor. """ return response + def post_execute_pipeline_with_metadata( + self, + response: rest_streaming.ResponseIterator, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + rest_streaming.ResponseIterator, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Post-rpc interceptor for execute_pipeline + + Override in a subclass to read or manipulate the response or metadata after it + is returned by the Firestore server but before it is returned to user code. + + We recommend only using this `post_execute_pipeline_with_metadata` + interceptor in new development instead of the `post_execute_pipeline` interceptor. + When both interceptors are used, this `post_execute_pipeline_with_metadata` interceptor runs after the + `post_execute_pipeline` interceptor. The (possibly modified) response returned by + `post_execute_pipeline` will be passed to + `post_execute_pipeline_with_metadata`. + """ + return response, metadata + def pre_get_document( - self, request: firestore.GetDocumentRequest, metadata: Sequence[Tuple[str, str]] - ) -> Tuple[firestore.GetDocumentRequest, Sequence[Tuple[str, str]]]: + self, + request: firestore.GetDocumentRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[firestore.GetDocumentRequest, Sequence[Tuple[str, Union[str, bytes]]]]: """Pre-rpc interceptor for get_document Override in a subclass to manipulate the request or metadata @@ -346,17 +514,42 @@ def pre_get_document( def post_get_document(self, response: document.Document) -> document.Document: """Post-rpc interceptor for get_document - Override in a subclass to manipulate the response + DEPRECATED. Please use the `post_get_document_with_metadata` + interceptor instead. + + Override in a subclass to read or manipulate the response after it is returned by the Firestore server but before - it is returned to user code. + it is returned to user code. This `post_get_document` interceptor runs + before the `post_get_document_with_metadata` interceptor. """ return response + def post_get_document_with_metadata( + self, + response: document.Document, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[document.Document, Sequence[Tuple[str, Union[str, bytes]]]]: + """Post-rpc interceptor for get_document + + Override in a subclass to read or manipulate the response or metadata after it + is returned by the Firestore server but before it is returned to user code. + + We recommend only using this `post_get_document_with_metadata` + interceptor in new development instead of the `post_get_document` interceptor. + When both interceptors are used, this `post_get_document_with_metadata` interceptor runs after the + `post_get_document` interceptor. The (possibly modified) response returned by + `post_get_document` will be passed to + `post_get_document_with_metadata`. + """ + return response, metadata + def pre_list_collection_ids( self, request: firestore.ListCollectionIdsRequest, - metadata: Sequence[Tuple[str, str]], - ) -> Tuple[firestore.ListCollectionIdsRequest, Sequence[Tuple[str, str]]]: + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + firestore.ListCollectionIdsRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: """Pre-rpc interceptor for list_collection_ids Override in a subclass to manipulate the request or metadata @@ -369,17 +562,42 @@ def post_list_collection_ids( ) -> firestore.ListCollectionIdsResponse: """Post-rpc interceptor for list_collection_ids - Override in a subclass to manipulate the response + DEPRECATED. Please use the `post_list_collection_ids_with_metadata` + interceptor instead. + + Override in a subclass to read or manipulate the response after it is returned by the Firestore server but before - it is returned to user code. + it is returned to user code. This `post_list_collection_ids` interceptor runs + before the `post_list_collection_ids_with_metadata` interceptor. """ return response + def post_list_collection_ids_with_metadata( + self, + response: firestore.ListCollectionIdsResponse, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + firestore.ListCollectionIdsResponse, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Post-rpc interceptor for list_collection_ids + + Override in a subclass to read or manipulate the response or metadata after it + is returned by the Firestore server but before it is returned to user code. + + We recommend only using this `post_list_collection_ids_with_metadata` + interceptor in new development instead of the `post_list_collection_ids` interceptor. + When both interceptors are used, this `post_list_collection_ids_with_metadata` interceptor runs after the + `post_list_collection_ids` interceptor. The (possibly modified) response returned by + `post_list_collection_ids` will be passed to + `post_list_collection_ids_with_metadata`. + """ + return response, metadata + def pre_list_documents( self, request: firestore.ListDocumentsRequest, - metadata: Sequence[Tuple[str, str]], - ) -> Tuple[firestore.ListDocumentsRequest, Sequence[Tuple[str, str]]]: + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[firestore.ListDocumentsRequest, Sequence[Tuple[str, Union[str, bytes]]]]: """Pre-rpc interceptor for list_documents Override in a subclass to manipulate the request or metadata @@ -392,17 +610,44 @@ def post_list_documents( ) -> firestore.ListDocumentsResponse: """Post-rpc interceptor for list_documents - Override in a subclass to manipulate the response + DEPRECATED. Please use the `post_list_documents_with_metadata` + interceptor instead. + + Override in a subclass to read or manipulate the response after it is returned by the Firestore server but before - it is returned to user code. + it is returned to user code. This `post_list_documents` interceptor runs + before the `post_list_documents_with_metadata` interceptor. """ return response + def post_list_documents_with_metadata( + self, + response: firestore.ListDocumentsResponse, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + firestore.ListDocumentsResponse, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Post-rpc interceptor for list_documents + + Override in a subclass to read or manipulate the response or metadata after it + is returned by the Firestore server but before it is returned to user code. + + We recommend only using this `post_list_documents_with_metadata` + interceptor in new development instead of the `post_list_documents` interceptor. + When both interceptors are used, this `post_list_documents_with_metadata` interceptor runs after the + `post_list_documents` interceptor. The (possibly modified) response returned by + `post_list_documents` will be passed to + `post_list_documents_with_metadata`. + """ + return response, metadata + def pre_partition_query( self, request: firestore.PartitionQueryRequest, - metadata: Sequence[Tuple[str, str]], - ) -> Tuple[firestore.PartitionQueryRequest, Sequence[Tuple[str, str]]]: + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + firestore.PartitionQueryRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: """Pre-rpc interceptor for partition_query Override in a subclass to manipulate the request or metadata @@ -415,15 +660,42 @@ def post_partition_query( ) -> firestore.PartitionQueryResponse: """Post-rpc interceptor for partition_query - Override in a subclass to manipulate the response + DEPRECATED. Please use the `post_partition_query_with_metadata` + interceptor instead. + + Override in a subclass to read or manipulate the response after it is returned by the Firestore server but before - it is returned to user code. + it is returned to user code. This `post_partition_query` interceptor runs + before the `post_partition_query_with_metadata` interceptor. """ return response + def post_partition_query_with_metadata( + self, + response: firestore.PartitionQueryResponse, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + firestore.PartitionQueryResponse, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Post-rpc interceptor for partition_query + + Override in a subclass to read or manipulate the response or metadata after it + is returned by the Firestore server but before it is returned to user code. + + We recommend only using this `post_partition_query_with_metadata` + interceptor in new development instead of the `post_partition_query` interceptor. + When both interceptors are used, this `post_partition_query_with_metadata` interceptor runs after the + `post_partition_query` interceptor. The (possibly modified) response returned by + `post_partition_query` will be passed to + `post_partition_query_with_metadata`. + """ + return response, metadata + def pre_rollback( - self, request: firestore.RollbackRequest, metadata: Sequence[Tuple[str, str]] - ) -> Tuple[firestore.RollbackRequest, Sequence[Tuple[str, str]]]: + self, + request: firestore.RollbackRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[firestore.RollbackRequest, Sequence[Tuple[str, Union[str, bytes]]]]: """Pre-rpc interceptor for rollback Override in a subclass to manipulate the request or metadata @@ -434,8 +706,10 @@ def pre_rollback( def pre_run_aggregation_query( self, request: firestore.RunAggregationQueryRequest, - metadata: Sequence[Tuple[str, str]], - ) -> Tuple[firestore.RunAggregationQueryRequest, Sequence[Tuple[str, str]]]: + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + firestore.RunAggregationQueryRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: """Pre-rpc interceptor for run_aggregation_query Override in a subclass to manipulate the request or metadata @@ -448,15 +722,42 @@ def post_run_aggregation_query( ) -> rest_streaming.ResponseIterator: """Post-rpc interceptor for run_aggregation_query - Override in a subclass to manipulate the response + DEPRECATED. Please use the `post_run_aggregation_query_with_metadata` + interceptor instead. + + Override in a subclass to read or manipulate the response after it is returned by the Firestore server but before - it is returned to user code. + it is returned to user code. This `post_run_aggregation_query` interceptor runs + before the `post_run_aggregation_query_with_metadata` interceptor. """ return response + def post_run_aggregation_query_with_metadata( + self, + response: rest_streaming.ResponseIterator, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + rest_streaming.ResponseIterator, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Post-rpc interceptor for run_aggregation_query + + Override in a subclass to read or manipulate the response or metadata after it + is returned by the Firestore server but before it is returned to user code. + + We recommend only using this `post_run_aggregation_query_with_metadata` + interceptor in new development instead of the `post_run_aggregation_query` interceptor. + When both interceptors are used, this `post_run_aggregation_query_with_metadata` interceptor runs after the + `post_run_aggregation_query` interceptor. The (possibly modified) response returned by + `post_run_aggregation_query` will be passed to + `post_run_aggregation_query_with_metadata`. + """ + return response, metadata + def pre_run_query( - self, request: firestore.RunQueryRequest, metadata: Sequence[Tuple[str, str]] - ) -> Tuple[firestore.RunQueryRequest, Sequence[Tuple[str, str]]]: + self, + request: firestore.RunQueryRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[firestore.RunQueryRequest, Sequence[Tuple[str, Union[str, bytes]]]]: """Pre-rpc interceptor for run_query Override in a subclass to manipulate the request or metadata @@ -469,17 +770,44 @@ def post_run_query( ) -> rest_streaming.ResponseIterator: """Post-rpc interceptor for run_query - Override in a subclass to manipulate the response + DEPRECATED. Please use the `post_run_query_with_metadata` + interceptor instead. + + Override in a subclass to read or manipulate the response after it is returned by the Firestore server but before - it is returned to user code. + it is returned to user code. This `post_run_query` interceptor runs + before the `post_run_query_with_metadata` interceptor. """ return response + def post_run_query_with_metadata( + self, + response: rest_streaming.ResponseIterator, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + rest_streaming.ResponseIterator, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Post-rpc interceptor for run_query + + Override in a subclass to read or manipulate the response or metadata after it + is returned by the Firestore server but before it is returned to user code. + + We recommend only using this `post_run_query_with_metadata` + interceptor in new development instead of the `post_run_query` interceptor. + When both interceptors are used, this `post_run_query_with_metadata` interceptor runs after the + `post_run_query` interceptor. The (possibly modified) response returned by + `post_run_query` will be passed to + `post_run_query_with_metadata`. + """ + return response, metadata + def pre_update_document( self, request: firestore.UpdateDocumentRequest, - metadata: Sequence[Tuple[str, str]], - ) -> Tuple[firestore.UpdateDocumentRequest, Sequence[Tuple[str, str]]]: + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + firestore.UpdateDocumentRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: """Pre-rpc interceptor for update_document Override in a subclass to manipulate the request or metadata @@ -492,17 +820,42 @@ def post_update_document( ) -> gf_document.Document: """Post-rpc interceptor for update_document - Override in a subclass to manipulate the response + DEPRECATED. Please use the `post_update_document_with_metadata` + interceptor instead. + + Override in a subclass to read or manipulate the response after it is returned by the Firestore server but before - it is returned to user code. + it is returned to user code. This `post_update_document` interceptor runs + before the `post_update_document_with_metadata` interceptor. """ return response + def post_update_document_with_metadata( + self, + response: gf_document.Document, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[gf_document.Document, Sequence[Tuple[str, Union[str, bytes]]]]: + """Post-rpc interceptor for update_document + + Override in a subclass to read or manipulate the response or metadata after it + is returned by the Firestore server but before it is returned to user code. + + We recommend only using this `post_update_document_with_metadata` + interceptor in new development instead of the `post_update_document` interceptor. + When both interceptors are used, this `post_update_document_with_metadata` interceptor runs after the + `post_update_document` interceptor. The (possibly modified) response returned by + `post_update_document` will be passed to + `post_update_document_with_metadata`. + """ + return response, metadata + def pre_cancel_operation( self, request: operations_pb2.CancelOperationRequest, - metadata: Sequence[Tuple[str, str]], - ) -> Tuple[operations_pb2.CancelOperationRequest, Sequence[Tuple[str, str]]]: + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + operations_pb2.CancelOperationRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: """Pre-rpc interceptor for cancel_operation Override in a subclass to manipulate the request or metadata @@ -522,8 +875,10 @@ def post_cancel_operation(self, response: None) -> None: def pre_delete_operation( self, request: operations_pb2.DeleteOperationRequest, - metadata: Sequence[Tuple[str, str]], - ) -> Tuple[operations_pb2.DeleteOperationRequest, Sequence[Tuple[str, str]]]: + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + operations_pb2.DeleteOperationRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: """Pre-rpc interceptor for delete_operation Override in a subclass to manipulate the request or metadata @@ -543,8 +898,10 @@ def post_delete_operation(self, response: None) -> None: def pre_get_operation( self, request: operations_pb2.GetOperationRequest, - metadata: Sequence[Tuple[str, str]], - ) -> Tuple[operations_pb2.GetOperationRequest, Sequence[Tuple[str, str]]]: + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + operations_pb2.GetOperationRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: """Pre-rpc interceptor for get_operation Override in a subclass to manipulate the request or metadata @@ -566,8 +923,10 @@ def post_get_operation( def pre_list_operations( self, request: operations_pb2.ListOperationsRequest, - metadata: Sequence[Tuple[str, str]], - ) -> Tuple[operations_pb2.ListOperationsRequest, Sequence[Tuple[str, str]]]: + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + operations_pb2.ListOperationsRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: """Pre-rpc interceptor for list_operations Override in a subclass to manipulate the request or metadata @@ -594,8 +953,8 @@ class FirestoreRestStub: _interceptor: FirestoreRestInterceptor -class FirestoreRestTransport(FirestoreTransport): - """REST backend transport for Firestore. +class FirestoreRestTransport(_BaseFirestoreRestTransport): + """REST backend synchronous transport for Firestore. The Cloud Firestore service. @@ -612,7 +971,6 @@ class FirestoreRestTransport(FirestoreTransport): and call it. It sends JSON representations of protocol buffers over HTTP/1.1 - """ def __init__( @@ -632,55 +990,50 @@ def __init__( ) -> None: """Instantiate the transport. - Args: - host (Optional[str]): - The hostname to connect to (default: 'firestore.googleapis.com'). - credentials (Optional[google.auth.credentials.Credentials]): The - authorization credentials to attach to requests. These - credentials identify the application to the service; if none - are specified, the client will attempt to ascertain the - credentials from the environment. - - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. - scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client - certificate to configure mutual TLS HTTP channel. It is ignored - if ``channel`` is provided. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you are developing - your own client library. - always_use_jwt_access (Optional[bool]): Whether self signed JWT should - be used for service account credentials. - url_scheme: the protocol scheme for the API endpoint. Normally - "https", but for testing or local servers, - "http" can be specified. + NOTE: This REST transport functionality is currently in a beta + state (preview). We welcome your feedback via a GitHub issue in + this library's repository. Thank you! + + Args: + host (Optional[str]): + The hostname to connect to (default: 'firestore.googleapis.com'). + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you are developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + url_scheme: the protocol scheme for the API endpoint. Normally + "https", but for testing or local servers, + "http" can be specified. """ # Run the base constructor # TODO(yon-mg): resolve other ctor params i.e. scopes, quota, etc. # TODO: When custom host (api_endpoint) is set, `scopes` must *also* be set on the # credentials object - maybe_url_match = re.match("^(?Phttp(?:s)?://)?(?P.*)$", host) - if maybe_url_match is None: - raise ValueError( - f"Unexpected hostname structure: {host}" - ) # pragma: NO COVER - - url_match_items = maybe_url_match.groupdict() - - host = f"{url_scheme}://{host}" if not url_match_items["scheme"] else host - super().__init__( host=host, credentials=credentials, client_info=client_info, always_use_jwt_access=always_use_jwt_access, + url_scheme=url_scheme, api_audience=api_audience, ) self._session = AuthorizedSession( @@ -691,19 +1044,35 @@ def __init__( self._interceptor = interceptor or FirestoreRestInterceptor() self._prep_wrapped_messages(client_info) - class _BatchGetDocuments(FirestoreRestStub): + class _BatchGetDocuments( + _BaseFirestoreRestTransport._BaseBatchGetDocuments, FirestoreRestStub + ): def __hash__(self): - return hash("BatchGetDocuments") - - __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} - - @classmethod - def _get_unset_required_fields(cls, message_dict): - return { - k: v - for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() - if k not in message_dict - } + return hash("FirestoreRestTransport.BatchGetDocuments") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + stream=True, + ) + return response def __call__( self, @@ -711,7 +1080,7 @@ def __call__( *, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> rest_streaming.ResponseIterator: r"""Call the batch get documents method over HTTP. @@ -722,8 +1091,10 @@ def __call__( retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. Returns: ~.firestore.BatchGetDocumentsResponse: @@ -732,47 +1103,62 @@ def __call__( """ - http_options: List[Dict[str, str]] = [ - { - "method": "post", - "uri": "/v1/{database=projects/*/databases/*}/documents:batchGet", - "body": "*", - }, - ] + http_options = ( + _BaseFirestoreRestTransport._BaseBatchGetDocuments._get_http_options() + ) + request, metadata = self._interceptor.pre_batch_get_documents( request, metadata ) - pb_request = firestore.BatchGetDocumentsRequest.pb(request) - transcoded_request = path_template.transcode(http_options, pb_request) - - # Jsonify the request body + transcoded_request = _BaseFirestoreRestTransport._BaseBatchGetDocuments._get_transcoded_request( + http_options, request + ) - body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + body = _BaseFirestoreRestTransport._BaseBatchGetDocuments._get_request_body_json( + transcoded_request ) - uri = transcoded_request["uri"] - method = transcoded_request["method"] # Jsonify the query params - query_params = json.loads( - json_format.MessageToJson( - transcoded_request["query_params"], - use_integers_for_enums=True, - ) + query_params = _BaseFirestoreRestTransport._BaseBatchGetDocuments._get_query_params_json( + transcoded_request ) - query_params.update(self._get_unset_required_fields(query_params)) - query_params["$alt"] = "json;enum-encoding=int" + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.firestore_v1.FirestoreClient.BatchGetDocuments", + extra={ + "serviceName": "google.firestore.v1.Firestore", + "rpcName": "BatchGetDocuments", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) # Send the request - headers = dict(metadata) - headers["Content-Type"] = "application/json" - response = getattr(self._session, method)( - "{host}{uri}".format(host=self._host, uri=uri), - timeout=timeout, - headers=headers, - params=rest_helpers.flatten_query_params(query_params, strict=True), - data=body, + response = FirestoreRestTransport._BatchGetDocuments._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + body, ) # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception @@ -784,22 +1170,40 @@ def __call__( resp = rest_streaming.ResponseIterator( response, firestore.BatchGetDocumentsResponse ) + resp = self._interceptor.post_batch_get_documents(resp) + response_metadata = [(k, str(v)) for k, v in response.headers.items()] + resp, _ = self._interceptor.post_batch_get_documents_with_metadata( + resp, response_metadata + ) return resp - class _BatchWrite(FirestoreRestStub): + class _BatchWrite(_BaseFirestoreRestTransport._BaseBatchWrite, FirestoreRestStub): def __hash__(self): - return hash("BatchWrite") - - __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} - - @classmethod - def _get_unset_required_fields(cls, message_dict): - return { - k: v - for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() - if k not in message_dict - } + return hash("FirestoreRestTransport.BatchWrite") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + return response def __call__( self, @@ -807,7 +1211,7 @@ def __call__( *, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> firestore.BatchWriteResponse: r"""Call the batch write method over HTTP. @@ -818,8 +1222,10 @@ def __call__( retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. Returns: ~.firestore.BatchWriteResponse: @@ -828,45 +1234,64 @@ def __call__( """ - http_options: List[Dict[str, str]] = [ - { - "method": "post", - "uri": "/v1/{database=projects/*/databases/*}/documents:batchWrite", - "body": "*", - }, - ] - request, metadata = self._interceptor.pre_batch_write(request, metadata) - pb_request = firestore.BatchWriteRequest.pb(request) - transcoded_request = path_template.transcode(http_options, pb_request) + http_options = ( + _BaseFirestoreRestTransport._BaseBatchWrite._get_http_options() + ) - # Jsonify the request body + request, metadata = self._interceptor.pre_batch_write(request, metadata) + transcoded_request = ( + _BaseFirestoreRestTransport._BaseBatchWrite._get_transcoded_request( + http_options, request + ) + ) - body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + body = _BaseFirestoreRestTransport._BaseBatchWrite._get_request_body_json( + transcoded_request ) - uri = transcoded_request["uri"] - method = transcoded_request["method"] # Jsonify the query params - query_params = json.loads( - json_format.MessageToJson( - transcoded_request["query_params"], - use_integers_for_enums=True, + query_params = ( + _BaseFirestoreRestTransport._BaseBatchWrite._get_query_params_json( + transcoded_request ) ) - query_params.update(self._get_unset_required_fields(query_params)) - query_params["$alt"] = "json;enum-encoding=int" + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.firestore_v1.FirestoreClient.BatchWrite", + extra={ + "serviceName": "google.firestore.v1.Firestore", + "rpcName": "BatchWrite", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) # Send the request - headers = dict(metadata) - headers["Content-Type"] = "application/json" - response = getattr(self._session, method)( - "{host}{uri}".format(host=self._host, uri=uri), - timeout=timeout, - headers=headers, - params=rest_helpers.flatten_query_params(query_params, strict=True), - data=body, + response = FirestoreRestTransport._BatchWrite._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + body, ) # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception @@ -879,22 +1304,63 @@ def __call__( pb_resp = firestore.BatchWriteResponse.pb(resp) json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + resp = self._interceptor.post_batch_write(resp) + response_metadata = [(k, str(v)) for k, v in response.headers.items()] + resp, _ = self._interceptor.post_batch_write_with_metadata( + resp, response_metadata + ) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = firestore.BatchWriteResponse.to_json(response) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.firestore_v1.FirestoreClient.batch_write", + extra={ + "serviceName": "google.firestore.v1.Firestore", + "rpcName": "BatchWrite", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) return resp - class _BeginTransaction(FirestoreRestStub): + class _BeginTransaction( + _BaseFirestoreRestTransport._BaseBeginTransaction, FirestoreRestStub + ): def __hash__(self): - return hash("BeginTransaction") - - __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} - - @classmethod - def _get_unset_required_fields(cls, message_dict): - return { - k: v - for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() - if k not in message_dict - } + return hash("FirestoreRestTransport.BeginTransaction") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + return response def __call__( self, @@ -902,7 +1368,7 @@ def __call__( *, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> firestore.BeginTransactionResponse: r"""Call the begin transaction method over HTTP. @@ -913,8 +1379,10 @@ def __call__( retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. Returns: ~.firestore.BeginTransactionResponse: @@ -923,47 +1391,62 @@ def __call__( """ - http_options: List[Dict[str, str]] = [ - { - "method": "post", - "uri": "/v1/{database=projects/*/databases/*}/documents:beginTransaction", - "body": "*", - }, - ] + http_options = ( + _BaseFirestoreRestTransport._BaseBeginTransaction._get_http_options() + ) + request, metadata = self._interceptor.pre_begin_transaction( request, metadata ) - pb_request = firestore.BeginTransactionRequest.pb(request) - transcoded_request = path_template.transcode(http_options, pb_request) - - # Jsonify the request body + transcoded_request = _BaseFirestoreRestTransport._BaseBeginTransaction._get_transcoded_request( + http_options, request + ) - body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + body = _BaseFirestoreRestTransport._BaseBeginTransaction._get_request_body_json( + transcoded_request ) - uri = transcoded_request["uri"] - method = transcoded_request["method"] # Jsonify the query params - query_params = json.loads( - json_format.MessageToJson( - transcoded_request["query_params"], - use_integers_for_enums=True, - ) + query_params = _BaseFirestoreRestTransport._BaseBeginTransaction._get_query_params_json( + transcoded_request ) - query_params.update(self._get_unset_required_fields(query_params)) - query_params["$alt"] = "json;enum-encoding=int" + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.firestore_v1.FirestoreClient.BeginTransaction", + extra={ + "serviceName": "google.firestore.v1.Firestore", + "rpcName": "BeginTransaction", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) # Send the request - headers = dict(metadata) - headers["Content-Type"] = "application/json" - response = getattr(self._session, method)( - "{host}{uri}".format(host=self._host, uri=uri), - timeout=timeout, - headers=headers, - params=rest_helpers.flatten_query_params(query_params, strict=True), - data=body, + response = FirestoreRestTransport._BeginTransaction._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + body, ) # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception @@ -976,30 +1459,71 @@ def __call__( pb_resp = firestore.BeginTransactionResponse.pb(resp) json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + resp = self._interceptor.post_begin_transaction(resp) + response_metadata = [(k, str(v)) for k, v in response.headers.items()] + resp, _ = self._interceptor.post_begin_transaction_with_metadata( + resp, response_metadata + ) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = firestore.BeginTransactionResponse.to_json( + response + ) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.firestore_v1.FirestoreClient.begin_transaction", + extra={ + "serviceName": "google.firestore.v1.Firestore", + "rpcName": "BeginTransaction", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) return resp - class _Commit(FirestoreRestStub): + class _Commit(_BaseFirestoreRestTransport._BaseCommit, FirestoreRestStub): def __hash__(self): - return hash("Commit") - - __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} - - @classmethod - def _get_unset_required_fields(cls, message_dict): - return { - k: v - for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() - if k not in message_dict - } - + return hash("FirestoreRestTransport.Commit") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + return response + def __call__( self, request: firestore.CommitRequest, *, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> firestore.CommitResponse: r"""Call the commit method over HTTP. @@ -1010,8 +1534,10 @@ def __call__( retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. Returns: ~.firestore.CommitResponse: @@ -1020,45 +1546,62 @@ def __call__( """ - http_options: List[Dict[str, str]] = [ - { - "method": "post", - "uri": "/v1/{database=projects/*/databases/*}/documents:commit", - "body": "*", - }, - ] - request, metadata = self._interceptor.pre_commit(request, metadata) - pb_request = firestore.CommitRequest.pb(request) - transcoded_request = path_template.transcode(http_options, pb_request) + http_options = _BaseFirestoreRestTransport._BaseCommit._get_http_options() - # Jsonify the request body + request, metadata = self._interceptor.pre_commit(request, metadata) + transcoded_request = ( + _BaseFirestoreRestTransport._BaseCommit._get_transcoded_request( + http_options, request + ) + ) - body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + body = _BaseFirestoreRestTransport._BaseCommit._get_request_body_json( + transcoded_request ) - uri = transcoded_request["uri"] - method = transcoded_request["method"] # Jsonify the query params - query_params = json.loads( - json_format.MessageToJson( - transcoded_request["query_params"], - use_integers_for_enums=True, + query_params = ( + _BaseFirestoreRestTransport._BaseCommit._get_query_params_json( + transcoded_request ) ) - query_params.update(self._get_unset_required_fields(query_params)) - query_params["$alt"] = "json;enum-encoding=int" + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.firestore_v1.FirestoreClient.Commit", + extra={ + "serviceName": "google.firestore.v1.Firestore", + "rpcName": "Commit", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) # Send the request - headers = dict(metadata) - headers["Content-Type"] = "application/json" - response = getattr(self._session, method)( - "{host}{uri}".format(host=self._host, uri=uri), - timeout=timeout, - headers=headers, - params=rest_helpers.flatten_query_params(query_params, strict=True), - data=body, + response = FirestoreRestTransport._Commit._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + body, ) # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception @@ -1071,22 +1614,63 @@ def __call__( pb_resp = firestore.CommitResponse.pb(resp) json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + resp = self._interceptor.post_commit(resp) + response_metadata = [(k, str(v)) for k, v in response.headers.items()] + resp, _ = self._interceptor.post_commit_with_metadata( + resp, response_metadata + ) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = firestore.CommitResponse.to_json(response) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.firestore_v1.FirestoreClient.commit", + extra={ + "serviceName": "google.firestore.v1.Firestore", + "rpcName": "Commit", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) return resp - class _CreateDocument(FirestoreRestStub): + class _CreateDocument( + _BaseFirestoreRestTransport._BaseCreateDocument, FirestoreRestStub + ): def __hash__(self): - return hash("CreateDocument") - - __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} - - @classmethod - def _get_unset_required_fields(cls, message_dict): - return { - k: v - for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() - if k not in message_dict - } + return hash("FirestoreRestTransport.CreateDocument") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + return response def __call__( self, @@ -1094,7 +1678,7 @@ def __call__( *, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> document.Document: r"""Call the create document method over HTTP. @@ -1105,8 +1689,10 @@ def __call__( retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. Returns: ~.document.Document: @@ -1116,47 +1702,66 @@ def __call__( """ - http_options: List[Dict[str, str]] = [ - { - "method": "post", - "uri": "/v1/{parent=projects/*/databases/*/documents/**}/{collection_id}", - "body": "document", - }, - ] - request, metadata = self._interceptor.pre_create_document(request, metadata) - pb_request = firestore.CreateDocumentRequest.pb(request) - transcoded_request = path_template.transcode(http_options, pb_request) + http_options = ( + _BaseFirestoreRestTransport._BaseCreateDocument._get_http_options() + ) - # Jsonify the request body + request, metadata = self._interceptor.pre_create_document(request, metadata) + transcoded_request = ( + _BaseFirestoreRestTransport._BaseCreateDocument._get_transcoded_request( + http_options, request + ) + ) - body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + body = ( + _BaseFirestoreRestTransport._BaseCreateDocument._get_request_body_json( + transcoded_request + ) ) - uri = transcoded_request["uri"] - method = transcoded_request["method"] # Jsonify the query params - query_params = json.loads( - json_format.MessageToJson( - transcoded_request["query_params"], - use_integers_for_enums=True, + query_params = ( + _BaseFirestoreRestTransport._BaseCreateDocument._get_query_params_json( + transcoded_request ) ) - query_params.update(self._get_unset_required_fields(query_params)) - query_params["$alt"] = "json;enum-encoding=int" + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.firestore_v1.FirestoreClient.CreateDocument", + extra={ + "serviceName": "google.firestore.v1.Firestore", + "rpcName": "CreateDocument", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) - class _ExecutePipeline(FirestoreRestStub): - def __hash__(self): # Send the request - headers = dict(metadata) - headers["Content-Type"] = "application/json" - response = getattr(self._session, method)( - "{host}{uri}".format(host=self._host, uri=uri), - timeout=timeout, - headers=headers, - params=rest_helpers.flatten_query_params(query_params, strict=True), - data=body, + response = FirestoreRestTransport._CreateDocument._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + body, ) # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception @@ -1169,22 +1774,62 @@ def __hash__(self): pb_resp = document.Document.pb(resp) json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + resp = self._interceptor.post_create_document(resp) + response_metadata = [(k, str(v)) for k, v in response.headers.items()] + resp, _ = self._interceptor.post_create_document_with_metadata( + resp, response_metadata + ) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = document.Document.to_json(response) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.firestore_v1.FirestoreClient.create_document", + extra={ + "serviceName": "google.firestore.v1.Firestore", + "rpcName": "CreateDocument", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) return resp - class _DeleteDocument(FirestoreRestStub): + class _DeleteDocument( + _BaseFirestoreRestTransport._BaseDeleteDocument, FirestoreRestStub + ): def __hash__(self): - return hash("DeleteDocument") - - __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} - - @classmethod - def _get_unset_required_fields(cls, message_dict): - return { - k: v - for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() - if k not in message_dict - } + return hash("FirestoreRestTransport.DeleteDocument") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + return response def __call__( self, @@ -1192,7 +1837,7 @@ def __call__( *, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ): r"""Call the delete document method over HTTP. @@ -1203,47 +1848,101 @@ def __call__( retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. """ - http_options: List[Dict[str, str]] = [ - { - "method": "delete", - "uri": "/v1/{name=projects/*/databases/*/documents/*/**}", - }, - ] - request, metadata = self._interceptor.pre_delete_document(request, metadata) - pb_request = firestore.DeleteDocumentRequest.pb(request) - transcoded_request = path_template.transcode(http_options, pb_request) + http_options = ( + _BaseFirestoreRestTransport._BaseDeleteDocument._get_http_options() + ) - uri = transcoded_request["uri"] - method = transcoded_request["method"] + request, metadata = self._interceptor.pre_delete_document(request, metadata) + transcoded_request = ( + _BaseFirestoreRestTransport._BaseDeleteDocument._get_transcoded_request( + http_options, request + ) + ) # Jsonify the query params - query_params = json.loads( - json_format.MessageToJson( - transcoded_request["query_params"], - use_integers_for_enums=True, + query_params = ( + _BaseFirestoreRestTransport._BaseDeleteDocument._get_query_params_json( + transcoded_request ) ) - query_params.update(self._get_unset_required_fields(query_params)) - query_params["$alt"] = "json;enum-encoding=int" + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = json_format.MessageToJson(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.firestore_v1.FirestoreClient.DeleteDocument", + extra={ + "serviceName": "google.firestore.v1.Firestore", + "rpcName": "DeleteDocument", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) - class _ExecutePipeline(FirestoreRestStub): - def __hash__(self): - return hash("ExecutePipeline") + # Send the request + response = FirestoreRestTransport._DeleteDocument._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + ) - __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) - @classmethod - def _get_unset_required_fields(cls, message_dict): - return { - k: v - for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() - if k not in message_dict - } + class _ExecutePipeline( + _BaseFirestoreRestTransport._BaseExecutePipeline, FirestoreRestStub + ): + def __hash__(self): + return hash("FirestoreRestTransport.ExecutePipeline") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + stream=True, + ) + return response def __call__( self, @@ -1251,7 +1950,7 @@ def __call__( *, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> rest_streaming.ResponseIterator: r"""Call the execute pipeline method over HTTP. @@ -1262,56 +1961,76 @@ def __call__( retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. Returns: ~.firestore.ExecutePipelineResponse: The response for [Firestore.Execute][]. """ - http_options: List[Dict[str, str]] = [ - { - "method": "post", - "uri": "/v1/{database=projects/*/databases/*}/documents:executePipeline", - "body": "*", - }, - ] + http_options = ( + _BaseFirestoreRestTransport._BaseExecutePipeline._get_http_options() + ) + request, metadata = self._interceptor.pre_execute_pipeline( request, metadata ) - pb_request = firestore.ExecutePipelineRequest.pb(request) - transcoded_request = path_template.transcode(http_options, pb_request) - - # Jsonify the request body + transcoded_request = _BaseFirestoreRestTransport._BaseExecutePipeline._get_transcoded_request( + http_options, request + ) - body = json_format.MessageToJson( - transcoded_request["body"], - including_default_value_fields=False, - use_integers_for_enums=False, + body = ( + _BaseFirestoreRestTransport._BaseExecutePipeline._get_request_body_json( + transcoded_request + ) ) - uri = transcoded_request["uri"] - method = transcoded_request["method"] # Jsonify the query params - query_params = json.loads( - json_format.MessageToJson( - transcoded_request["query_params"], - including_default_value_fields=False, - use_integers_for_enums=False, + query_params = ( + _BaseFirestoreRestTransport._BaseExecutePipeline._get_query_params_json( + transcoded_request ) ) - query_params.update(self._get_unset_required_fields(query_params)) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.firestore_v1.FirestoreClient.ExecutePipeline", + extra={ + "serviceName": "google.firestore.v1.Firestore", + "rpcName": "ExecutePipeline", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) # Send the request - headers = dict(metadata) - headers["Content-Type"] = "application/json" - response = getattr(self._session, method)( - "{host}{uri}".format(host=self._host, uri=uri), - timeout=timeout, - headers=headers, - params=rest_helpers.flatten_query_params(query_params, strict=True), - data=body, + response = FirestoreRestTransport._ExecutePipeline._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + body, ) # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception @@ -1323,22 +2042,39 @@ def __call__( resp = rest_streaming.ResponseIterator( response, firestore.ExecutePipelineResponse ) + resp = self._interceptor.post_execute_pipeline(resp) + response_metadata = [(k, str(v)) for k, v in response.headers.items()] + resp, _ = self._interceptor.post_execute_pipeline_with_metadata( + resp, response_metadata + ) return resp - class _GetDocument(FirestoreRestStub): + class _GetDocument(_BaseFirestoreRestTransport._BaseGetDocument, FirestoreRestStub): def __hash__(self): - return hash("GetDocument") - - __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} - - @classmethod - def _get_unset_required_fields(cls, message_dict): - return { - k: v - for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() - if k not in message_dict - } + return hash("FirestoreRestTransport.GetDocument") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + return response def __call__( self, @@ -1346,7 +2082,7 @@ def __call__( *, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> document.Document: r"""Call the get document method over HTTP. @@ -1357,8 +2093,10 @@ def __call__( retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. Returns: ~.document.Document: @@ -1368,38 +2106,59 @@ def __call__( """ - http_options: List[Dict[str, str]] = [ - { - "method": "get", - "uri": "/v1/{name=projects/*/databases/*/documents/*/**}", - }, - ] - request, metadata = self._interceptor.pre_get_document(request, metadata) - pb_request = firestore.GetDocumentRequest.pb(request) - transcoded_request = path_template.transcode(http_options, pb_request) + http_options = ( + _BaseFirestoreRestTransport._BaseGetDocument._get_http_options() + ) - uri = transcoded_request["uri"] - method = transcoded_request["method"] + request, metadata = self._interceptor.pre_get_document(request, metadata) + transcoded_request = ( + _BaseFirestoreRestTransport._BaseGetDocument._get_transcoded_request( + http_options, request + ) + ) # Jsonify the query params - query_params = json.loads( - json_format.MessageToJson( - transcoded_request["query_params"], - use_integers_for_enums=True, + query_params = ( + _BaseFirestoreRestTransport._BaseGetDocument._get_query_params_json( + transcoded_request ) ) - query_params.update(self._get_unset_required_fields(query_params)) - query_params["$alt"] = "json;enum-encoding=int" + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.firestore_v1.FirestoreClient.GetDocument", + extra={ + "serviceName": "google.firestore.v1.Firestore", + "rpcName": "GetDocument", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) # Send the request - headers = dict(metadata) - headers["Content-Type"] = "application/json" - response = getattr(self._session, method)( - "{host}{uri}".format(host=self._host, uri=uri), - timeout=timeout, - headers=headers, - params=rest_helpers.flatten_query_params(query_params, strict=True), + response = FirestoreRestTransport._GetDocument._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, ) # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception @@ -1412,22 +2171,63 @@ def __call__( pb_resp = document.Document.pb(resp) json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + resp = self._interceptor.post_get_document(resp) + response_metadata = [(k, str(v)) for k, v in response.headers.items()] + resp, _ = self._interceptor.post_get_document_with_metadata( + resp, response_metadata + ) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = document.Document.to_json(response) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.firestore_v1.FirestoreClient.get_document", + extra={ + "serviceName": "google.firestore.v1.Firestore", + "rpcName": "GetDocument", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) return resp - class _ListCollectionIds(FirestoreRestStub): + class _ListCollectionIds( + _BaseFirestoreRestTransport._BaseListCollectionIds, FirestoreRestStub + ): def __hash__(self): - return hash("ListCollectionIds") - - __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} - - @classmethod - def _get_unset_required_fields(cls, message_dict): - return { - k: v - for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() - if k not in message_dict - } + return hash("FirestoreRestTransport.ListCollectionIds") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + return response def __call__( self, @@ -1435,7 +2235,7 @@ def __call__( *, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> firestore.ListCollectionIdsResponse: r"""Call the list collection ids method over HTTP. @@ -1446,8 +2246,10 @@ def __call__( retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. Returns: ~.firestore.ListCollectionIdsResponse: @@ -1456,52 +2258,62 @@ def __call__( """ - http_options: List[Dict[str, str]] = [ - { - "method": "post", - "uri": "/v1/{parent=projects/*/databases/*/documents}:listCollectionIds", - "body": "*", - }, - { - "method": "post", - "uri": "/v1/{parent=projects/*/databases/*/documents/*/**}:listCollectionIds", - "body": "*", - }, - ] + http_options = ( + _BaseFirestoreRestTransport._BaseListCollectionIds._get_http_options() + ) + request, metadata = self._interceptor.pre_list_collection_ids( request, metadata ) - pb_request = firestore.ListCollectionIdsRequest.pb(request) - transcoded_request = path_template.transcode(http_options, pb_request) - - # Jsonify the request body + transcoded_request = _BaseFirestoreRestTransport._BaseListCollectionIds._get_transcoded_request( + http_options, request + ) - body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + body = _BaseFirestoreRestTransport._BaseListCollectionIds._get_request_body_json( + transcoded_request ) - uri = transcoded_request["uri"] - method = transcoded_request["method"] # Jsonify the query params - query_params = json.loads( - json_format.MessageToJson( - transcoded_request["query_params"], - use_integers_for_enums=True, - ) + query_params = _BaseFirestoreRestTransport._BaseListCollectionIds._get_query_params_json( + transcoded_request ) - query_params.update(self._get_unset_required_fields(query_params)) - query_params["$alt"] = "json;enum-encoding=int" + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.firestore_v1.FirestoreClient.ListCollectionIds", + extra={ + "serviceName": "google.firestore.v1.Firestore", + "rpcName": "ListCollectionIds", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) # Send the request - headers = dict(metadata) - headers["Content-Type"] = "application/json" - response = getattr(self._session, method)( - "{host}{uri}".format(host=self._host, uri=uri), - timeout=timeout, - headers=headers, - params=rest_helpers.flatten_query_params(query_params, strict=True), - data=body, + response = FirestoreRestTransport._ListCollectionIds._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + body, ) # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception @@ -1514,22 +2326,64 @@ def __call__( pb_resp = firestore.ListCollectionIdsResponse.pb(resp) json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + resp = self._interceptor.post_list_collection_ids(resp) + response_metadata = [(k, str(v)) for k, v in response.headers.items()] + resp, _ = self._interceptor.post_list_collection_ids_with_metadata( + resp, response_metadata + ) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = firestore.ListCollectionIdsResponse.to_json( + response + ) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.firestore_v1.FirestoreClient.list_collection_ids", + extra={ + "serviceName": "google.firestore.v1.Firestore", + "rpcName": "ListCollectionIds", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) return resp - class _ListDocuments(FirestoreRestStub): + class _ListDocuments( + _BaseFirestoreRestTransport._BaseListDocuments, FirestoreRestStub + ): def __hash__(self): - return hash("ListDocuments") - - __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} - - @classmethod - def _get_unset_required_fields(cls, message_dict): - return { - k: v - for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() - if k not in message_dict - } + return hash("FirestoreRestTransport.ListDocuments") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + return response def __call__( self, @@ -1537,7 +2391,7 @@ def __call__( *, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> firestore.ListDocumentsResponse: r"""Call the list documents method over HTTP. @@ -1548,8 +2402,10 @@ def __call__( retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. Returns: ~.firestore.ListDocumentsResponse: @@ -1558,42 +2414,59 @@ def __call__( """ - http_options: List[Dict[str, str]] = [ - { - "method": "get", - "uri": "/v1/{parent=projects/*/databases/*/documents/*/**}/{collection_id}", - }, - { - "method": "get", - "uri": "/v1/{parent=projects/*/databases/*/documents}/{collection_id}", - }, - ] - request, metadata = self._interceptor.pre_list_documents(request, metadata) - pb_request = firestore.ListDocumentsRequest.pb(request) - transcoded_request = path_template.transcode(http_options, pb_request) + http_options = ( + _BaseFirestoreRestTransport._BaseListDocuments._get_http_options() + ) - uri = transcoded_request["uri"] - method = transcoded_request["method"] + request, metadata = self._interceptor.pre_list_documents(request, metadata) + transcoded_request = ( + _BaseFirestoreRestTransport._BaseListDocuments._get_transcoded_request( + http_options, request + ) + ) # Jsonify the query params - query_params = json.loads( - json_format.MessageToJson( - transcoded_request["query_params"], - use_integers_for_enums=True, + query_params = ( + _BaseFirestoreRestTransport._BaseListDocuments._get_query_params_json( + transcoded_request ) ) - query_params.update(self._get_unset_required_fields(query_params)) - query_params["$alt"] = "json;enum-encoding=int" + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.firestore_v1.FirestoreClient.ListDocuments", + extra={ + "serviceName": "google.firestore.v1.Firestore", + "rpcName": "ListDocuments", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) # Send the request - headers = dict(metadata) - headers["Content-Type"] = "application/json" - response = getattr(self._session, method)( - "{host}{uri}".format(host=self._host, uri=uri), - timeout=timeout, - headers=headers, - params=rest_helpers.flatten_query_params(query_params, strict=True), + response = FirestoreRestTransport._ListDocuments._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, ) # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception @@ -1606,12 +2479,38 @@ def __call__( pb_resp = firestore.ListDocumentsResponse.pb(resp) json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + resp = self._interceptor.post_list_documents(resp) + response_metadata = [(k, str(v)) for k, v in response.headers.items()] + resp, _ = self._interceptor.post_list_documents_with_metadata( + resp, response_metadata + ) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = firestore.ListDocumentsResponse.to_json(response) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.firestore_v1.FirestoreClient.list_documents", + extra={ + "serviceName": "google.firestore.v1.Firestore", + "rpcName": "ListDocuments", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) return resp - class _Listen(FirestoreRestStub): + class _Listen(_BaseFirestoreRestTransport._BaseListen, FirestoreRestStub): def __hash__(self): - return hash("Listen") + return hash("FirestoreRestTransport.Listen") def __call__( self, @@ -1619,25 +2518,40 @@ def __call__( *, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> rest_streaming.ResponseIterator: raise NotImplementedError( "Method Listen is not available over REST transport" ) - class _PartitionQuery(FirestoreRestStub): + class _PartitionQuery( + _BaseFirestoreRestTransport._BasePartitionQuery, FirestoreRestStub + ): def __hash__(self): - return hash("PartitionQuery") - - __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} - - @classmethod - def _get_unset_required_fields(cls, message_dict): - return { - k: v - for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() - if k not in message_dict - } + return hash("FirestoreRestTransport.PartitionQuery") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + return response def __call__( self, @@ -1645,7 +2559,7 @@ def __call__( *, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> firestore.PartitionQueryResponse: r"""Call the partition query method over HTTP. @@ -1656,8 +2570,10 @@ def __call__( retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. Returns: ~.firestore.PartitionQueryResponse: @@ -1666,50 +2582,66 @@ def __call__( """ - http_options: List[Dict[str, str]] = [ - { - "method": "post", - "uri": "/v1/{parent=projects/*/databases/*/documents}:partitionQuery", - "body": "*", - }, - { - "method": "post", - "uri": "/v1/{parent=projects/*/databases/*/documents/*/**}:partitionQuery", - "body": "*", - }, - ] - request, metadata = self._interceptor.pre_partition_query(request, metadata) - pb_request = firestore.PartitionQueryRequest.pb(request) - transcoded_request = path_template.transcode(http_options, pb_request) + http_options = ( + _BaseFirestoreRestTransport._BasePartitionQuery._get_http_options() + ) - # Jsonify the request body + request, metadata = self._interceptor.pre_partition_query(request, metadata) + transcoded_request = ( + _BaseFirestoreRestTransport._BasePartitionQuery._get_transcoded_request( + http_options, request + ) + ) - body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + body = ( + _BaseFirestoreRestTransport._BasePartitionQuery._get_request_body_json( + transcoded_request + ) ) - uri = transcoded_request["uri"] - method = transcoded_request["method"] # Jsonify the query params - query_params = json.loads( - json_format.MessageToJson( - transcoded_request["query_params"], - use_integers_for_enums=True, + query_params = ( + _BaseFirestoreRestTransport._BasePartitionQuery._get_query_params_json( + transcoded_request ) ) - query_params.update(self._get_unset_required_fields(query_params)) - query_params["$alt"] = "json;enum-encoding=int" + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.firestore_v1.FirestoreClient.PartitionQuery", + extra={ + "serviceName": "google.firestore.v1.Firestore", + "rpcName": "PartitionQuery", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) # Send the request - headers = dict(metadata) - headers["Content-Type"] = "application/json" - response = getattr(self._session, method)( - "{host}{uri}".format(host=self._host, uri=uri), - timeout=timeout, - headers=headers, - params=rest_helpers.flatten_query_params(query_params, strict=True), - data=body, + response = FirestoreRestTransport._PartitionQuery._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + body, ) # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception @@ -1722,22 +2654,63 @@ def __call__( pb_resp = firestore.PartitionQueryResponse.pb(resp) json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + resp = self._interceptor.post_partition_query(resp) + response_metadata = [(k, str(v)) for k, v in response.headers.items()] + resp, _ = self._interceptor.post_partition_query_with_metadata( + resp, response_metadata + ) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = firestore.PartitionQueryResponse.to_json( + response + ) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.firestore_v1.FirestoreClient.partition_query", + extra={ + "serviceName": "google.firestore.v1.Firestore", + "rpcName": "PartitionQuery", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) return resp - class _Rollback(FirestoreRestStub): + class _Rollback(_BaseFirestoreRestTransport._BaseRollback, FirestoreRestStub): def __hash__(self): - return hash("Rollback") - - __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} - - @classmethod - def _get_unset_required_fields(cls, message_dict): - return { - k: v - for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() - if k not in message_dict - } + return hash("FirestoreRestTransport.Rollback") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + return response def __call__( self, @@ -1745,7 +2718,7 @@ def __call__( *, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ): r"""Call the rollback method over HTTP. @@ -1756,49 +2729,68 @@ def __call__( retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. """ - http_options: List[Dict[str, str]] = [ - { - "method": "post", - "uri": "/v1/{database=projects/*/databases/*}/documents:rollback", - "body": "*", - }, - ] - request, metadata = self._interceptor.pre_rollback(request, metadata) - pb_request = firestore.RollbackRequest.pb(request) - transcoded_request = path_template.transcode(http_options, pb_request) + http_options = _BaseFirestoreRestTransport._BaseRollback._get_http_options() - # Jsonify the request body + request, metadata = self._interceptor.pre_rollback(request, metadata) + transcoded_request = ( + _BaseFirestoreRestTransport._BaseRollback._get_transcoded_request( + http_options, request + ) + ) - body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + body = _BaseFirestoreRestTransport._BaseRollback._get_request_body_json( + transcoded_request ) - uri = transcoded_request["uri"] - method = transcoded_request["method"] # Jsonify the query params - query_params = json.loads( - json_format.MessageToJson( - transcoded_request["query_params"], - use_integers_for_enums=True, + query_params = ( + _BaseFirestoreRestTransport._BaseRollback._get_query_params_json( + transcoded_request ) ) - query_params.update(self._get_unset_required_fields(query_params)) - query_params["$alt"] = "json;enum-encoding=int" + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = json_format.MessageToJson(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.firestore_v1.FirestoreClient.Rollback", + extra={ + "serviceName": "google.firestore.v1.Firestore", + "rpcName": "Rollback", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) # Send the request - headers = dict(metadata) - headers["Content-Type"] = "application/json" - response = getattr(self._session, method)( - "{host}{uri}".format(host=self._host, uri=uri), - timeout=timeout, - headers=headers, - params=rest_helpers.flatten_query_params(query_params, strict=True), - data=body, + response = FirestoreRestTransport._Rollback._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + body, ) # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception @@ -1806,19 +2798,35 @@ def __call__( if response.status_code >= 400: raise core_exceptions.from_http_response(response) - class _RunAggregationQuery(FirestoreRestStub): + class _RunAggregationQuery( + _BaseFirestoreRestTransport._BaseRunAggregationQuery, FirestoreRestStub + ): def __hash__(self): - return hash("RunAggregationQuery") - - __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} - - @classmethod - def _get_unset_required_fields(cls, message_dict): - return { - k: v - for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() - if k not in message_dict - } + return hash("FirestoreRestTransport.RunAggregationQuery") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + stream=True, + ) + return response def __call__( self, @@ -1826,7 +2834,7 @@ def __call__( *, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> rest_streaming.ResponseIterator: r"""Call the run aggregation query method over HTTP. @@ -1837,8 +2845,10 @@ def __call__( retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. Returns: ~.firestore.RunAggregationQueryResponse: @@ -1847,52 +2857,62 @@ def __call__( """ - http_options: List[Dict[str, str]] = [ - { - "method": "post", - "uri": "/v1/{parent=projects/*/databases/*/documents}:runAggregationQuery", - "body": "*", - }, - { - "method": "post", - "uri": "/v1/{parent=projects/*/databases/*/documents/*/**}:runAggregationQuery", - "body": "*", - }, - ] + http_options = ( + _BaseFirestoreRestTransport._BaseRunAggregationQuery._get_http_options() + ) + request, metadata = self._interceptor.pre_run_aggregation_query( request, metadata ) - pb_request = firestore.RunAggregationQueryRequest.pb(request) - transcoded_request = path_template.transcode(http_options, pb_request) - - # Jsonify the request body + transcoded_request = _BaseFirestoreRestTransport._BaseRunAggregationQuery._get_transcoded_request( + http_options, request + ) - body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + body = _BaseFirestoreRestTransport._BaseRunAggregationQuery._get_request_body_json( + transcoded_request ) - uri = transcoded_request["uri"] - method = transcoded_request["method"] # Jsonify the query params - query_params = json.loads( - json_format.MessageToJson( - transcoded_request["query_params"], - use_integers_for_enums=True, - ) + query_params = _BaseFirestoreRestTransport._BaseRunAggregationQuery._get_query_params_json( + transcoded_request ) - query_params.update(self._get_unset_required_fields(query_params)) - query_params["$alt"] = "json;enum-encoding=int" + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.firestore_v1.FirestoreClient.RunAggregationQuery", + extra={ + "serviceName": "google.firestore.v1.Firestore", + "rpcName": "RunAggregationQuery", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) # Send the request - headers = dict(metadata) - headers["Content-Type"] = "application/json" - response = getattr(self._session, method)( - "{host}{uri}".format(host=self._host, uri=uri), - timeout=timeout, - headers=headers, - params=rest_helpers.flatten_query_params(query_params, strict=True), - data=body, + response = FirestoreRestTransport._RunAggregationQuery._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + body, ) # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception @@ -1904,22 +2924,41 @@ def __call__( resp = rest_streaming.ResponseIterator( response, firestore.RunAggregationQueryResponse ) + resp = self._interceptor.post_run_aggregation_query(resp) + response_metadata = [(k, str(v)) for k, v in response.headers.items()] + resp, _ = self._interceptor.post_run_aggregation_query_with_metadata( + resp, response_metadata + ) return resp - class _RunQuery(FirestoreRestStub): + class _RunQuery(_BaseFirestoreRestTransport._BaseRunQuery, FirestoreRestStub): def __hash__(self): - return hash("RunQuery") - - __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} - - @classmethod - def _get_unset_required_fields(cls, message_dict): - return { - k: v - for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() - if k not in message_dict - } + return hash("FirestoreRestTransport.RunQuery") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + stream=True, + ) + return response def __call__( self, @@ -1927,7 +2966,7 @@ def __call__( *, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> rest_streaming.ResponseIterator: r"""Call the run query method over HTTP. @@ -1938,8 +2977,10 @@ def __call__( retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. Returns: ~.firestore.RunQueryResponse: @@ -1948,50 +2989,62 @@ def __call__( """ - http_options: List[Dict[str, str]] = [ - { - "method": "post", - "uri": "/v1/{parent=projects/*/databases/*/documents}:runQuery", - "body": "*", - }, - { - "method": "post", - "uri": "/v1/{parent=projects/*/databases/*/documents/*/**}:runQuery", - "body": "*", - }, - ] - request, metadata = self._interceptor.pre_run_query(request, metadata) - pb_request = firestore.RunQueryRequest.pb(request) - transcoded_request = path_template.transcode(http_options, pb_request) + http_options = _BaseFirestoreRestTransport._BaseRunQuery._get_http_options() - # Jsonify the request body + request, metadata = self._interceptor.pre_run_query(request, metadata) + transcoded_request = ( + _BaseFirestoreRestTransport._BaseRunQuery._get_transcoded_request( + http_options, request + ) + ) - body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + body = _BaseFirestoreRestTransport._BaseRunQuery._get_request_body_json( + transcoded_request ) - uri = transcoded_request["uri"] - method = transcoded_request["method"] # Jsonify the query params - query_params = json.loads( - json_format.MessageToJson( - transcoded_request["query_params"], - use_integers_for_enums=True, + query_params = ( + _BaseFirestoreRestTransport._BaseRunQuery._get_query_params_json( + transcoded_request ) ) - query_params.update(self._get_unset_required_fields(query_params)) - query_params["$alt"] = "json;enum-encoding=int" + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.firestore_v1.FirestoreClient.RunQuery", + extra={ + "serviceName": "google.firestore.v1.Firestore", + "rpcName": "RunQuery", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) # Send the request - headers = dict(metadata) - headers["Content-Type"] = "application/json" - response = getattr(self._session, method)( - "{host}{uri}".format(host=self._host, uri=uri), - timeout=timeout, - headers=headers, - params=rest_helpers.flatten_query_params(query_params, strict=True), - data=body, + response = FirestoreRestTransport._RunQuery._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + body, ) # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception @@ -2001,22 +3054,42 @@ def __call__( # Return the response resp = rest_streaming.ResponseIterator(response, firestore.RunQueryResponse) + resp = self._interceptor.post_run_query(resp) + response_metadata = [(k, str(v)) for k, v in response.headers.items()] + resp, _ = self._interceptor.post_run_query_with_metadata( + resp, response_metadata + ) return resp - class _UpdateDocument(FirestoreRestStub): + class _UpdateDocument( + _BaseFirestoreRestTransport._BaseUpdateDocument, FirestoreRestStub + ): def __hash__(self): - return hash("UpdateDocument") - - __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} - - @classmethod - def _get_unset_required_fields(cls, message_dict): - return { - k: v - for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() - if k not in message_dict - } + return hash("FirestoreRestTransport.UpdateDocument") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + return response def __call__( self, @@ -2024,7 +3097,7 @@ def __call__( *, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> gf_document.Document: r"""Call the update document method over HTTP. @@ -2035,8 +3108,10 @@ def __call__( retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. Returns: ~.gf_document.Document: @@ -2046,45 +3121,66 @@ def __call__( """ - http_options: List[Dict[str, str]] = [ - { - "method": "patch", - "uri": "/v1/{document.name=projects/*/databases/*/documents/*/**}", - "body": "document", - }, - ] - request, metadata = self._interceptor.pre_update_document(request, metadata) - pb_request = firestore.UpdateDocumentRequest.pb(request) - transcoded_request = path_template.transcode(http_options, pb_request) + http_options = ( + _BaseFirestoreRestTransport._BaseUpdateDocument._get_http_options() + ) - # Jsonify the request body + request, metadata = self._interceptor.pre_update_document(request, metadata) + transcoded_request = ( + _BaseFirestoreRestTransport._BaseUpdateDocument._get_transcoded_request( + http_options, request + ) + ) - body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + body = ( + _BaseFirestoreRestTransport._BaseUpdateDocument._get_request_body_json( + transcoded_request + ) ) - uri = transcoded_request["uri"] - method = transcoded_request["method"] # Jsonify the query params - query_params = json.loads( - json_format.MessageToJson( - transcoded_request["query_params"], - use_integers_for_enums=True, + query_params = ( + _BaseFirestoreRestTransport._BaseUpdateDocument._get_query_params_json( + transcoded_request ) ) - query_params.update(self._get_unset_required_fields(query_params)) - query_params["$alt"] = "json;enum-encoding=int" + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.firestore_v1.FirestoreClient.UpdateDocument", + extra={ + "serviceName": "google.firestore.v1.Firestore", + "rpcName": "UpdateDocument", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) # Send the request - headers = dict(metadata) - headers["Content-Type"] = "application/json" - response = getattr(self._session, method)( - "{host}{uri}".format(host=self._host, uri=uri), - timeout=timeout, - headers=headers, - params=rest_helpers.flatten_query_params(query_params, strict=True), - data=body, + response = FirestoreRestTransport._UpdateDocument._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + body, ) # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception @@ -2097,12 +3193,38 @@ def __call__( pb_resp = gf_document.Document.pb(resp) json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + resp = self._interceptor.post_update_document(resp) + response_metadata = [(k, str(v)) for k, v in response.headers.items()] + resp, _ = self._interceptor.post_update_document_with_metadata( + resp, response_metadata + ) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = gf_document.Document.to_json(response) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.firestore_v1.FirestoreClient.update_document", + extra={ + "serviceName": "google.firestore.v1.Firestore", + "rpcName": "UpdateDocument", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) return resp - class _Write(FirestoreRestStub): + class _Write(_BaseFirestoreRestTransport._BaseWrite, FirestoreRestStub): def __hash__(self): - return hash("Write") + return hash("FirestoreRestTransport.Write") def __call__( self, @@ -2110,7 +3232,7 @@ def __call__( *, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> rest_streaming.ResponseIterator: raise NotImplementedError( "Method Write is not available over REST transport" @@ -2258,14 +3380,42 @@ def write(self) -> Callable[[firestore.WriteRequest], firestore.WriteResponse]: def cancel_operation(self): return self._CancelOperation(self._session, self._host, self._interceptor) # type: ignore - class _CancelOperation(FirestoreRestStub): + class _CancelOperation( + _BaseFirestoreRestTransport._BaseCancelOperation, FirestoreRestStub + ): + def __hash__(self): + return hash("FirestoreRestTransport.CancelOperation") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + return response + def __call__( self, request: operations_pb2.CancelOperationRequest, *, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> None: r"""Call the cancel operation method over HTTP. @@ -2275,41 +3425,72 @@ def __call__( retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. """ - http_options: List[Dict[str, str]] = [ - { - "method": "post", - "uri": "/v1/{name=projects/*/databases/*/operations/*}:cancel", - "body": "*", - }, - ] + http_options = ( + _BaseFirestoreRestTransport._BaseCancelOperation._get_http_options() + ) request, metadata = self._interceptor.pre_cancel_operation( request, metadata ) - request_kwargs = json_format.MessageToDict(request) - transcoded_request = path_template.transcode(http_options, **request_kwargs) + transcoded_request = _BaseFirestoreRestTransport._BaseCancelOperation._get_transcoded_request( + http_options, request + ) - body = json.dumps(transcoded_request["body"]) - uri = transcoded_request["uri"] - method = transcoded_request["method"] + body = ( + _BaseFirestoreRestTransport._BaseCancelOperation._get_request_body_json( + transcoded_request + ) + ) # Jsonify the query params - query_params = json.loads(json.dumps(transcoded_request["query_params"])) + query_params = ( + _BaseFirestoreRestTransport._BaseCancelOperation._get_query_params_json( + transcoded_request + ) + ) - # Send the request - headers = dict(metadata) - headers["Content-Type"] = "application/json" + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = json_format.MessageToJson(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.firestore_v1.FirestoreClient.CancelOperation", + extra={ + "serviceName": "google.firestore.v1.Firestore", + "rpcName": "CancelOperation", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) - response = getattr(self._session, method)( - "{host}{uri}".format(host=self._host, uri=uri), - timeout=timeout, - headers=headers, - params=rest_helpers.flatten_query_params(query_params), - data=body, + # Send the request + response = FirestoreRestTransport._CancelOperation._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + body, ) # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception @@ -2323,14 +3504,41 @@ def __call__( def delete_operation(self): return self._DeleteOperation(self._session, self._host, self._interceptor) # type: ignore - class _DeleteOperation(FirestoreRestStub): + class _DeleteOperation( + _BaseFirestoreRestTransport._BaseDeleteOperation, FirestoreRestStub + ): + def __hash__(self): + return hash("FirestoreRestTransport.DeleteOperation") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + return response + def __call__( self, request: operations_pb2.DeleteOperationRequest, *, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> None: r"""Call the delete operation method over HTTP. @@ -2340,38 +3548,65 @@ def __call__( retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. """ - http_options: List[Dict[str, str]] = [ - { - "method": "delete", - "uri": "/v1/{name=projects/*/databases/*/operations/*}", - }, - ] + http_options = ( + _BaseFirestoreRestTransport._BaseDeleteOperation._get_http_options() + ) request, metadata = self._interceptor.pre_delete_operation( request, metadata ) - request_kwargs = json_format.MessageToDict(request) - transcoded_request = path_template.transcode(http_options, **request_kwargs) - - uri = transcoded_request["uri"] - method = transcoded_request["method"] + transcoded_request = _BaseFirestoreRestTransport._BaseDeleteOperation._get_transcoded_request( + http_options, request + ) # Jsonify the query params - query_params = json.loads(json.dumps(transcoded_request["query_params"])) + query_params = ( + _BaseFirestoreRestTransport._BaseDeleteOperation._get_query_params_json( + transcoded_request + ) + ) - # Send the request - headers = dict(metadata) - headers["Content-Type"] = "application/json" + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = json_format.MessageToJson(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.firestore_v1.FirestoreClient.DeleteOperation", + extra={ + "serviceName": "google.firestore.v1.Firestore", + "rpcName": "DeleteOperation", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) - response = getattr(self._session, method)( - "{host}{uri}".format(host=self._host, uri=uri), - timeout=timeout, - headers=headers, - params=rest_helpers.flatten_query_params(query_params), + # Send the request + response = FirestoreRestTransport._DeleteOperation._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, ) # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception @@ -2385,14 +3620,41 @@ def __call__( def get_operation(self): return self._GetOperation(self._session, self._host, self._interceptor) # type: ignore - class _GetOperation(FirestoreRestStub): + class _GetOperation( + _BaseFirestoreRestTransport._BaseGetOperation, FirestoreRestStub + ): + def __hash__(self): + return hash("FirestoreRestTransport.GetOperation") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + return response + def __call__( self, request: operations_pb2.GetOperationRequest, *, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> operations_pb2.Operation: r"""Call the get operation method over HTTP. @@ -2402,39 +3664,68 @@ def __call__( retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. Returns: operations_pb2.Operation: Response from GetOperation method. """ - http_options: List[Dict[str, str]] = [ - { - "method": "get", - "uri": "/v1/{name=projects/*/databases/*/operations/*}", - }, - ] + http_options = ( + _BaseFirestoreRestTransport._BaseGetOperation._get_http_options() + ) request, metadata = self._interceptor.pre_get_operation(request, metadata) - request_kwargs = json_format.MessageToDict(request) - transcoded_request = path_template.transcode(http_options, **request_kwargs) - - uri = transcoded_request["uri"] - method = transcoded_request["method"] + transcoded_request = ( + _BaseFirestoreRestTransport._BaseGetOperation._get_transcoded_request( + http_options, request + ) + ) # Jsonify the query params - query_params = json.loads(json.dumps(transcoded_request["query_params"])) + query_params = ( + _BaseFirestoreRestTransport._BaseGetOperation._get_query_params_json( + transcoded_request + ) + ) - # Send the request - headers = dict(metadata) - headers["Content-Type"] = "application/json" + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = json_format.MessageToJson(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.firestore_v1.FirestoreClient.GetOperation", + extra={ + "serviceName": "google.firestore.v1.Firestore", + "rpcName": "GetOperation", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) - response = getattr(self._session, method)( - "{host}{uri}".format(host=self._host, uri=uri), - timeout=timeout, - headers=headers, - params=rest_helpers.flatten_query_params(query_params), + # Send the request + response = FirestoreRestTransport._GetOperation._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, ) # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception @@ -2442,23 +3733,72 @@ def __call__( if response.status_code >= 400: raise core_exceptions.from_http_response(response) + content = response.content.decode("utf-8") resp = operations_pb2.Operation() - resp = json_format.Parse(response.content.decode("utf-8"), resp) + resp = json_format.Parse(content, resp) resp = self._interceptor.post_get_operation(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = json_format.MessageToJson(resp) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.firestore_v1.FirestoreAsyncClient.GetOperation", + extra={ + "serviceName": "google.firestore.v1.Firestore", + "rpcName": "GetOperation", + "httpResponse": http_response, + "metadata": http_response["headers"], + }, + ) return resp @property def list_operations(self): return self._ListOperations(self._session, self._host, self._interceptor) # type: ignore - class _ListOperations(FirestoreRestStub): + class _ListOperations( + _BaseFirestoreRestTransport._BaseListOperations, FirestoreRestStub + ): + def __hash__(self): + return hash("FirestoreRestTransport.ListOperations") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + return response + def __call__( self, request: operations_pb2.ListOperationsRequest, *, retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Optional[float] = None, - metadata: Sequence[Tuple[str, str]] = (), + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> operations_pb2.ListOperationsResponse: r"""Call the list operations method over HTTP. @@ -2468,39 +3808,68 @@ def __call__( retry (google.api_core.retry.Retry): Designation of what errors, if any, should be retried. timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. Returns: operations_pb2.ListOperationsResponse: Response from ListOperations method. """ - http_options: List[Dict[str, str]] = [ - { - "method": "get", - "uri": "/v1/{name=projects/*/databases/*}/operations", - }, - ] + http_options = ( + _BaseFirestoreRestTransport._BaseListOperations._get_http_options() + ) request, metadata = self._interceptor.pre_list_operations(request, metadata) - request_kwargs = json_format.MessageToDict(request) - transcoded_request = path_template.transcode(http_options, **request_kwargs) - - uri = transcoded_request["uri"] - method = transcoded_request["method"] + transcoded_request = ( + _BaseFirestoreRestTransport._BaseListOperations._get_transcoded_request( + http_options, request + ) + ) # Jsonify the query params - query_params = json.loads(json.dumps(transcoded_request["query_params"])) + query_params = ( + _BaseFirestoreRestTransport._BaseListOperations._get_query_params_json( + transcoded_request + ) + ) - # Send the request - headers = dict(metadata) - headers["Content-Type"] = "application/json" + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = json_format.MessageToJson(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.firestore_v1.FirestoreClient.ListOperations", + extra={ + "serviceName": "google.firestore.v1.Firestore", + "rpcName": "ListOperations", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) - response = getattr(self._session, method)( - "{host}{uri}".format(host=self._host, uri=uri), - timeout=timeout, - headers=headers, - params=rest_helpers.flatten_query_params(query_params), + # Send the request + response = FirestoreRestTransport._ListOperations._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, ) # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception @@ -2508,9 +3877,31 @@ def __call__( if response.status_code >= 400: raise core_exceptions.from_http_response(response) + content = response.content.decode("utf-8") resp = operations_pb2.ListOperationsResponse() - resp = json_format.Parse(response.content.decode("utf-8"), resp) + resp = json_format.Parse(content, resp) resp = self._interceptor.post_list_operations(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = json_format.MessageToJson(resp) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.firestore_v1.FirestoreAsyncClient.ListOperations", + extra={ + "serviceName": "google.firestore.v1.Firestore", + "rpcName": "ListOperations", + "httpResponse": http_response, + "metadata": http_response["headers"], + }, + ) return resp @property diff --git a/google/cloud/firestore_v1/services/firestore/transports/rest_base.py b/google/cloud/firestore_v1/services/firestore/transports/rest_base.py new file mode 100644 index 000000000..721f0792f --- /dev/null +++ b/google/cloud/firestore_v1/services/firestore/transports/rest_base.py @@ -0,0 +1,1046 @@ +# -*- coding: utf-8 -*- +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import json # type: ignore +from google.api_core import path_template +from google.api_core import gapic_v1 + +from google.protobuf import json_format +from google.cloud.location import locations_pb2 # type: ignore +from .base import FirestoreTransport, DEFAULT_CLIENT_INFO + +import re +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union + + +from google.cloud.firestore_v1.types import document +from google.cloud.firestore_v1.types import document as gf_document +from google.cloud.firestore_v1.types import firestore +from google.protobuf import empty_pb2 # type: ignore +from google.longrunning import operations_pb2 # type: ignore + + +class _BaseFirestoreRestTransport(FirestoreTransport): + """Base REST backend transport for Firestore. + + Note: This class is not meant to be used directly. Use its sync and + async sub-classes instead. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends JSON representations of protocol buffers over HTTP/1.1 + """ + + def __init__( + self, + *, + host: str = "firestore.googleapis.com", + credentials: Optional[Any] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + url_scheme: str = "https", + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + Args: + host (Optional[str]): + The hostname to connect to (default: 'firestore.googleapis.com'). + credentials (Optional[Any]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you are developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + url_scheme: the protocol scheme for the API endpoint. Normally + "https", but for testing or local servers, + "http" can be specified. + """ + # Run the base constructor + maybe_url_match = re.match("^(?Phttp(?:s)?://)?(?P.*)$", host) + if maybe_url_match is None: + raise ValueError( + f"Unexpected hostname structure: {host}" + ) # pragma: NO COVER + + url_match_items = maybe_url_match.groupdict() + + host = f"{url_scheme}://{host}" if not url_match_items["scheme"] else host + + super().__init__( + host=host, + credentials=credentials, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + + class _BaseBatchGetDocuments: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1/{database=projects/*/databases/*}/documents:batchGet", + "body": "*", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = firestore.BatchGetDocumentsRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=False + ) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=False, + ) + ) + query_params.update( + _BaseFirestoreRestTransport._BaseBatchGetDocuments._get_unset_required_fields( + query_params + ) + ) + + return query_params + + class _BaseBatchWrite: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1/{database=projects/*/databases/*}/documents:batchWrite", + "body": "*", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = firestore.BatchWriteRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=False + ) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=False, + ) + ) + query_params.update( + _BaseFirestoreRestTransport._BaseBatchWrite._get_unset_required_fields( + query_params + ) + ) + + return query_params + + class _BaseBeginTransaction: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1/{database=projects/*/databases/*}/documents:beginTransaction", + "body": "*", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = firestore.BeginTransactionRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=False + ) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=False, + ) + ) + query_params.update( + _BaseFirestoreRestTransport._BaseBeginTransaction._get_unset_required_fields( + query_params + ) + ) + + return query_params + + class _BaseCommit: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1/{database=projects/*/databases/*}/documents:commit", + "body": "*", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = firestore.CommitRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=False + ) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=False, + ) + ) + query_params.update( + _BaseFirestoreRestTransport._BaseCommit._get_unset_required_fields( + query_params + ) + ) + + return query_params + + class _BaseCreateDocument: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1/{parent=projects/*/databases/*/documents/**}/{collection_id}", + "body": "document", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = firestore.CreateDocumentRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=False + ) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=False, + ) + ) + query_params.update( + _BaseFirestoreRestTransport._BaseCreateDocument._get_unset_required_fields( + query_params + ) + ) + + return query_params + + class _BaseDeleteDocument: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "delete", + "uri": "/v1/{name=projects/*/databases/*/documents/*/**}", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = firestore.DeleteDocumentRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=False, + ) + ) + query_params.update( + _BaseFirestoreRestTransport._BaseDeleteDocument._get_unset_required_fields( + query_params + ) + ) + + return query_params + + class _BaseExecutePipeline: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1/{database=projects/*/databases/*}/documents:executePipeline", + "body": "*", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = firestore.ExecutePipelineRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=False + ) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=False, + ) + ) + query_params.update( + _BaseFirestoreRestTransport._BaseExecutePipeline._get_unset_required_fields( + query_params + ) + ) + + return query_params + + class _BaseGetDocument: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1/{name=projects/*/databases/*/documents/*/**}", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = firestore.GetDocumentRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=False, + ) + ) + query_params.update( + _BaseFirestoreRestTransport._BaseGetDocument._get_unset_required_fields( + query_params + ) + ) + + return query_params + + class _BaseListCollectionIds: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1/{parent=projects/*/databases/*/documents}:listCollectionIds", + "body": "*", + }, + { + "method": "post", + "uri": "/v1/{parent=projects/*/databases/*/documents/*/**}:listCollectionIds", + "body": "*", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = firestore.ListCollectionIdsRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=False + ) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=False, + ) + ) + query_params.update( + _BaseFirestoreRestTransport._BaseListCollectionIds._get_unset_required_fields( + query_params + ) + ) + + return query_params + + class _BaseListDocuments: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1/{parent=projects/*/databases/*/documents/*/**}/{collection_id}", + }, + { + "method": "get", + "uri": "/v1/{parent=projects/*/databases/*/documents}/{collection_id}", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = firestore.ListDocumentsRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=False, + ) + ) + query_params.update( + _BaseFirestoreRestTransport._BaseListDocuments._get_unset_required_fields( + query_params + ) + ) + + return query_params + + class _BaseListen: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + class _BasePartitionQuery: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1/{parent=projects/*/databases/*/documents}:partitionQuery", + "body": "*", + }, + { + "method": "post", + "uri": "/v1/{parent=projects/*/databases/*/documents/*/**}:partitionQuery", + "body": "*", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = firestore.PartitionQueryRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=False + ) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=False, + ) + ) + query_params.update( + _BaseFirestoreRestTransport._BasePartitionQuery._get_unset_required_fields( + query_params + ) + ) + + return query_params + + class _BaseRollback: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1/{database=projects/*/databases/*}/documents:rollback", + "body": "*", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = firestore.RollbackRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=False + ) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=False, + ) + ) + query_params.update( + _BaseFirestoreRestTransport._BaseRollback._get_unset_required_fields( + query_params + ) + ) + + return query_params + + class _BaseRunAggregationQuery: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1/{parent=projects/*/databases/*/documents}:runAggregationQuery", + "body": "*", + }, + { + "method": "post", + "uri": "/v1/{parent=projects/*/databases/*/documents/*/**}:runAggregationQuery", + "body": "*", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = firestore.RunAggregationQueryRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=False + ) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=False, + ) + ) + query_params.update( + _BaseFirestoreRestTransport._BaseRunAggregationQuery._get_unset_required_fields( + query_params + ) + ) + + return query_params + + class _BaseRunQuery: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1/{parent=projects/*/databases/*/documents}:runQuery", + "body": "*", + }, + { + "method": "post", + "uri": "/v1/{parent=projects/*/databases/*/documents/*/**}:runQuery", + "body": "*", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = firestore.RunQueryRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=False + ) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=False, + ) + ) + query_params.update( + _BaseFirestoreRestTransport._BaseRunQuery._get_unset_required_fields( + query_params + ) + ) + + return query_params + + class _BaseUpdateDocument: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "patch", + "uri": "/v1/{document.name=projects/*/databases/*/documents/*/**}", + "body": "document", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = firestore.UpdateDocumentRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=False + ) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=False, + ) + ) + query_params.update( + _BaseFirestoreRestTransport._BaseUpdateDocument._get_unset_required_fields( + query_params + ) + ) + + return query_params + + class _BaseWrite: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + class _BaseCancelOperation: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1/{name=projects/*/databases/*/operations/*}:cancel", + "body": "*", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + request_kwargs = json_format.MessageToDict(request) + transcoded_request = path_template.transcode(http_options, **request_kwargs) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + body = json.dumps(transcoded_request["body"]) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads(json.dumps(transcoded_request["query_params"])) + return query_params + + class _BaseDeleteOperation: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "delete", + "uri": "/v1/{name=projects/*/databases/*/operations/*}", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + request_kwargs = json_format.MessageToDict(request) + transcoded_request = path_template.transcode(http_options, **request_kwargs) + return transcoded_request + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads(json.dumps(transcoded_request["query_params"])) + return query_params + + class _BaseGetOperation: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1/{name=projects/*/databases/*/operations/*}", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + request_kwargs = json_format.MessageToDict(request) + transcoded_request = path_template.transcode(http_options, **request_kwargs) + return transcoded_request + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads(json.dumps(transcoded_request["query_params"])) + return query_params + + class _BaseListOperations: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1/{name=projects/*/databases/*}/operations", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + request_kwargs = json_format.MessageToDict(request) + transcoded_request = path_template.transcode(http_options, **request_kwargs) + return transcoded_request + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads(json.dumps(transcoded_request["query_params"])) + return query_params + + +__all__ = ("_BaseFirestoreRestTransport",) diff --git a/google/cloud/firestore_v1/types/__init__.py b/google/cloud/firestore_v1/types/__init__.py index 1e6b0c729..ed1965d7f 100644 --- a/google/cloud/firestore_v1/types/__init__.py +++ b/google/cloud/firestore_v1/types/__init__.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2024 Google LLC +# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/google/cloud/firestore_v1/types/aggregation_result.py b/google/cloud/firestore_v1/types/aggregation_result.py index 1fbe2988d..3c649dc8a 100644 --- a/google/cloud/firestore_v1/types/aggregation_result.py +++ b/google/cloud/firestore_v1/types/aggregation_result.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2024 Google LLC +# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/google/cloud/firestore_v1/types/bloom_filter.py b/google/cloud/firestore_v1/types/bloom_filter.py index 3c92b2173..f38386cbe 100644 --- a/google/cloud/firestore_v1/types/bloom_filter.py +++ b/google/cloud/firestore_v1/types/bloom_filter.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2024 Google LLC +# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/google/cloud/firestore_v1/types/common.py b/google/cloud/firestore_v1/types/common.py index cecb1b610..01fb3d263 100644 --- a/google/cloud/firestore_v1/types/common.py +++ b/google/cloud/firestore_v1/types/common.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2024 Google LLC +# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/google/cloud/firestore_v1/types/document.py b/google/cloud/firestore_v1/types/document.py index 432e043df..1757571b1 100644 --- a/google/cloud/firestore_v1/types/document.py +++ b/google/cloud/firestore_v1/types/document.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2024 Google LLC +# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -185,6 +185,37 @@ class Value(proto.Message): map_value (google.cloud.firestore_v1.types.MapValue): A map value. + This field is a member of `oneof`_ ``value_type``. + field_reference_value (str): + Value which references a field. + + This is considered relative (vs absolute) since it only + refers to a field and not a field within a particular + document. + + **Requires:** + + - Must follow [field reference][FieldReference.field_path] + limitations. + + - Not allowed to be used when writing documents. + + This field is a member of `oneof`_ ``value_type``. + function_value (google.cloud.firestore_v1.types.Function): + A value that represents an unevaluated expression. + + **Requires:** + + - Not allowed to be used when writing documents. + + This field is a member of `oneof`_ ``value_type``. + pipeline_value (google.cloud.firestore_v1.types.Pipeline): + A value that represents an unevaluated pipeline. + + **Requires:** + + - Not allowed to be used when writing documents. + This field is a member of `oneof`_ ``value_type``. """ diff --git a/google/cloud/firestore_v1/types/explain_stats.py b/google/cloud/firestore_v1/types/explain_stats.py index 9d12e8c31..1fda228b6 100644 --- a/google/cloud/firestore_v1/types/explain_stats.py +++ b/google/cloud/firestore_v1/types/explain_stats.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2022 Google LLC +# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/google/cloud/firestore_v1/types/firestore.py b/google/cloud/firestore_v1/types/firestore.py index c2d9f8475..f1753c92f 100644 --- a/google/cloud/firestore_v1/types/firestore.py +++ b/google/cloud/firestore_v1/types/firestore.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2024 Google LLC +# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -867,6 +867,7 @@ class ExecutePipelineRequest(proto.Message): This field is a member of `oneof`_ ``consistency_selector``. new_transaction (google.cloud.firestore_v1.types.TransactionOptions): Execute the pipeline in a new transaction. + The identifier of the newly created transaction will be returned in the first response on the stream. This defaults to a read-only @@ -952,6 +953,7 @@ class ExecutePipelineResponse(proto.Message): represents the time at which the operation was run. explain_stats (google.cloud.firestore_v1.types.ExplainStats): Query explain stats. + Contains all metadata related to pipeline planning and execution, specific contents depend on the supplied pipeline options. diff --git a/google/cloud/firestore_v1/types/pipeline.py b/google/cloud/firestore_v1/types/pipeline.py index 0aed187cf..29fbe884b 100644 --- a/google/cloud/firestore_v1/types/pipeline.py +++ b/google/cloud/firestore_v1/types/pipeline.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2022 Google LLC +# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/google/cloud/firestore_v1/types/query.py b/google/cloud/firestore_v1/types/query.py index 3e53208aa..9aa8977dd 100644 --- a/google/cloud/firestore_v1/types/query.py +++ b/google/cloud/firestore_v1/types/query.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2024 Google LLC +# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -555,8 +555,9 @@ class FindNearest(proto.Message): when the vectors are more similar, the comparison is inverted. - For EUCLIDEAN, COSINE: WHERE distance <= distance_threshold - For DOT_PRODUCT: WHERE distance >= distance_threshold + - For EUCLIDEAN, COSINE: WHERE distance <= + distance_threshold + - For DOT_PRODUCT: WHERE distance >= distance_threshold """ class DistanceMeasure(proto.Enum): diff --git a/google/cloud/firestore_v1/types/query_profile.py b/google/cloud/firestore_v1/types/query_profile.py index 0b26236cf..f93184ae3 100644 --- a/google/cloud/firestore_v1/types/query_profile.py +++ b/google/cloud/firestore_v1/types/query_profile.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2024 Google LLC +# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/google/cloud/firestore_v1/types/write.py b/google/cloud/firestore_v1/types/write.py index 8b12cced2..e393b9148 100644 --- a/google/cloud/firestore_v1/types/write.py +++ b/google/cloud/firestore_v1/types/write.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2024 Google LLC +# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 93003c096c8babd38942b0ad247b4925e8c336c7 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 8 May 2025 14:23:55 -0700 Subject: [PATCH 109/131] removed unneeded code --- google/cloud/firestore_v1/base_query.py | 5 --- google/cloud/firestore_v1/pipeline_source.py | 35 +------------------- noxfile.py | 27 ++++++++++----- 3 files changed, 19 insertions(+), 48 deletions(-) diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index 1aaf6d081..5a9efaf78 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -57,8 +57,6 @@ query, ) from google.cloud.firestore_v1.vector import Vector -from google.cloud.firestore_v1 import pipeline_expressions -from google.cloud.firestore_v1 import pipeline_stages if TYPE_CHECKING: # pragma: NO COVER from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator @@ -1121,9 +1119,6 @@ def recursive(self: QueryType) -> QueryType: return copied - def pipeline(self): - raise NotImplementedError - def _comparator(self, doc1, doc2) -> int: _orders = self._orders diff --git a/google/cloud/firestore_v1/pipeline_source.py b/google/cloud/firestore_v1/pipeline_source.py index 97044e471..0a336dfa5 100644 --- a/google/cloud/firestore_v1/pipeline_source.py +++ b/google/cloud/firestore_v1/pipeline_source.py @@ -51,37 +51,4 @@ def collection(self, path: str) -> PipelineType: Returns: a new pipeline instance targeting the specified collection """ - return self.client._pipeline_cls(self.client, stages.Collection(path)) - - def collection_group(self, collection_id: str) -> PipelineType: - """ - Creates a new Pipeline that that operates on all documents in a collection group. - - Args: - collection_id: The ID of the collection group - Returns: - a new pipeline instance targeting the specified collection group - """ - return self.client._pipeline_cls( - self.client, stages.CollectionGroup(collection_id) - ) - - def database(self) -> PipelineType: - """ - Creates a new Pipeline that operates on all documents in the Firestore database. - - Returns: - a new pipeline instance targeting the specified collection - """ - return self.client._pipeline_cls(self.client, stages.Database()) - - def documents(self, *docs: "BaseDocumentReference") -> PipelineType: - """ - Creates a new Pipeline that operates on a specific set of Firestore documents. - - Args: - docs: The DocumentReference instances representing the documents to include in the pipeline. - Returns: - a new pipeline instance targeting the specified documents - """ - return self.client._pipeline_cls(self.client, stages.Documents.of(*docs)) + return self.client._pipeline_cls(self.client, stages.Collection(path)) \ No newline at end of file diff --git a/noxfile.py b/noxfile.py index 8ebe6dbd2..1b89f9b9b 100644 --- a/noxfile.py +++ b/noxfile.py @@ -158,15 +158,24 @@ def mypy(session): session.install("-e", ".") session.install("mypy", "types-setuptools", "types-protobuf") # TODO: also verify types on tests, all of google package - session.run("mypy", - "-p", "google.cloud.firestore_v1.pipeline_expressions", - "-p", "google.cloud.firestore_v1.pipeline_stages", - "-p", "google.cloud.firestore_v1.pipeline_source", - "-p", "google.cloud.firestore_v1.pipeline_result", - "-p", "google.cloud.firestore_v1.base_pipeline", - "-p", "google.cloud.firestore_v1.async_pipeline", - "-p", "google.cloud.firestore_v1.pipeline", - "--no-incremental") + session.run( + "mypy", + "-p", + "google.cloud.firestore_v1.pipeline_expressions", + "-p", + "google.cloud.firestore_v1.pipeline_stages", + "-p", + "google.cloud.firestore_v1.pipeline_source", + "-p", + "google.cloud.firestore_v1.pipeline_result", + "-p", + "google.cloud.firestore_v1.base_pipeline", + "-p", + "google.cloud.firestore_v1.async_pipeline", + "-p", + "google.cloud.firestore_v1.pipeline", + "--no-incremental", + ) @nox.session(python=DEFAULT_PYTHON_VERSION) From 79d016a8bb0e6a64183b130c3903e0b56038d265 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 8 May 2025 14:56:57 -0700 Subject: [PATCH 110/131] fixed lint --- google/cloud/firestore_v1/base_client.py | 2 ++ google/cloud/firestore_v1/base_pipeline.py | 3 --- google/cloud/firestore_v1/pipeline.py | 3 +-- google/cloud/firestore_v1/pipeline_source.py | 3 +-- 4 files changed, 4 insertions(+), 7 deletions(-) diff --git a/google/cloud/firestore_v1/base_client.py b/google/cloud/firestore_v1/base_client.py index ca815e52c..415ddaf44 100644 --- a/google/cloud/firestore_v1/base_client.py +++ b/google/cloud/firestore_v1/base_client.py @@ -61,6 +61,8 @@ from google.cloud.firestore_v1.bulk_writer import BulkWriter, BulkWriterOptions from google.cloud.firestore_v1.field_path import render_field_path from google.cloud.firestore_v1.services.firestore import client as firestore_client +from google.cloud.firestore_v1.pipeline_source import PipelineSource +from google.cloud.firestore_v1.base_pipeline import _BasePipeline DEFAULT_DATABASE = "(default)" """str: The default database used in a :class:`~google.cloud.firestore_v1.client.Client`.""" diff --git a/google/cloud/firestore_v1/base_pipeline.py b/google/cloud/firestore_v1/base_pipeline.py index 91da9ac2a..b7e4566d6 100644 --- a/google/cloud/firestore_v1/base_pipeline.py +++ b/google/cloud/firestore_v1/base_pipeline.py @@ -15,12 +15,9 @@ from __future__ import annotations from typing import TYPE_CHECKING from google.cloud.firestore_v1 import pipeline_stages as stages -from google.cloud.firestore_v1.document import DocumentReference from google.cloud.firestore_v1.types.pipeline import ( StructuredPipeline as StructuredPipeline_pb, ) -from google.cloud.firestore_v1.pipeline_result import PipelineResult -from google.cloud.firestore_v1 import _helpers, document if TYPE_CHECKING: from google.cloud.firestore_v1.client import Client diff --git a/google/cloud/firestore_v1/pipeline.py b/google/cloud/firestore_v1/pipeline.py index 8ddde7ec9..264b05c18 100644 --- a/google/cloud/firestore_v1/pipeline.py +++ b/google/cloud/firestore_v1/pipeline.py @@ -13,8 +13,7 @@ # limitations under the License. from __future__ import annotations -import datetime -from typing import AsyncIterable, Iterable, TYPE_CHECKING +from typing import Iterable, TYPE_CHECKING from google.cloud.firestore_v1 import pipeline_stages as stages from google.cloud.firestore_v1.types.firestore import ExecutePipelineRequest from google.cloud.firestore_v1.document import DocumentReference diff --git a/google/cloud/firestore_v1/pipeline_source.py b/google/cloud/firestore_v1/pipeline_source.py index 0a336dfa5..d072c328b 100644 --- a/google/cloud/firestore_v1/pipeline_source.py +++ b/google/cloud/firestore_v1/pipeline_source.py @@ -20,7 +20,6 @@ if TYPE_CHECKING: from google.cloud.firestore_v1.client import Client from google.cloud.firestore_v1.async_client import AsyncClient - from google.cloud.firestore_v1.base_document import BaseDocumentReference PipelineType = TypeVar("PipelineType", bound=_BasePipeline) @@ -51,4 +50,4 @@ def collection(self, path: str) -> PipelineType: Returns: a new pipeline instance targeting the specified collection """ - return self.client._pipeline_cls(self.client, stages.Collection(path)) \ No newline at end of file + return self.client._pipeline_cls(self.client, stages.Collection(path)) From 907a551dea0ca2448c0c73b01ab83a0dd948edb0 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 8 May 2025 16:29:57 -0700 Subject: [PATCH 111/131] added client tests --- google/cloud/firestore_v1/async_client.py | 2 +- google/cloud/firestore_v1/client.py | 2 +- tests/unit/v1/test_async_client.py | 8 ++++++++ tests/unit/v1/test_client.py | 9 +++++++++ 4 files changed, 19 insertions(+), 2 deletions(-) diff --git a/google/cloud/firestore_v1/async_client.py b/google/cloud/firestore_v1/async_client.py index 21790194f..3f37604dc 100644 --- a/google/cloud/firestore_v1/async_client.py +++ b/google/cloud/firestore_v1/async_client.py @@ -417,4 +417,4 @@ def transaction(self, **kwargs) -> AsyncTransaction: @property def _pipeline_cls(self): - raise AsyncPipeline + return AsyncPipeline diff --git a/google/cloud/firestore_v1/client.py b/google/cloud/firestore_v1/client.py index 1371d18ea..167830dcb 100644 --- a/google/cloud/firestore_v1/client.py +++ b/google/cloud/firestore_v1/client.py @@ -399,4 +399,4 @@ def transaction(self, **kwargs) -> Transaction: @property def _pipeline_cls(self): - raise Pipeline + return Pipeline diff --git a/tests/unit/v1/test_async_client.py b/tests/unit/v1/test_async_client.py index ee624d382..cd29e8efc 100644 --- a/tests/unit/v1/test_async_client.py +++ b/tests/unit/v1/test_async_client.py @@ -532,6 +532,14 @@ def test_asyncclient_transaction(): assert transaction._read_only assert transaction._id is None +def test_asyncclient_pipeline(database): + from google.cloud.firestore_v1.async_pipeline import AsyncPipeline + from google.cloud.firestore_v1.pipeline_source import PipelineSource + client = _make_default_async_client(database=database) + ppl = client.pipeline() + assert client._pipeline_cls == AsyncPipeline + assert isinstance(ppl, PipelineSource) + assert ppl.client == client def _make_credentials(): import google.auth.credentials diff --git a/tests/unit/v1/test_client.py b/tests/unit/v1/test_client.py index edb411c9f..8c09ec274 100644 --- a/tests/unit/v1/test_client.py +++ b/tests/unit/v1/test_client.py @@ -621,6 +621,15 @@ def test_client_transaction(database): assert transaction._read_only assert transaction._id is None +@pytest.mark.parametrize("database", [None, DEFAULT_DATABASE, "somedb"]) +def test_client_pipeline(database): + from google.cloud.firestore_v1.pipeline import Pipeline + from google.cloud.firestore_v1.pipeline_source import PipelineSource + client = _make_default_client(database=database) + ppl = client.pipeline() + assert client._pipeline_cls == Pipeline + assert isinstance(ppl, PipelineSource) + assert ppl.client == client def _make_batch_response(**kwargs): from google.cloud.firestore_v1.types import firestore From 8e20c118b2ae7206ded739499b43f4c9f818fde7 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 8 May 2025 17:31:24 -0700 Subject: [PATCH 112/131] added pipeline tests --- google/cloud/firestore_v1/base_pipeline.py | 7 +- tests/unit/v1/test_async_pipeline.py | 92 ++++++++++++++++++++++ tests/unit/v1/test_pipeline.py | 92 ++++++++++++++++++++++ 3 files changed, 188 insertions(+), 3 deletions(-) create mode 100644 tests/unit/v1/test_async_pipeline.py create mode 100644 tests/unit/v1/test_pipeline.py diff --git a/google/cloud/firestore_v1/base_pipeline.py b/google/cloud/firestore_v1/base_pipeline.py index b7e4566d6..b9ac6f084 100644 --- a/google/cloud/firestore_v1/base_pipeline.py +++ b/google/cloud/firestore_v1/base_pipeline.py @@ -46,13 +46,14 @@ def __init__(self, client: Client | AsyncClient, *stages: stages.Stage): self.stages = tuple(stages) def __repr__(self): + cls_str = type(self).__name__ if not self.stages: - return "Pipeline()" + return f"{cls_str}()" elif len(self.stages) == 1: - return f"Pipeline({self.stages[0]!r})" + return f"{cls_str}({self.stages[0]!r})" else: stages_str = ",\n ".join([repr(s) for s in self.stages]) - return f"Pipeline(\n {stages_str}\n)" + return f"{cls_str}(\n {stages_str}\n)" def _to_pb(self) -> StructuredPipeline_pb: return StructuredPipeline_pb( diff --git a/tests/unit/v1/test_async_pipeline.py b/tests/unit/v1/test_async_pipeline.py new file mode 100644 index 000000000..72431a952 --- /dev/null +++ b/tests/unit/v1/test_async_pipeline.py @@ -0,0 +1,92 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import mock + +def _make_async_pipeline(*args, client=mock.Mock()): + from google.cloud.firestore_v1.async_pipeline import AsyncPipeline + + return AsyncPipeline(client, *args) + +def test_ctor(): + from google.cloud.firestore_v1.async_pipeline import AsyncPipeline + client = object() + stages = [object() for i in range(10)] + instance = AsyncPipeline(client, *stages) + assert instance._client == client + assert len(instance.stages) == 10 + assert instance.stages[0] == stages[0] + assert instance.stages[-1] == stages[-1] + +def test_async_pipeline_repr_empty(): + ppl = _make_async_pipeline() + repr_str = repr(ppl) + assert repr_str == "AsyncPipeline()" + +def test_async_pipeline_repr_single_stage(): + stage = mock.Mock() + stage.__repr__ = lambda x: "SingleStage" + ppl = _make_async_pipeline(stage) + repr_str = repr(ppl) + assert repr_str == 'AsyncPipeline(SingleStage)' + +def test_async_pipeline_repr_multiple_stage(): + from google.cloud.firestore_v1.pipeline_stages import GenericStage, Collection + stage_1 = Collection("path") + stage_2 = GenericStage("second", 2) + stage_3 = GenericStage("third", 3) + ppl = _make_async_pipeline(stage_1, stage_2, stage_3) + repr_str = repr(ppl) + assert repr_str == ( + "AsyncPipeline(\n" + " Collection(path='/path'),\n" + " GenericStage(params=[2]),\n" + " GenericStage(params=[3])\n" + ")" + ) + +def test_async_pipeline_repr_long(): + from google.cloud.firestore_v1.pipeline_stages import GenericStage + num_stages = 100 + stage_list = [GenericStage("custom", i) for i in range(num_stages)] + ppl = _make_async_pipeline(*stage_list) + repr_str = repr(ppl) + assert repr_str.count("GenericStage") == num_stages + assert repr_str.count('\n') == num_stages+1 + +def test_async_pipeline__to_pb(): + from google.cloud.firestore_v1.types.pipeline import StructuredPipeline + from google.cloud.firestore_v1.pipeline_stages import GenericStage + stage_1 = GenericStage("first") + stage_2 = GenericStage("second") + ppl = _make_async_pipeline(stage_1, stage_2) + pb = ppl._to_pb() + assert isinstance(pb, StructuredPipeline) + assert pb.pipeline.stages[0] == stage_1._to_pb() + assert pb.pipeline.stages[1] == stage_2._to_pb() + +def test_async_pipeline_append(): + """append should create a new pipeline with the additional stage""" + from google.cloud.firestore_v1.pipeline_stages import GenericStage + stage_1 = GenericStage("first") + ppl_1 = _make_async_pipeline(stage_1, client=object()) + stage_2 = GenericStage("second") + ppl_2 = ppl_1._append(stage_2) + assert ppl_1 != ppl_2 + assert len(ppl_1.stages) == 1 + assert len(ppl_2.stages) == 2 + assert ppl_2.stages[0] == stage_1 + assert ppl_2.stages[1] == stage_2 + assert ppl_1._client == ppl_2._client + assert isinstance(ppl_2, type(ppl_1)) \ No newline at end of file diff --git a/tests/unit/v1/test_pipeline.py b/tests/unit/v1/test_pipeline.py new file mode 100644 index 000000000..f5e62b453 --- /dev/null +++ b/tests/unit/v1/test_pipeline.py @@ -0,0 +1,92 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import mock + +def _make_pipeline(*args, client=mock.Mock()): + from google.cloud.firestore_v1.pipeline import Pipeline + + return Pipeline(client, *args) + +def test_ctor(): + from google.cloud.firestore_v1.pipeline import Pipeline + client = object() + stages = [object() for i in range(10)] + instance = Pipeline(client, *stages) + assert instance._client == client + assert len(instance.stages) == 10 + assert instance.stages[0] == stages[0] + assert instance.stages[-1] == stages[-1] + +def test_pipeline_repr_empty(): + ppl = _make_pipeline() + repr_str = repr(ppl) + assert repr_str == "Pipeline()" + +def test_pipeline_repr_single_stage(): + stage = mock.Mock() + stage.__repr__ = lambda x: "SingleStage" + ppl = _make_pipeline(stage) + repr_str = repr(ppl) + assert repr_str == 'Pipeline(SingleStage)' + +def test_pipeline_repr_multiple_stage(): + from google.cloud.firestore_v1.pipeline_stages import GenericStage, Collection + stage_1 = Collection("path") + stage_2 = GenericStage("second", 2) + stage_3 = GenericStage("third", 3) + ppl = _make_pipeline(stage_1, stage_2, stage_3) + repr_str = repr(ppl) + assert repr_str == ( + "Pipeline(\n" + " Collection(path='/path'),\n" + " GenericStage(params=[2]),\n" + " GenericStage(params=[3])\n" + ")" + ) + +def test_pipeline_repr_long(): + from google.cloud.firestore_v1.pipeline_stages import GenericStage + num_stages = 100 + stage_list = [GenericStage("custom", i) for i in range(num_stages)] + ppl = _make_pipeline(*stage_list) + repr_str = repr(ppl) + assert repr_str.count("GenericStage") == num_stages + assert repr_str.count('\n') == num_stages+1 + +def test_pipeline__to_pb(): + from google.cloud.firestore_v1.types.pipeline import StructuredPipeline + from google.cloud.firestore_v1.pipeline_stages import GenericStage + stage_1 = GenericStage("first") + stage_2 = GenericStage("second") + ppl = _make_pipeline(stage_1, stage_2) + pb = ppl._to_pb() + assert isinstance(pb, StructuredPipeline) + assert pb.pipeline.stages[0] == stage_1._to_pb() + assert pb.pipeline.stages[1] == stage_2._to_pb() + +def test_pipeline_append(): + """append should create a new pipeline with the additional stage""" + from google.cloud.firestore_v1.pipeline_stages import GenericStage + stage_1 = GenericStage("first") + ppl_1 = _make_pipeline(stage_1, client=object()) + stage_2 = GenericStage("second") + ppl_2 = ppl_1._append(stage_2) + assert ppl_1 != ppl_2 + assert len(ppl_1.stages) == 1 + assert len(ppl_2.stages) == 2 + assert ppl_2.stages[0] == stage_1 + assert ppl_2.stages[1] == stage_2 + assert ppl_1._client == ppl_2._client + assert isinstance(ppl_2, type(ppl_1)) \ No newline at end of file From 98c7ea57dc01e9642ea925613cf233fcc8f37b3b Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 9 May 2025 16:24:43 -0700 Subject: [PATCH 113/131] added tests for execute --- google/cloud/firestore_v1/async_pipeline.py | 13 +- google/cloud/firestore_v1/pipeline.py | 12 +- google/cloud/firestore_v1/pipeline_result.py | 3 + tests/unit/v1/test_async_client.py | 3 + tests/unit/v1/test_async_pipeline.py | 200 ++++++++++++++++++- tests/unit/v1/test_client.py | 3 + tests/unit/v1/test_pipeline.py | 182 ++++++++++++++++- 7 files changed, 393 insertions(+), 23 deletions(-) diff --git a/google/cloud/firestore_v1/async_pipeline.py b/google/cloud/firestore_v1/async_pipeline.py index 5e4fc05e7..9f513362a 100644 --- a/google/cloud/firestore_v1/async_pipeline.py +++ b/google/cloud/firestore_v1/async_pipeline.py @@ -65,22 +65,17 @@ async def execute(self) -> AsyncIterable[PipelineResult]: request = ExecutePipelineRequest( database=database_name, structured_pipeline=self._to_pb(), - read_time=datetime.datetime.now(), ) async for response in await self._client._firestore_api.execute_pipeline( request ): for doc in response.results: - doc_ref = ( - AsyncDocumentReference(doc.name, client=self._client) - if doc.name - else None - ) + ref = self._client.document(doc.name) if doc.name else None yield PipelineResult( self._client, doc.fields, - doc_ref, + ref, response._pb.execution_time, - doc.create_time, - doc.update_tiem, + doc.create_time.timestamp_pb() if doc.create_time else None, + doc.update_time.timestamp_pb() if doc.update_time else None, ) diff --git a/google/cloud/firestore_v1/pipeline.py b/google/cloud/firestore_v1/pipeline.py index 264b05c18..48d508fad 100644 --- a/google/cloud/firestore_v1/pipeline.py +++ b/google/cloud/firestore_v1/pipeline.py @@ -64,16 +64,12 @@ def execute(self) -> Iterable[PipelineResult]: ) for response in self._client._firestore_api.execute_pipeline(request): for doc in response.results: - doc_ref = ( - DocumentReference(doc.name, client=self._client) - if doc.name - else None - ) + ref = self._client.document(doc.name) if doc.name else None yield PipelineResult( self._client, doc.fields, - doc_ref, + ref, response._pb.execution_time, - doc.create_time, - doc.update_tiem, + doc.create_time.timestamp_pb() if doc.create_time else None, + doc.update_time.timestamp_pb() if doc.update_time else None, ) diff --git a/google/cloud/firestore_v1/pipeline_result.py b/google/cloud/firestore_v1/pipeline_result.py index 61341db4d..98d711b71 100644 --- a/google/cloud/firestore_v1/pipeline_result.py +++ b/google/cloud/firestore_v1/pipeline_result.py @@ -61,6 +61,9 @@ def __init__( self._create_time = create_time self._update_time = update_time + def __repr__(self): + return f"{type(self).__name__}(data={self.data()})" + @property def ref(self) -> BaseDocumentReference | None: """ diff --git a/tests/unit/v1/test_async_client.py b/tests/unit/v1/test_async_client.py index cd29e8efc..b4490fb69 100644 --- a/tests/unit/v1/test_async_client.py +++ b/tests/unit/v1/test_async_client.py @@ -532,15 +532,18 @@ def test_asyncclient_transaction(): assert transaction._read_only assert transaction._id is None + def test_asyncclient_pipeline(database): from google.cloud.firestore_v1.async_pipeline import AsyncPipeline from google.cloud.firestore_v1.pipeline_source import PipelineSource + client = _make_default_async_client(database=database) ppl = client.pipeline() assert client._pipeline_cls == AsyncPipeline assert isinstance(ppl, PipelineSource) assert ppl.client == client + def _make_credentials(): import google.auth.credentials diff --git a/tests/unit/v1/test_async_pipeline.py b/tests/unit/v1/test_async_pipeline.py index 72431a952..ee6f79994 100644 --- a/tests/unit/v1/test_async_pipeline.py +++ b/tests/unit/v1/test_async_pipeline.py @@ -13,14 +13,23 @@ # limitations under the License import mock +import pytest + def _make_async_pipeline(*args, client=mock.Mock()): from google.cloud.firestore_v1.async_pipeline import AsyncPipeline return AsyncPipeline(client, *args) + +async def _async_it(list): + for value in list: + yield value + + def test_ctor(): from google.cloud.firestore_v1.async_pipeline import AsyncPipeline + client = object() stages = [object() for i in range(10)] instance = AsyncPipeline(client, *stages) @@ -29,20 +38,24 @@ def test_ctor(): assert instance.stages[0] == stages[0] assert instance.stages[-1] == stages[-1] + def test_async_pipeline_repr_empty(): ppl = _make_async_pipeline() repr_str = repr(ppl) assert repr_str == "AsyncPipeline()" + def test_async_pipeline_repr_single_stage(): stage = mock.Mock() stage.__repr__ = lambda x: "SingleStage" ppl = _make_async_pipeline(stage) repr_str = repr(ppl) - assert repr_str == 'AsyncPipeline(SingleStage)' + assert repr_str == "AsyncPipeline(SingleStage)" + def test_async_pipeline_repr_multiple_stage(): from google.cloud.firestore_v1.pipeline_stages import GenericStage, Collection + stage_1 = Collection("path") stage_2 = GenericStage("second", 2) stage_3 = GenericStage("third", 3) @@ -56,18 +69,22 @@ def test_async_pipeline_repr_multiple_stage(): ")" ) + def test_async_pipeline_repr_long(): from google.cloud.firestore_v1.pipeline_stages import GenericStage + num_stages = 100 stage_list = [GenericStage("custom", i) for i in range(num_stages)] ppl = _make_async_pipeline(*stage_list) repr_str = repr(ppl) assert repr_str.count("GenericStage") == num_stages - assert repr_str.count('\n') == num_stages+1 + assert repr_str.count("\n") == num_stages + 1 + def test_async_pipeline__to_pb(): from google.cloud.firestore_v1.types.pipeline import StructuredPipeline from google.cloud.firestore_v1.pipeline_stages import GenericStage + stage_1 = GenericStage("first") stage_2 = GenericStage("second") ppl = _make_async_pipeline(stage_1, stage_2) @@ -76,9 +93,11 @@ def test_async_pipeline__to_pb(): assert pb.pipeline.stages[0] == stage_1._to_pb() assert pb.pipeline.stages[1] == stage_2._to_pb() + def test_async_pipeline_append(): """append should create a new pipeline with the additional stage""" from google.cloud.firestore_v1.pipeline_stages import GenericStage + stage_1 = GenericStage("first") ppl_1 = _make_async_pipeline(stage_1, client=object()) stage_2 = GenericStage("second") @@ -89,4 +108,179 @@ def test_async_pipeline_append(): assert ppl_2.stages[0] == stage_1 assert ppl_2.stages[1] == stage_2 assert ppl_1._client == ppl_2._client - assert isinstance(ppl_2, type(ppl_1)) \ No newline at end of file + assert isinstance(ppl_2, type(ppl_1)) + + +@pytest.mark.asyncio +async def test_async_pipeline_execute_empty(): + """ + test execute pipeline with mocked empty response + """ + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import ExecutePipelineRequest + from google.cloud.firestore_v1.pipeline_stages import GenericStage + + client = mock.Mock() + client.project = "A" + client._database = "B" + mock_rpc = mock.AsyncMock() + client._firestore_api.execute_pipeline = mock_rpc + mock_rpc.return_value = _async_it([ExecutePipelineResponse()]) + ppl_1 = _make_async_pipeline(GenericStage("s"), client=client) + + results = [r async for r in ppl_1.execute()] + assert results == [] + assert mock_rpc.call_count == 1 + request = mock_rpc.call_args[0][0] + assert isinstance(request, ExecutePipelineRequest) + assert request.structured_pipeline == ppl_1._to_pb() + assert request.database == "projects/A/databases/B" + + +@pytest.mark.asyncio +async def test_async_pipeline_execute_no_doc_ref(): + """ + test execute pipeline with no doc ref + """ + from google.cloud.firestore_v1.types import Document + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import ExecutePipelineRequest + from google.cloud.firestore_v1.pipeline_stages import GenericStage + from google.cloud.firestore_v1.pipeline_result import PipelineResult + + client = mock.Mock() + client.project = "A" + client._database = "B" + mock_rpc = mock.AsyncMock() + client._firestore_api.execute_pipeline = mock_rpc + mock_rpc.return_value = _async_it( + [ExecutePipelineResponse(results=[Document()], execution_time={"seconds": 9})] + ) + ppl_1 = _make_async_pipeline(GenericStage("s"), client=client) + + results = [r async for r in ppl_1.execute()] + assert len(results) == 1 + assert mock_rpc.call_count == 1 + request = mock_rpc.call_args[0][0] + assert isinstance(request, ExecutePipelineRequest) + assert request.structured_pipeline == ppl_1._to_pb() + assert request.database == "projects/A/databases/B" + + response = results[0] + assert isinstance(response, PipelineResult) + assert response.ref is None + assert response.id is None + assert response.create_time is None + assert response.update_time is None + assert response.execution_time.seconds == 9 + assert response.data() == {} + + +@pytest.mark.asyncio +async def test_async_pipeline_execute_populated(): + """ + test execute pipeline with fully populated doc ref + """ + from google.cloud.firestore_v1.types import Document + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import ExecutePipelineRequest + from google.cloud.firestore_v1.types import Value + from google.cloud.firestore_v1.client import Client + from google.cloud.firestore_v1.document import DocumentReference + from google.cloud.firestore_v1.pipeline_result import PipelineResult + + real_client = Client() + client = mock.Mock() + client.project = "A" + client._database = "B" + client.document = real_client.document + mock_rpc = mock.AsyncMock() + client._firestore_api.execute_pipeline = mock_rpc + + mock_rpc.return_value = _async_it( + [ + ExecutePipelineResponse( + results=[ + Document( + name="test/my_doc", + create_time={"seconds": 1}, + update_time={"seconds": 2}, + fields={"key": Value(string_value="str_val")}, + ) + ], + execution_time={"seconds": 9}, + ) + ] + ) + ppl_1 = _make_async_pipeline(client=client) + + results = [r async for r in ppl_1.execute()] + assert len(results) == 1 + assert mock_rpc.call_count == 1 + request = mock_rpc.call_args[0][0] + assert isinstance(request, ExecutePipelineRequest) + assert request.structured_pipeline == ppl_1._to_pb() + assert request.database == "projects/A/databases/B" + + response = results[0] + assert isinstance(response, PipelineResult) + assert isinstance(response.ref, DocumentReference) + assert response.ref.path == "test/my_doc" + assert response.id == "my_doc" + assert response.create_time.seconds == 1 + assert response.update_time.seconds == 2 + assert response.execution_time.seconds == 9 + assert response.data() == {"key": "str_val"} + + +@pytest.mark.asyncio +async def test_async_pipeline_execute_multiple(): + """ + test execute pipeline with multiple docs and responses + """ + from google.cloud.firestore_v1.types import Document + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import ExecutePipelineRequest + from google.cloud.firestore_v1.types import Value + from google.cloud.firestore_v1.client import Client + from google.cloud.firestore_v1.pipeline_result import PipelineResult + + real_client = Client() + client = mock.Mock() + client.project = "A" + client._database = "B" + client.document = real_client.document + mock_rpc = mock.AsyncMock() + client._firestore_api.execute_pipeline = mock_rpc + + mock_rpc.return_value = _async_it( + [ + ExecutePipelineResponse( + results=[ + Document(fields={"key": Value(integer_value=0)}), + Document(fields={"key": Value(integer_value=1)}), + ], + execution_time={"seconds": 0}, + ), + ExecutePipelineResponse( + results=[ + Document(fields={"key": Value(integer_value=2)}), + Document(fields={"key": Value(integer_value=3)}), + ], + execution_time={"seconds": 1}, + ), + ] + ) + ppl_1 = _make_async_pipeline(client=client) + + results = [r async for r in ppl_1.execute()] + assert len(results) == 4 + assert mock_rpc.call_count == 1 + request = mock_rpc.call_args[0][0] + assert isinstance(request, ExecutePipelineRequest) + assert request.structured_pipeline == ppl_1._to_pb() + assert request.database == "projects/A/databases/B" + + for idx, response in enumerate(results): + assert isinstance(response, PipelineResult) + assert response.data() == {"key": idx} diff --git a/tests/unit/v1/test_client.py b/tests/unit/v1/test_client.py index 8c09ec274..4e5e4469a 100644 --- a/tests/unit/v1/test_client.py +++ b/tests/unit/v1/test_client.py @@ -621,16 +621,19 @@ def test_client_transaction(database): assert transaction._read_only assert transaction._id is None + @pytest.mark.parametrize("database", [None, DEFAULT_DATABASE, "somedb"]) def test_client_pipeline(database): from google.cloud.firestore_v1.pipeline import Pipeline from google.cloud.firestore_v1.pipeline_source import PipelineSource + client = _make_default_client(database=database) ppl = client.pipeline() assert client._pipeline_cls == Pipeline assert isinstance(ppl, PipelineSource) assert ppl.client == client + def _make_batch_response(**kwargs): from google.cloud.firestore_v1.types import firestore diff --git a/tests/unit/v1/test_pipeline.py b/tests/unit/v1/test_pipeline.py index f5e62b453..b8933b2c4 100644 --- a/tests/unit/v1/test_pipeline.py +++ b/tests/unit/v1/test_pipeline.py @@ -14,13 +14,16 @@ import mock + def _make_pipeline(*args, client=mock.Mock()): from google.cloud.firestore_v1.pipeline import Pipeline return Pipeline(client, *args) + def test_ctor(): from google.cloud.firestore_v1.pipeline import Pipeline + client = object() stages = [object() for i in range(10)] instance = Pipeline(client, *stages) @@ -29,20 +32,24 @@ def test_ctor(): assert instance.stages[0] == stages[0] assert instance.stages[-1] == stages[-1] + def test_pipeline_repr_empty(): ppl = _make_pipeline() repr_str = repr(ppl) assert repr_str == "Pipeline()" + def test_pipeline_repr_single_stage(): stage = mock.Mock() stage.__repr__ = lambda x: "SingleStage" ppl = _make_pipeline(stage) repr_str = repr(ppl) - assert repr_str == 'Pipeline(SingleStage)' + assert repr_str == "Pipeline(SingleStage)" + def test_pipeline_repr_multiple_stage(): from google.cloud.firestore_v1.pipeline_stages import GenericStage, Collection + stage_1 = Collection("path") stage_2 = GenericStage("second", 2) stage_3 = GenericStage("third", 3) @@ -56,18 +63,22 @@ def test_pipeline_repr_multiple_stage(): ")" ) + def test_pipeline_repr_long(): from google.cloud.firestore_v1.pipeline_stages import GenericStage + num_stages = 100 stage_list = [GenericStage("custom", i) for i in range(num_stages)] ppl = _make_pipeline(*stage_list) repr_str = repr(ppl) assert repr_str.count("GenericStage") == num_stages - assert repr_str.count('\n') == num_stages+1 + assert repr_str.count("\n") == num_stages + 1 + def test_pipeline__to_pb(): from google.cloud.firestore_v1.types.pipeline import StructuredPipeline from google.cloud.firestore_v1.pipeline_stages import GenericStage + stage_1 = GenericStage("first") stage_2 = GenericStage("second") ppl = _make_pipeline(stage_1, stage_2) @@ -76,9 +87,11 @@ def test_pipeline__to_pb(): assert pb.pipeline.stages[0] == stage_1._to_pb() assert pb.pipeline.stages[1] == stage_2._to_pb() + def test_pipeline_append(): """append should create a new pipeline with the additional stage""" from google.cloud.firestore_v1.pipeline_stages import GenericStage + stage_1 = GenericStage("first") ppl_1 = _make_pipeline(stage_1, client=object()) stage_2 = GenericStage("second") @@ -89,4 +102,167 @@ def test_pipeline_append(): assert ppl_2.stages[0] == stage_1 assert ppl_2.stages[1] == stage_2 assert ppl_1._client == ppl_2._client - assert isinstance(ppl_2, type(ppl_1)) \ No newline at end of file + assert isinstance(ppl_2, type(ppl_1)) + + +def test_pipeline_execute_empty(): + """ + test execute pipeline with mocked empty response + """ + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import ExecutePipelineRequest + from google.cloud.firestore_v1.pipeline_stages import GenericStage + + client = mock.Mock() + client.project = "A" + client._database = "B" + mock_rpc = client._firestore_api.execute_pipeline + mock_rpc.return_value = [ExecutePipelineResponse()] + ppl_1 = _make_pipeline(GenericStage("s"), client=client) + + results = list(ppl_1.execute()) + assert results == [] + assert mock_rpc.call_count == 1 + request = mock_rpc.call_args[0][0] + assert isinstance(request, ExecutePipelineRequest) + assert request.structured_pipeline == ppl_1._to_pb() + assert request.database == "projects/A/databases/B" + + +def test_pipeline_execute_no_doc_ref(): + """ + test execute pipeline with no doc ref + """ + from google.cloud.firestore_v1.types import Document + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import ExecutePipelineRequest + from google.cloud.firestore_v1.pipeline_stages import GenericStage + from google.cloud.firestore_v1.pipeline_result import PipelineResult + + client = mock.Mock() + client.project = "A" + client._database = "B" + mock_rpc = client._firestore_api.execute_pipeline + mock_rpc.return_value = [ + ExecutePipelineResponse(results=[Document()], execution_time={"seconds": 9}) + ] + ppl_1 = _make_pipeline(GenericStage("s"), client=client) + + results = list(ppl_1.execute()) + assert len(results) == 1 + assert mock_rpc.call_count == 1 + request = mock_rpc.call_args[0][0] + assert isinstance(request, ExecutePipelineRequest) + assert request.structured_pipeline == ppl_1._to_pb() + assert request.database == "projects/A/databases/B" + + response = results[0] + assert isinstance(response, PipelineResult) + assert response.ref is None + assert response.id is None + assert response.create_time is None + assert response.update_time is None + assert response.execution_time.seconds == 9 + assert response.data() == {} + + +def test_pipeline_execute_populated(): + """ + test execute pipeline with fully populated doc ref + """ + from google.cloud.firestore_v1.types import Document + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import ExecutePipelineRequest + from google.cloud.firestore_v1.types import Value + from google.cloud.firestore_v1.client import Client + from google.cloud.firestore_v1.document import DocumentReference + from google.cloud.firestore_v1.pipeline_result import PipelineResult + + real_client = Client() + client = mock.Mock() + client.project = "A" + client._database = "B" + client.document = real_client.document + mock_rpc = client._firestore_api.execute_pipeline + + mock_rpc.return_value = [ + ExecutePipelineResponse( + results=[ + Document( + name="test/my_doc", + create_time={"seconds": 1}, + update_time={"seconds": 2}, + fields={"key": Value(string_value="str_val")}, + ) + ], + execution_time={"seconds": 9}, + ) + ] + ppl_1 = _make_pipeline(client=client) + + results = list(ppl_1.execute()) + assert len(results) == 1 + assert mock_rpc.call_count == 1 + request = mock_rpc.call_args[0][0] + assert isinstance(request, ExecutePipelineRequest) + assert request.structured_pipeline == ppl_1._to_pb() + assert request.database == "projects/A/databases/B" + + response = results[0] + assert isinstance(response, PipelineResult) + assert isinstance(response.ref, DocumentReference) + assert response.ref.path == "test/my_doc" + assert response.id == "my_doc" + assert response.create_time.seconds == 1 + assert response.update_time.seconds == 2 + assert response.execution_time.seconds == 9 + assert response.data() == {"key": "str_val"} + + +def test_pipeline_execute_multiple(): + """ + test execute pipeline with multiple docs and responses + """ + from google.cloud.firestore_v1.types import Document + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import ExecutePipelineRequest + from google.cloud.firestore_v1.types import Value + from google.cloud.firestore_v1.client import Client + from google.cloud.firestore_v1.pipeline_result import PipelineResult + + real_client = Client() + client = mock.Mock() + client.project = "A" + client._database = "B" + client.document = real_client.document + mock_rpc = client._firestore_api.execute_pipeline + + mock_rpc.return_value = [ + ExecutePipelineResponse( + results=[ + Document(fields={"key": Value(integer_value=0)}), + Document(fields={"key": Value(integer_value=1)}), + ], + execution_time={"seconds": 0}, + ), + ExecutePipelineResponse( + results=[ + Document(fields={"key": Value(integer_value=2)}), + Document(fields={"key": Value(integer_value=3)}), + ], + execution_time={"seconds": 1}, + ), + ] + ppl_1 = _make_pipeline(client=client) + + results = list(ppl_1.execute()) + assert len(results) == 4 + assert mock_rpc.call_count == 1 + request = mock_rpc.call_args[0][0] + assert isinstance(request, ExecutePipelineRequest) + assert request.structured_pipeline == ppl_1._to_pb() + assert request.database == "projects/A/databases/B" + + for idx, response in enumerate(results): + assert isinstance(response, PipelineResult) + assert response.data() == {"key": idx} From af4fb20ac604a857742ca28bd31900e2ffb25940 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 9 May 2025 16:37:07 -0700 Subject: [PATCH 114/131] broke out shared logic into base_pipeline --- google/cloud/firestore_v1/async_pipeline.py | 25 +++------------- google/cloud/firestore_v1/base_pipeline.py | 33 ++++++++++++++++++++- google/cloud/firestore_v1/pipeline.py | 24 ++------------- 3 files changed, 39 insertions(+), 43 deletions(-) diff --git a/google/cloud/firestore_v1/async_pipeline.py b/google/cloud/firestore_v1/async_pipeline.py index 9f513362a..62f729273 100644 --- a/google/cloud/firestore_v1/async_pipeline.py +++ b/google/cloud/firestore_v1/async_pipeline.py @@ -16,13 +16,11 @@ import datetime from typing import AsyncIterable, TYPE_CHECKING from google.cloud.firestore_v1 import pipeline_stages as stages -from google.cloud.firestore_v1.types.firestore import ExecutePipelineRequest -from google.cloud.firestore_v1.async_document import AsyncDocumentReference from google.cloud.firestore_v1.base_pipeline import _BasePipeline -from google.cloud.firestore_v1.pipeline_result import PipelineResult if TYPE_CHECKING: from google.cloud.firestore_v1.async_client import AsyncClient + from google.cloud.firestore_v1.pipeline_result import PipelineResult class AsyncPipeline(_BasePipeline): @@ -59,23 +57,8 @@ def __init__(self, client: AsyncClient, *stages: stages.Stage): super().__init__(client, *stages) async def execute(self) -> AsyncIterable[PipelineResult]: - database_name = ( - f"projects/{self._client.project}/databases/{self._client._database}" - ) - request = ExecutePipelineRequest( - database=database_name, - structured_pipeline=self._to_pb(), - ) async for response in await self._client._firestore_api.execute_pipeline( - request + self._execute_request_helper() ): - for doc in response.results: - ref = self._client.document(doc.name) if doc.name else None - yield PipelineResult( - self._client, - doc.fields, - ref, - response._pb.execution_time, - doc.create_time.timestamp_pb() if doc.create_time else None, - doc.update_time.timestamp_pb() if doc.update_time else None, - ) + for result in self._execute_response_helper(response): + yield result diff --git a/google/cloud/firestore_v1/base_pipeline.py b/google/cloud/firestore_v1/base_pipeline.py index b9ac6f084..8d3957d88 100644 --- a/google/cloud/firestore_v1/base_pipeline.py +++ b/google/cloud/firestore_v1/base_pipeline.py @@ -13,15 +13,18 @@ # limitations under the License. from __future__ import annotations -from typing import TYPE_CHECKING +from typing import Iterable, TYPE_CHECKING from google.cloud.firestore_v1 import pipeline_stages as stages from google.cloud.firestore_v1.types.pipeline import ( StructuredPipeline as StructuredPipeline_pb, ) +from google.cloud.firestore_v1.types.firestore import ExecutePipelineRequest +from google.cloud.firestore_v1.pipeline_result import PipelineResult if TYPE_CHECKING: from google.cloud.firestore_v1.client import Client from google.cloud.firestore_v1.async_client import AsyncClient + from google.cloud.firestore_v1.types.firestore import ExecutePipelineResponse class _BasePipeline: @@ -65,3 +68,31 @@ def _append(self, new_stage): Create a new Pipeline object with a new stage appended """ return self.__class__(self._client, *self.stages, new_stage) + + def _execute_request_helper(self) -> ExecutePipelineRequest: + """ + shared logic for creating an ExecutePipelineRequest + """ + database_name = ( + f"projects/{self._client.project}/databases/{self._client._database}" + ) + request = ExecutePipelineRequest( + database=database_name, + structured_pipeline=self._to_pb(), + ) + return request + + def _execute_response_helper(self, response:ExecutePipelineResponse) -> Iterable[PipelineResult]: + """ + shared logic for unpacking an ExecutePipelineReponse into PipelineResults + """ + for doc in response.results: + ref = self._client.document(doc.name) if doc.name else None + yield PipelineResult( + self._client, + doc.fields, + ref, + response._pb.execution_time, + doc.create_time.timestamp_pb() if doc.create_time else None, + doc.update_time.timestamp_pb() if doc.update_time else None, + ) \ No newline at end of file diff --git a/google/cloud/firestore_v1/pipeline.py b/google/cloud/firestore_v1/pipeline.py index 48d508fad..40945146e 100644 --- a/google/cloud/firestore_v1/pipeline.py +++ b/google/cloud/firestore_v1/pipeline.py @@ -15,13 +15,11 @@ from __future__ import annotations from typing import Iterable, TYPE_CHECKING from google.cloud.firestore_v1 import pipeline_stages as stages -from google.cloud.firestore_v1.types.firestore import ExecutePipelineRequest -from google.cloud.firestore_v1.document import DocumentReference from google.cloud.firestore_v1.base_pipeline import _BasePipeline -from google.cloud.firestore_v1.pipeline_result import PipelineResult if TYPE_CHECKING: from google.cloud.firestore_v1.client import Client + from google.cloud.firestore_v1.pipeline_result import PipelineResult class Pipeline(_BasePipeline): @@ -55,21 +53,5 @@ def __init__(self, client: Client, *stages: stages.Stage): super().__init__(client, *stages) def execute(self) -> Iterable[PipelineResult]: - database_name = ( - f"projects/{self._client.project}/databases/{self._client._database}" - ) - request = ExecutePipelineRequest( - database=database_name, - structured_pipeline=self._to_pb(), - ) - for response in self._client._firestore_api.execute_pipeline(request): - for doc in response.results: - ref = self._client.document(doc.name) if doc.name else None - yield PipelineResult( - self._client, - doc.fields, - ref, - response._pb.execution_time, - doc.create_time.timestamp_pb() if doc.create_time else None, - doc.update_time.timestamp_pb() if doc.update_time else None, - ) + for response in self._client._firestore_api.execute_pipeline(self._execute_request_helper()): + yield from self._execute_response_helper(response) \ No newline at end of file From cd38fc20b7a12a23fba476de043ca76aa9335640 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 12 May 2025 13:22:03 -0700 Subject: [PATCH 115/131] added pipeline stages tests --- google/cloud/firestore_v1/pipeline_stages.py | 3 + tests/unit/v1/test_pipeline_stages.py | 99 ++++++++++++++++++++ 2 files changed, 102 insertions(+) create mode 100644 tests/unit/v1/test_pipeline_stages.py diff --git a/google/cloud/firestore_v1/pipeline_stages.py b/google/cloud/firestore_v1/pipeline_stages.py index 0eaefd7d0..7849e6a8e 100644 --- a/google/cloud/firestore_v1/pipeline_stages.py +++ b/google/cloud/firestore_v1/pipeline_stages.py @@ -122,3 +122,6 @@ def __init__(self, name: str, *params: Expr | Value): def _pb_args(self): return self.params + + def __repr__(self): + return f"{self.__class__.__name__}(name='{self.name}')" \ No newline at end of file diff --git a/tests/unit/v1/test_pipeline_stages.py b/tests/unit/v1/test_pipeline_stages.py new file mode 100644 index 000000000..49d4816fa --- /dev/null +++ b/tests/unit/v1/test_pipeline_stages.py @@ -0,0 +1,99 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import mock +import pytest + +import google.cloud.firestore_v1.pipeline_stages as stages +from google.cloud.firestore_v1.pipeline_expressions import Constant +from google.cloud.firestore_v1.types.document import Value +from google.cloud.firestore_v1._helpers import GeoPoint + + +class TestStage: + + def test_ctor(self): + """ + Base class should be abstract + """ + with pytest.raises(TypeError): + stages.Stage() + + +class TestCollection: + + def _make_one(self, *args, **kwargs): + return stages.Collection(*args, **kwargs) + + @pytest.mark.parametrize("input_arg,expected", [ + ("test", "Collection(path='/test')"), + ("/test", "Collection(path='/test')"), + ]) + def test_repr(self, input_arg, expected): + instance = self._make_one(input_arg) + repr_str = repr(instance) + assert repr_str == expected + + def test_to_pb(self): + input_arg = "test/col" + instance = self._make_one(input_arg) + result = instance._to_pb() + assert result.name == "collection" + assert len(result.args) == 1 + assert result.args[0].reference_value == "/test/col" + assert len(result.options) == 0 + +class TestGenericStage: + + def _make_one(self, *args, **kwargs): + return stages.GenericStage(*args, **kwargs) + + @pytest.mark.parametrize("input_args,expected_params", [ + (("name",), []), + (("custom",Value(string_value="val")), [Value(string_value="val")]), + (("n",Value(integer_value=1)), [Value(integer_value=1)]), + (("n",Constant.of(1)), [Value(integer_value=1)]), + (("n",Constant.of(True), Constant.of(False)), [Value(boolean_value=True), Value(boolean_value=False)]), + (("n", Constant.of(GeoPoint(1, 2))), [Value(geo_point_value={"latitude": 1, "longitude": 2})]), + (("n", Constant.of(None)), [Value(null_value=0)]), + (("n", Constant.of([0,1,2])), [Value(array_value={"values": [Value(integer_value=n) for n in range(3)]})]), + (("n", Value(reference_value="/projects/p/databases/d/documents/doc")), [Value(reference_value="/projects/p/databases/d/documents/doc")]), + (("n", Constant.of({"a": "b"})), [Value(map_value={"fields": {"a": Value(string_value="b")}})]), + ]) + def test_ctor(self, input_args, expected_params): + instance = self._make_one(*input_args) + assert instance.params == expected_params + + @pytest.mark.parametrize("input_args,expected", [ + (("name",), "GenericStage(name='name')"), + (("custom",Value(string_value="val")), "GenericStage(name='custom')"), + ]) + def test_repr(self, input_args, expected): + instance = self._make_one(*input_args) + repr_str = repr(instance) + assert repr_str == expected + + def test_repr_with_reference(self): + """ + reference_value can't properly be displayed without knowing the + """ + + def test_to_pb(self): + instance = self._make_one("name", Constant.of(True), Constant.of("test")) + result = instance._to_pb() + assert result.name == "name" + assert len(result.args) == 2 + assert result.args[0].boolean_value is True + assert result.args[1].string_value == "test" + assert len(result.options) == 0 \ No newline at end of file From 0ac319e97370fc18eadb4687e94114729239aadf Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 12 May 2025 13:23:17 -0700 Subject: [PATCH 116/131] removed unneeded stages --- google/cloud/firestore_v1/pipeline_stages.py | 45 +------------------- 1 file changed, 1 insertion(+), 44 deletions(-) diff --git a/google/cloud/firestore_v1/pipeline_stages.py b/google/cloud/firestore_v1/pipeline_stages.py index 7849e6a8e..840feca5e 100644 --- a/google/cloud/firestore_v1/pipeline_stages.py +++ b/google/cloud/firestore_v1/pipeline_stages.py @@ -19,9 +19,7 @@ from google.cloud.firestore_v1.types.document import Pipeline as Pipeline_pb from google.cloud.firestore_v1.types.document import Value -from google.cloud.firestore_v1.pipeline_expressions import ( - Expr, -) +from google.cloud.firestore_v1.pipeline_expressions import Expr if TYPE_CHECKING: from google.cloud.firestore_v1.base_document import BaseDocumentReference @@ -70,47 +68,6 @@ def _pb_args(self): return [Value(reference_value=self.path)] -class CollectionGroup(Stage): - """Specifies a collection group as the initial data source.""" - - def __init__(self, collection_id: str): - super().__init__("collection_group") - self.collection_id = collection_id - - def _pb_args(self): - return [Value(string_value=self.collection_id)] - - -class Database(Stage): - """Specifies the default database as the initial data source.""" - - def __init__(self): - super().__init__() - - def _pb_args(self): - return [] - - -class Documents(Stage): - """Specifies specific documents as the initial data source.""" - - def __init__(self, *paths: str): - super().__init__() - self.paths = paths - - @staticmethod - def of(*documents: "BaseDocumentReference") -> "Documents": - doc_paths = ["/" + doc.path for doc in documents] - return Documents(*doc_paths) - - def _pb_args(self): - return [ - Value( - list_value={"values": [Value(string_value=path) for path in self.paths]} - ) - ] - - class GenericStage(Stage): """Represents a generic, named stage with parameters.""" From 7981a348e9646b1be84d6d22a179dded79da32ed Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 12 May 2025 13:57:00 -0700 Subject: [PATCH 117/131] added tests for pipeline expressions --- google/cloud/firestore_v1/_helpers.py | 3 + tests/unit/v1/test_pipeline_expressions.py | 88 ++++++++++++++++++++++ tests/unit/v1/test_pipeline_stages.py | 81 +++++++++++++------- 3 files changed, 143 insertions(+), 29 deletions(-) create mode 100644 tests/unit/v1/test_pipeline_expressions.py diff --git a/google/cloud/firestore_v1/_helpers.py b/google/cloud/firestore_v1/_helpers.py index 399bdb066..1fbc1a476 100644 --- a/google/cloud/firestore_v1/_helpers.py +++ b/google/cloud/firestore_v1/_helpers.py @@ -120,6 +120,9 @@ def __ne__(self, other): else: return not equality_val + def __repr__(self): + return f"{type(self).__name__}(latitude={self.latitude}, longitude={self.longitude})" + def verify_path(path, is_collection) -> None: """Verifies that a ``path`` has the correct form. diff --git a/tests/unit/v1/test_pipeline_expressions.py b/tests/unit/v1/test_pipeline_expressions.py new file mode 100644 index 000000000..2be7b5eb5 --- /dev/null +++ b/tests/unit/v1/test_pipeline_expressions.py @@ -0,0 +1,88 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import mock +import pytest +import datetime + +import google.cloud.firestore_v1.pipeline_expressions as expressions +from google.cloud.firestore_v1.types.document import Value +from google.cloud.firestore_v1.vector import Vector +from google.cloud.firestore_v1._helpers import GeoPoint + + +class TestExpr: + def test_ctor(self): + """ + Base class should be abstract + """ + with pytest.raises(TypeError): + expressions.Expr() + + +class TestConstant: + @pytest.mark.parametrize("input_val, to_pb_val",[ + ("test", Value(string_value="test")), + ("", Value(string_value="")), + (10, Value(integer_value=10)), + (0, Value(integer_value=0)), + (10.0, Value(double_value=10)), + (0.0, Value(double_value=0)), + (True, Value(boolean_value=True)), + (b"test", Value(bytes_value=b"test")), + (None, Value(null_value=0)), + ( + datetime.datetime(2025, 5, 12), + Value(timestamp_value={"seconds": 1747008000}), + ), + (GeoPoint(1, 2), Value(geo_point_value={"latitude": 1, "longitude": 2})), + ( + [0.0, 1.0, 2.0], + Value(array_value={"values": [Value(double_value=i) for i in range(3)]}), + ), + ({"a": "b"}, Value(map_value={"fields": {"a": Value(string_value="b")}})), + ( + Vector([1.0, 2.0]), + Value(map_value={"fields": { + "__type__": Value(string_value="__vector__"), + "value": Value(array_value={ + "values": [Value(double_value=v) for v in [1,2]], + }), + }}), + ), + ]) + def test_to_pb(self, input_val, to_pb_val): + instance = expressions.Constant.of(input_val) + assert instance._to_pb() == to_pb_val + + @pytest.mark.parametrize("input_val,expected", [ + ("test", "Constant.of('test')"), + ("", "Constant.of('')"), + (10, "Constant.of(10)"), + (0, "Constant.of(0)"), + (10.0, "Constant.of(10.0)"), + (0.0, "Constant.of(0.0)"), + (True, "Constant.of(True)"), + (b"test", "Constant.of(b'test')"), + (None, "Constant.of(None)"), + (datetime.datetime(2025, 5, 12), "Constant.of(datetime.datetime(2025, 5, 12, 0, 0))"), + (GeoPoint(1, 2), "Constant.of(GeoPoint(latitude=1, longitude=2))"), + ([1, 2, 3], "Constant.of([1, 2, 3])"), + ({"a": "b"}, "Constant.of({'a': 'b'})"), + (Vector([1.0, 2.0]), "Constant.of(Vector<1.0, 2.0>)"), + ]) + def test_repr(self, input_val, expected): + instance = expressions.Constant.of(input_val) + repr_string = repr(instance) + assert repr_string == expected diff --git a/tests/unit/v1/test_pipeline_stages.py b/tests/unit/v1/test_pipeline_stages.py index 49d4816fa..d0568dcf5 100644 --- a/tests/unit/v1/test_pipeline_stages.py +++ b/tests/unit/v1/test_pipeline_stages.py @@ -22,7 +22,6 @@ class TestStage: - def test_ctor(self): """ Base class should be abstract @@ -32,14 +31,16 @@ def test_ctor(self): class TestCollection: - def _make_one(self, *args, **kwargs): return stages.Collection(*args, **kwargs) - @pytest.mark.parametrize("input_arg,expected", [ - ("test", "Collection(path='/test')"), - ("/test", "Collection(path='/test')"), - ]) + @pytest.mark.parametrize( + "input_arg,expected", + [ + ("test", "Collection(path='/test')"), + ("/test", "Collection(path='/test')"), + ], + ) def test_repr(self, input_arg, expected): instance = self._make_one(input_arg) repr_str = repr(instance) @@ -54,41 +55,63 @@ def test_to_pb(self): assert result.args[0].reference_value == "/test/col" assert len(result.options) == 0 -class TestGenericStage: +class TestGenericStage: def _make_one(self, *args, **kwargs): return stages.GenericStage(*args, **kwargs) - @pytest.mark.parametrize("input_args,expected_params", [ - (("name",), []), - (("custom",Value(string_value="val")), [Value(string_value="val")]), - (("n",Value(integer_value=1)), [Value(integer_value=1)]), - (("n",Constant.of(1)), [Value(integer_value=1)]), - (("n",Constant.of(True), Constant.of(False)), [Value(boolean_value=True), Value(boolean_value=False)]), - (("n", Constant.of(GeoPoint(1, 2))), [Value(geo_point_value={"latitude": 1, "longitude": 2})]), - (("n", Constant.of(None)), [Value(null_value=0)]), - (("n", Constant.of([0,1,2])), [Value(array_value={"values": [Value(integer_value=n) for n in range(3)]})]), - (("n", Value(reference_value="/projects/p/databases/d/documents/doc")), [Value(reference_value="/projects/p/databases/d/documents/doc")]), - (("n", Constant.of({"a": "b"})), [Value(map_value={"fields": {"a": Value(string_value="b")}})]), - ]) + @pytest.mark.parametrize( + "input_args,expected_params", + [ + (("name",), []), + (("custom", Value(string_value="val")), [Value(string_value="val")]), + (("n", Value(integer_value=1)), [Value(integer_value=1)]), + (("n", Constant.of(1)), [Value(integer_value=1)]), + ( + ("n", Constant.of(True), Constant.of(False)), + [Value(boolean_value=True), Value(boolean_value=False)], + ), + ( + ("n", Constant.of(GeoPoint(1, 2))), + [Value(geo_point_value={"latitude": 1, "longitude": 2})], + ), + (("n", Constant.of(None)), [Value(null_value=0)]), + ( + ("n", Constant.of([0, 1, 2])), + [ + Value( + array_value={ + "values": [Value(integer_value=n) for n in range(3)] + } + ) + ], + ), + ( + ("n", Value(reference_value="/projects/p/databases/d/documents/doc")), + [Value(reference_value="/projects/p/databases/d/documents/doc")], + ), + ( + ("n", Constant.of({"a": "b"})), + [Value(map_value={"fields": {"a": Value(string_value="b")}})], + ), + ], + ) def test_ctor(self, input_args, expected_params): instance = self._make_one(*input_args) assert instance.params == expected_params - @pytest.mark.parametrize("input_args,expected", [ - (("name",), "GenericStage(name='name')"), - (("custom",Value(string_value="val")), "GenericStage(name='custom')"), - ]) + @pytest.mark.parametrize( + "input_args,expected", + [ + (("name",), "GenericStage(name='name')"), + (("custom", Value(string_value="val")), "GenericStage(name='custom')"), + ], + ) def test_repr(self, input_args, expected): instance = self._make_one(*input_args) repr_str = repr(instance) assert repr_str == expected - def test_repr_with_reference(self): - """ - reference_value can't properly be displayed without knowing the - """ - def test_to_pb(self): instance = self._make_one("name", Constant.of(True), Constant.of("test")) result = instance._to_pb() @@ -96,4 +119,4 @@ def test_to_pb(self): assert len(result.args) == 2 assert result.args[0].boolean_value is True assert result.args[1].string_value == "test" - assert len(result.options) == 0 \ No newline at end of file + assert len(result.options) == 0 From 62b65107b2ad5d73ab53e982183a6d2b4b684c9d Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 12 May 2025 14:57:09 -0700 Subject: [PATCH 118/131] added pipeline_result tests --- google/cloud/firestore_v1/pipeline_result.py | 9 +- tests/unit/v1/test_pipeline_result.py | 162 +++++++++++++++++++ 2 files changed, 166 insertions(+), 5 deletions(-) create mode 100644 tests/unit/v1/test_pipeline_result.py diff --git a/google/cloud/firestore_v1/pipeline_result.py b/google/cloud/firestore_v1/pipeline_result.py index 98d711b71..54bfc152f 100644 --- a/google/cloud/firestore_v1/pipeline_result.py +++ b/google/cloud/firestore_v1/pipeline_result.py @@ -14,8 +14,7 @@ from __future__ import annotations from typing import Any, TYPE_CHECKING -from google.cloud.firestore_v1._helpers import decode_dict -from google.cloud.firestore_v1._helpers import decode_value +from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.field_path import get_nested_value from google.cloud.firestore_v1.field_path import FieldPath @@ -78,7 +77,7 @@ def id(self) -> str | None: @property def create_time(self) -> Timestamp | None: - """The creation time of the document. `None` if not applicable (e.g., not a document result or document doesn't exist).""" + """The creation time of the document. `None` if not applicable.""" return self._create_time @property @@ -124,7 +123,7 @@ def data(self) -> Any: if self._fields_pb is None: return None - return decode_dict(self._fields_pb, self._client) + return _helpers.decode_dict(self._fields_pb, self._client) def get(self, field_path: str | FieldPath) -> Any: """ @@ -140,4 +139,4 @@ def get(self, field_path: str | FieldPath) -> Any: field_path if isinstance(field_path, str) else field_path.to_api_repr() ) value = get_nested_value(str_path, self._fields_pb) - return decode_value(value, self._client) + return _helpers.decode_value(value, self._client) diff --git a/tests/unit/v1/test_pipeline_result.py b/tests/unit/v1/test_pipeline_result.py new file mode 100644 index 000000000..d4bf538a0 --- /dev/null +++ b/tests/unit/v1/test_pipeline_result.py @@ -0,0 +1,162 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import mock +import pytest +import datetime + +from google.cloud.firestore_v1.pipeline_result import PipelineResult + + +class TestPipelineResult: + + def _make_one(self, *args, **kwargs): + if not args: + # use defaults if not passed + args = [mock.Mock(), {}] + return PipelineResult(*args, **kwargs) + + def test_ref(self): + expected = object() + instance = self._make_one(ref=expected) + assert instance.ref == expected + # should be None if not set + assert self._make_one().ref == None + + def test_id(self): + ref = mock.Mock() + ref.id = "test" + instance = self._make_one(ref=ref) + assert instance.id == "test" + # should be None if not set + assert self._make_one().id == None + + def test_create_time(self): + expected = object() + instance = self._make_one(create_time=expected) + assert instance.create_time == expected + # should be None if not set + assert self._make_one().create_time == None + + def test_update_time(self): + expected = object() + instance = self._make_one(update_time=expected) + assert instance.update_time == expected + # should be None if not set + assert self._make_one().update_time== None + + def test_exection_time(self): + expected = object() + instance = self._make_one(execution_time=expected) + assert instance.execution_time == expected + # should raise if not set + with pytest.raises(ValueError) as e: + self._make_one().execution_time + assert "execution_time" in e + + @pytest.mark.parametrize("first,second,result", [ + ((object(),{}), (object(), {}), True), + ((object(),{1:1}), (object(), {1:1}), True), + ((object(),{1:1}), (object(), {2:2}), False), + ((object(),{}, "ref"), (object(), {}, "ref"), True), + ((object(),{}, "ref"), (object(), {}, "diff"), False), + ((object(),{1:1}, "ref"), (object(), {1:1}, "ref"), True), + ((object(),{1:1}, "ref"), (object(), {2:2}, "ref"), False), + ((object(),{1:1}, "ref"), (object(), {1:1}, "diff"), False), + ((object(),{1:1}, "ref", 1,2,3), (object(), {1:1}, "ref", 4,5,6), True), + ]) + def test_eq(self, first, second, result): + first_obj = self._make_one(*first) + second_obj = self._make_one(*second) + assert (first_obj == second_obj) is result + + def test_data(self): + from google.cloud.firestore_v1.types.document import Value + client = mock.Mock() + data = {"str": Value(string_value="hello world"), "int": Value(integer_value=5)} + instance = self._make_one(client, data) + got = instance.data() + assert len(got) == 2 + assert got["str"] == "hello world" + assert got["int"] == 5 + + def test_data_none(self): + client = object() + data = None + instance = self._make_one(client, data) + assert instance.data() is None + + def test_data_call(self): + """ + ensure decode_dict is called on .data + """ + client = object() + data = {"hello": "world"} + instance = self._make_one(client, data) + with mock.patch("google.cloud.firestore_v1._helpers.decode_dict") as decode_mock: + got = instance.data() + decode_mock.assert_called_once_with(data, client) + assert got == decode_mock.return_value + + def test_get(self): + from google.cloud.firestore_v1.types.document import Value + client = object() + data = {"key": Value(string_value="hello world")} + instance = self._make_one(client, data) + got = instance.get("key") + assert got == "hello world" + + def test_get_nested(self): + from google.cloud.firestore_v1.types.document import Value + client = object() + data = { + "first": {"second": Value(string_value="hello world")} + } + instance = self._make_one(client, data) + got = instance.get("first.second") + assert got == "hello world" + + def test_get_field_path(self): + from google.cloud.firestore_v1.types.document import Value + from google.cloud.firestore_v1.field_path import FieldPath + client = object() + data = { + "first": {"second": Value(string_value="hello world")} + } + path = FieldPath.from_string("first.second") + instance = self._make_one(client, data) + got = instance.get(path) + assert got == "hello world" + + def test_get_failure(self): + """ + test calling get on value not in data + """ + client = object() + data = {} + instance = self._make_one(client, data) + with pytest.raises(KeyError): + instance.get("key") + + def test_get_call(self): + """ + ensure decode_value is called on .get() + """ + client = object() + data = {"key": "value"} + instance = self._make_one(client, data) + with mock.patch("google.cloud.firestore_v1._helpers.decode_value") as decode_mock: + got = instance.get("key") + decode_mock.assert_called_once_with("value", client) + assert got == decode_mock.return_value \ No newline at end of file From 035c6e6baf3a4ecbc024089c3a8c6cad279219bd Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 12 May 2025 15:11:13 -0700 Subject: [PATCH 119/131] added tests for pipeline source --- google/cloud/firestore_v1/pipeline_source.py | 5 +- tests/unit/v1/test_pipeline_source.py | 59 ++++++++++++++++++++ 2 files changed, 63 insertions(+), 1 deletion(-) create mode 100644 tests/unit/v1/test_pipeline_source.py diff --git a/google/cloud/firestore_v1/pipeline_source.py b/google/cloud/firestore_v1/pipeline_source.py index d072c328b..ab9cfada7 100644 --- a/google/cloud/firestore_v1/pipeline_source.py +++ b/google/cloud/firestore_v1/pipeline_source.py @@ -41,6 +41,9 @@ class PipelineSource(Generic[PipelineType]): def __init__(self, client: Client | AsyncClient): self.client = client + def _create_pipeline(self, source_stage): + return self.client._pipeline_cls(self.client, source_stage) + def collection(self, path: str) -> PipelineType: """ Creates a new Pipeline that operates on a specified Firestore collection. @@ -50,4 +53,4 @@ def collection(self, path: str) -> PipelineType: Returns: a new pipeline instance targeting the specified collection """ - return self.client._pipeline_cls(self.client, stages.Collection(path)) + return self._create_pipeline(stages.Collection(path)) diff --git a/tests/unit/v1/test_pipeline_source.py b/tests/unit/v1/test_pipeline_source.py new file mode 100644 index 000000000..0794ed0de --- /dev/null +++ b/tests/unit/v1/test_pipeline_source.py @@ -0,0 +1,59 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import mock +import pytest + +from google.cloud.firestore_v1.pipeline_source import PipelineSource +from google.cloud.firestore_v1.pipeline import Pipeline +from google.cloud.firestore_v1.async_pipeline import AsyncPipeline +from google.cloud.firestore_v1.client import Client +from google.cloud.firestore_v1.async_client import AsyncClient +from google.cloud.firestore_v1 import pipeline_stages as stages + + +class TestPipelineSource: + + _expected_pipeline_type = Pipeline + + def _make_client(self): + return Client() + + def test_make_from_client(self): + instance = self._make_client().pipeline() + assert isinstance(instance, PipelineSource) + + def test_create_pipeline(self): + instance = self._make_client().pipeline() + ppl = instance._create_pipeline(None) + assert isinstance(ppl, self._expected_pipeline_type) + + def test_collection(self): + instance = self._make_client().pipeline() + ppl = instance.collection("path") + assert isinstance(ppl, self._expected_pipeline_type) + assert len(ppl.stages) == 1 + first_stage = ppl.stages[0] + assert isinstance(first_stage, stages.Collection) + assert first_stage.path == "/path" + + +class TestPipelineSourceWithAsyncClient(TestPipelineSource): + """ + When an async client is used, it should produce async pipelines + """ + _expected_pipeline_type = AsyncPipeline + + def _make_client(self): + return AsyncClient() \ No newline at end of file From 5dd12469529e2b23301966fc7876ac5f02abfa81 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 12 May 2025 15:47:49 -0700 Subject: [PATCH 120/131] added transaction to execute call --- google/cloud/firestore_v1/async_pipeline.py | 24 ++++++++++--- google/cloud/firestore_v1/base_pipeline.py | 6 +++- google/cloud/firestore_v1/pipeline.py | 20 +++++++++-- tests/unit/v1/test_async_pipeline.py | 37 +++++++++++++++++++-- tests/unit/v1/test_pipeline.py | 34 +++++++++++++++++-- 5 files changed, 109 insertions(+), 12 deletions(-) diff --git a/google/cloud/firestore_v1/async_pipeline.py b/google/cloud/firestore_v1/async_pipeline.py index 62f729273..321d9033e 100644 --- a/google/cloud/firestore_v1/async_pipeline.py +++ b/google/cloud/firestore_v1/async_pipeline.py @@ -21,6 +21,7 @@ if TYPE_CHECKING: from google.cloud.firestore_v1.async_client import AsyncClient from google.cloud.firestore_v1.pipeline_result import PipelineResult + from google.cloud.firestore_v1.async_transaction import AsyncTransaction class AsyncPipeline(_BasePipeline): @@ -56,9 +57,22 @@ def __init__(self, client: AsyncClient, *stages: stages.Stage): """ super().__init__(client, *stages) - async def execute(self) -> AsyncIterable[PipelineResult]: - async for response in await self._client._firestore_api.execute_pipeline( - self._execute_request_helper() - ): + async def execute( + self, + transaction: "AsyncTransaction" | None=None, + ) -> AsyncIterable[PipelineResult]: + """ + Executes this pipeline, providing results through an Iterable + + Args: + transaction + (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + An existing transaction that this query will run in. + If a ``transaction`` is used and it already has write operations + added, this method cannot be used (i.e. read-after-write is not + allowed). + """ + request = self._prep_execute_request(transaction) + async for response in await self._client._firestore_api.execute_pipeline(request): for result in self._execute_response_helper(response): - yield result + yield result \ No newline at end of file diff --git a/google/cloud/firestore_v1/base_pipeline.py b/google/cloud/firestore_v1/base_pipeline.py index 8d3957d88..cb80be3b8 100644 --- a/google/cloud/firestore_v1/base_pipeline.py +++ b/google/cloud/firestore_v1/base_pipeline.py @@ -20,11 +20,13 @@ ) from google.cloud.firestore_v1.types.firestore import ExecutePipelineRequest from google.cloud.firestore_v1.pipeline_result import PipelineResult +from google.cloud.firestore_v1 import _helpers if TYPE_CHECKING: from google.cloud.firestore_v1.client import Client from google.cloud.firestore_v1.async_client import AsyncClient from google.cloud.firestore_v1.types.firestore import ExecutePipelineResponse + from google.cloud.firestore_v1.transaction import BaseTransaction class _BasePipeline: @@ -69,15 +71,17 @@ def _append(self, new_stage): """ return self.__class__(self._client, *self.stages, new_stage) - def _execute_request_helper(self) -> ExecutePipelineRequest: + def _prep_execute_request(self, transaction: BaseTransaction | None) -> ExecutePipelineRequest: """ shared logic for creating an ExecutePipelineRequest """ database_name = ( f"projects/{self._client.project}/databases/{self._client._database}" ) + transaction_id = _helpers.get_transaction_id(transaction) if transaction is not None else None request = ExecutePipelineRequest( database=database_name, + transaction=transaction_id, structured_pipeline=self._to_pb(), ) return request diff --git a/google/cloud/firestore_v1/pipeline.py b/google/cloud/firestore_v1/pipeline.py index 40945146e..0c85a57f3 100644 --- a/google/cloud/firestore_v1/pipeline.py +++ b/google/cloud/firestore_v1/pipeline.py @@ -20,6 +20,7 @@ if TYPE_CHECKING: from google.cloud.firestore_v1.client import Client from google.cloud.firestore_v1.pipeline_result import PipelineResult + from google.cloud.firestore_v1.transaction import Transaction class Pipeline(_BasePipeline): @@ -52,6 +53,21 @@ def __init__(self, client: Client, *stages: stages.Stage): """ super().__init__(client, *stages) - def execute(self) -> Iterable[PipelineResult]: - for response in self._client._firestore_api.execute_pipeline(self._execute_request_helper()): + def execute( + self, + transaction: "Transaction" | None=None, + ) -> Iterable[PipelineResult]: + """ + Executes this pipeline, providing results through an Iterable + + Args: + transaction + (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + An existing transaction that this query will run in. + If a ``transaction`` is used and it already has write operations + added, this method cannot be used (i.e. read-after-write is not + allowed). + """ + request = self._prep_execute_request(transaction) + for response in self._client._firestore_api.execute_pipeline(request): yield from self._execute_response_helper(response) \ No newline at end of file diff --git a/tests/unit/v1/test_async_pipeline.py b/tests/unit/v1/test_async_pipeline.py index ee6f79994..628d37ddd 100644 --- a/tests/unit/v1/test_async_pipeline.py +++ b/tests/unit/v1/test_async_pipeline.py @@ -64,8 +64,8 @@ def test_async_pipeline_repr_multiple_stage(): assert repr_str == ( "AsyncPipeline(\n" " Collection(path='/path'),\n" - " GenericStage(params=[2]),\n" - " GenericStage(params=[3])\n" + " GenericStage(name='second'),\n" + " GenericStage(name='third')\n" ")" ) @@ -165,6 +165,7 @@ async def test_async_pipeline_execute_no_doc_ref(): assert isinstance(request, ExecutePipelineRequest) assert request.structured_pipeline == ppl_1._to_pb() assert request.database == "projects/A/databases/B" + assert request.transaction == b'' response = results[0] assert isinstance(response, PipelineResult) @@ -284,3 +285,35 @@ async def test_async_pipeline_execute_multiple(): for idx, response in enumerate(results): assert isinstance(response, PipelineResult) assert response.data() == {"key": idx} + + +@pytest.mark.asyncio +async def test_async_pipeline_execute_with_transaction(): + """ + test execute pipeline with transaction context + """ + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import ExecutePipelineRequest + from google.cloud.firestore_v1.async_transaction import AsyncTransaction + + client = mock.Mock() + client.project = "A" + client._database = "B" + mock_rpc = mock.AsyncMock() + client._firestore_api.execute_pipeline = mock_rpc + + transaction = AsyncTransaction(client) + transaction._id = b"123" + + mock_rpc.return_value = _async_it([ + ExecutePipelineResponse() + ]) + ppl_1 = _make_async_pipeline(client=client) + + [r async for r in ppl_1.execute(transaction=transaction)] + assert mock_rpc.call_count == 1 + request = mock_rpc.call_args[0][0] + assert isinstance(request, ExecutePipelineRequest) + assert request.structured_pipeline == ppl_1._to_pb() + assert request.database == "projects/A/databases/B" + assert request.transaction == b"123" \ No newline at end of file diff --git a/tests/unit/v1/test_pipeline.py b/tests/unit/v1/test_pipeline.py index b8933b2c4..8bb4d6330 100644 --- a/tests/unit/v1/test_pipeline.py +++ b/tests/unit/v1/test_pipeline.py @@ -58,8 +58,8 @@ def test_pipeline_repr_multiple_stage(): assert repr_str == ( "Pipeline(\n" " Collection(path='/path'),\n" - " GenericStage(params=[2]),\n" - " GenericStage(params=[3])\n" + " GenericStage(name='second'),\n" + " GenericStage(name='third')\n" ")" ) @@ -207,6 +207,7 @@ def test_pipeline_execute_populated(): assert isinstance(request, ExecutePipelineRequest) assert request.structured_pipeline == ppl_1._to_pb() assert request.database == "projects/A/databases/B" + assert request.transaction == b"" response = results[0] assert isinstance(response, PipelineResult) @@ -266,3 +267,32 @@ def test_pipeline_execute_multiple(): for idx, response in enumerate(results): assert isinstance(response, PipelineResult) assert response.data() == {"key": idx} + +def test_pipeline_execute_with_transaction(): + """ + test execute pipeline with fully populated doc ref + """ + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import ExecutePipelineRequest + from google.cloud.firestore_v1.transaction import Transaction + + client = mock.Mock() + client.project = "A" + client._database = "B" + mock_rpc = client._firestore_api.execute_pipeline + + transaction = Transaction(client) + transaction._id = b"123" + + mock_rpc.return_value = [ + ExecutePipelineResponse() + ] + ppl_1 = _make_pipeline(client=client) + + list(ppl_1.execute(transaction=transaction)) + assert mock_rpc.call_count == 1 + request = mock_rpc.call_args[0][0] + assert isinstance(request, ExecutePipelineRequest) + assert request.structured_pipeline == ppl_1._to_pb() + assert request.database == "projects/A/databases/B" + assert request.transaction == b"123" \ No newline at end of file From 6434023e34565ea7f3168114bed81806c004f1a8 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 12 May 2025 15:49:51 -0700 Subject: [PATCH 121/131] ran blacken --- google/cloud/firestore_v1/async_pipeline.py | 8 +- google/cloud/firestore_v1/base_pipeline.py | 16 ++- google/cloud/firestore_v1/pipeline.py | 4 +- google/cloud/firestore_v1/pipeline_stages.py | 2 +- tests/unit/v1/test_async_pipeline.py | 8 +- tests/unit/v1/test_pipeline.py | 7 +- tests/unit/v1/test_pipeline_expressions.py | 109 +++++++++++-------- tests/unit/v1/test_pipeline_result.py | 54 +++++---- tests/unit/v1/test_pipeline_source.py | 4 +- 9 files changed, 123 insertions(+), 89 deletions(-) diff --git a/google/cloud/firestore_v1/async_pipeline.py b/google/cloud/firestore_v1/async_pipeline.py index 321d9033e..edc41d94d 100644 --- a/google/cloud/firestore_v1/async_pipeline.py +++ b/google/cloud/firestore_v1/async_pipeline.py @@ -59,7 +59,7 @@ def __init__(self, client: AsyncClient, *stages: stages.Stage): async def execute( self, - transaction: "AsyncTransaction" | None=None, + transaction: "AsyncTransaction" | None = None, ) -> AsyncIterable[PipelineResult]: """ Executes this pipeline, providing results through an Iterable @@ -73,6 +73,8 @@ async def execute( allowed). """ request = self._prep_execute_request(transaction) - async for response in await self._client._firestore_api.execute_pipeline(request): + async for response in await self._client._firestore_api.execute_pipeline( + request + ): for result in self._execute_response_helper(response): - yield result \ No newline at end of file + yield result diff --git a/google/cloud/firestore_v1/base_pipeline.py b/google/cloud/firestore_v1/base_pipeline.py index cb80be3b8..7f535caac 100644 --- a/google/cloud/firestore_v1/base_pipeline.py +++ b/google/cloud/firestore_v1/base_pipeline.py @@ -71,14 +71,20 @@ def _append(self, new_stage): """ return self.__class__(self._client, *self.stages, new_stage) - def _prep_execute_request(self, transaction: BaseTransaction | None) -> ExecutePipelineRequest: + def _prep_execute_request( + self, transaction: BaseTransaction | None + ) -> ExecutePipelineRequest: """ shared logic for creating an ExecutePipelineRequest """ database_name = ( f"projects/{self._client.project}/databases/{self._client._database}" ) - transaction_id = _helpers.get_transaction_id(transaction) if transaction is not None else None + transaction_id = ( + _helpers.get_transaction_id(transaction) + if transaction is not None + else None + ) request = ExecutePipelineRequest( database=database_name, transaction=transaction_id, @@ -86,7 +92,9 @@ def _prep_execute_request(self, transaction: BaseTransaction | None) -> ExecuteP ) return request - def _execute_response_helper(self, response:ExecutePipelineResponse) -> Iterable[PipelineResult]: + def _execute_response_helper( + self, response: ExecutePipelineResponse + ) -> Iterable[PipelineResult]: """ shared logic for unpacking an ExecutePipelineReponse into PipelineResults """ @@ -99,4 +107,4 @@ def _execute_response_helper(self, response:ExecutePipelineResponse) -> Iterable response._pb.execution_time, doc.create_time.timestamp_pb() if doc.create_time else None, doc.update_time.timestamp_pb() if doc.update_time else None, - ) \ No newline at end of file + ) diff --git a/google/cloud/firestore_v1/pipeline.py b/google/cloud/firestore_v1/pipeline.py index 0c85a57f3..91207c8c6 100644 --- a/google/cloud/firestore_v1/pipeline.py +++ b/google/cloud/firestore_v1/pipeline.py @@ -55,7 +55,7 @@ def __init__(self, client: Client, *stages: stages.Stage): def execute( self, - transaction: "Transaction" | None=None, + transaction: "Transaction" | None = None, ) -> Iterable[PipelineResult]: """ Executes this pipeline, providing results through an Iterable @@ -70,4 +70,4 @@ def execute( """ request = self._prep_execute_request(transaction) for response in self._client._firestore_api.execute_pipeline(request): - yield from self._execute_response_helper(response) \ No newline at end of file + yield from self._execute_response_helper(response) diff --git a/google/cloud/firestore_v1/pipeline_stages.py b/google/cloud/firestore_v1/pipeline_stages.py index 840feca5e..eec3cafb5 100644 --- a/google/cloud/firestore_v1/pipeline_stages.py +++ b/google/cloud/firestore_v1/pipeline_stages.py @@ -81,4 +81,4 @@ def _pb_args(self): return self.params def __repr__(self): - return f"{self.__class__.__name__}(name='{self.name}')" \ No newline at end of file + return f"{self.__class__.__name__}(name='{self.name}')" diff --git a/tests/unit/v1/test_async_pipeline.py b/tests/unit/v1/test_async_pipeline.py index 628d37ddd..c40708ea8 100644 --- a/tests/unit/v1/test_async_pipeline.py +++ b/tests/unit/v1/test_async_pipeline.py @@ -165,7 +165,7 @@ async def test_async_pipeline_execute_no_doc_ref(): assert isinstance(request, ExecutePipelineRequest) assert request.structured_pipeline == ppl_1._to_pb() assert request.database == "projects/A/databases/B" - assert request.transaction == b'' + assert request.transaction == b"" response = results[0] assert isinstance(response, PipelineResult) @@ -305,9 +305,7 @@ async def test_async_pipeline_execute_with_transaction(): transaction = AsyncTransaction(client) transaction._id = b"123" - mock_rpc.return_value = _async_it([ - ExecutePipelineResponse() - ]) + mock_rpc.return_value = _async_it([ExecutePipelineResponse()]) ppl_1 = _make_async_pipeline(client=client) [r async for r in ppl_1.execute(transaction=transaction)] @@ -316,4 +314,4 @@ async def test_async_pipeline_execute_with_transaction(): assert isinstance(request, ExecutePipelineRequest) assert request.structured_pipeline == ppl_1._to_pb() assert request.database == "projects/A/databases/B" - assert request.transaction == b"123" \ No newline at end of file + assert request.transaction == b"123" diff --git a/tests/unit/v1/test_pipeline.py b/tests/unit/v1/test_pipeline.py index 8bb4d6330..9a8ccd2a0 100644 --- a/tests/unit/v1/test_pipeline.py +++ b/tests/unit/v1/test_pipeline.py @@ -268,6 +268,7 @@ def test_pipeline_execute_multiple(): assert isinstance(response, PipelineResult) assert response.data() == {"key": idx} + def test_pipeline_execute_with_transaction(): """ test execute pipeline with fully populated doc ref @@ -284,9 +285,7 @@ def test_pipeline_execute_with_transaction(): transaction = Transaction(client) transaction._id = b"123" - mock_rpc.return_value = [ - ExecutePipelineResponse() - ] + mock_rpc.return_value = [ExecutePipelineResponse()] ppl_1 = _make_pipeline(client=client) list(ppl_1.execute(transaction=transaction)) @@ -295,4 +294,4 @@ def test_pipeline_execute_with_transaction(): assert isinstance(request, ExecutePipelineRequest) assert request.structured_pipeline == ppl_1._to_pb() assert request.database == "projects/A/databases/B" - assert request.transaction == b"123" \ No newline at end of file + assert request.transaction == b"123" diff --git a/tests/unit/v1/test_pipeline_expressions.py b/tests/unit/v1/test_pipeline_expressions.py index 2be7b5eb5..cba339ef1 100644 --- a/tests/unit/v1/test_pipeline_expressions.py +++ b/tests/unit/v1/test_pipeline_expressions.py @@ -32,56 +32,73 @@ def test_ctor(self): class TestConstant: - @pytest.mark.parametrize("input_val, to_pb_val",[ - ("test", Value(string_value="test")), - ("", Value(string_value="")), - (10, Value(integer_value=10)), - (0, Value(integer_value=0)), - (10.0, Value(double_value=10)), - (0.0, Value(double_value=0)), - (True, Value(boolean_value=True)), - (b"test", Value(bytes_value=b"test")), - (None, Value(null_value=0)), - ( - datetime.datetime(2025, 5, 12), - Value(timestamp_value={"seconds": 1747008000}), - ), - (GeoPoint(1, 2), Value(geo_point_value={"latitude": 1, "longitude": 2})), - ( - [0.0, 1.0, 2.0], - Value(array_value={"values": [Value(double_value=i) for i in range(3)]}), - ), - ({"a": "b"}, Value(map_value={"fields": {"a": Value(string_value="b")}})), - ( - Vector([1.0, 2.0]), - Value(map_value={"fields": { - "__type__": Value(string_value="__vector__"), - "value": Value(array_value={ - "values": [Value(double_value=v) for v in [1,2]], - }), - }}), - ), - ]) + @pytest.mark.parametrize( + "input_val, to_pb_val", + [ + ("test", Value(string_value="test")), + ("", Value(string_value="")), + (10, Value(integer_value=10)), + (0, Value(integer_value=0)), + (10.0, Value(double_value=10)), + (0.0, Value(double_value=0)), + (True, Value(boolean_value=True)), + (b"test", Value(bytes_value=b"test")), + (None, Value(null_value=0)), + ( + datetime.datetime(2025, 5, 12), + Value(timestamp_value={"seconds": 1747008000}), + ), + (GeoPoint(1, 2), Value(geo_point_value={"latitude": 1, "longitude": 2})), + ( + [0.0, 1.0, 2.0], + Value( + array_value={"values": [Value(double_value=i) for i in range(3)]} + ), + ), + ({"a": "b"}, Value(map_value={"fields": {"a": Value(string_value="b")}})), + ( + Vector([1.0, 2.0]), + Value( + map_value={ + "fields": { + "__type__": Value(string_value="__vector__"), + "value": Value( + array_value={ + "values": [Value(double_value=v) for v in [1, 2]], + } + ), + } + } + ), + ), + ], + ) def test_to_pb(self, input_val, to_pb_val): instance = expressions.Constant.of(input_val) assert instance._to_pb() == to_pb_val - @pytest.mark.parametrize("input_val,expected", [ - ("test", "Constant.of('test')"), - ("", "Constant.of('')"), - (10, "Constant.of(10)"), - (0, "Constant.of(0)"), - (10.0, "Constant.of(10.0)"), - (0.0, "Constant.of(0.0)"), - (True, "Constant.of(True)"), - (b"test", "Constant.of(b'test')"), - (None, "Constant.of(None)"), - (datetime.datetime(2025, 5, 12), "Constant.of(datetime.datetime(2025, 5, 12, 0, 0))"), - (GeoPoint(1, 2), "Constant.of(GeoPoint(latitude=1, longitude=2))"), - ([1, 2, 3], "Constant.of([1, 2, 3])"), - ({"a": "b"}, "Constant.of({'a': 'b'})"), - (Vector([1.0, 2.0]), "Constant.of(Vector<1.0, 2.0>)"), - ]) + @pytest.mark.parametrize( + "input_val,expected", + [ + ("test", "Constant.of('test')"), + ("", "Constant.of('')"), + (10, "Constant.of(10)"), + (0, "Constant.of(0)"), + (10.0, "Constant.of(10.0)"), + (0.0, "Constant.of(0.0)"), + (True, "Constant.of(True)"), + (b"test", "Constant.of(b'test')"), + (None, "Constant.of(None)"), + ( + datetime.datetime(2025, 5, 12), + "Constant.of(datetime.datetime(2025, 5, 12, 0, 0))", + ), + (GeoPoint(1, 2), "Constant.of(GeoPoint(latitude=1, longitude=2))"), + ([1, 2, 3], "Constant.of([1, 2, 3])"), + ({"a": "b"}, "Constant.of({'a': 'b'})"), + (Vector([1.0, 2.0]), "Constant.of(Vector<1.0, 2.0>)"), + ], + ) def test_repr(self, input_val, expected): instance = expressions.Constant.of(input_val) repr_string = repr(instance) diff --git a/tests/unit/v1/test_pipeline_result.py b/tests/unit/v1/test_pipeline_result.py index d4bf538a0..33fa57ea9 100644 --- a/tests/unit/v1/test_pipeline_result.py +++ b/tests/unit/v1/test_pipeline_result.py @@ -20,7 +20,6 @@ class TestPipelineResult: - def _make_one(self, *args, **kwargs): if not args: # use defaults if not passed @@ -54,7 +53,7 @@ def test_update_time(self): instance = self._make_one(update_time=expected) assert instance.update_time == expected # should be None if not set - assert self._make_one().update_time== None + assert self._make_one().update_time == None def test_exection_time(self): expected = object() @@ -65,17 +64,24 @@ def test_exection_time(self): self._make_one().execution_time assert "execution_time" in e - @pytest.mark.parametrize("first,second,result", [ - ((object(),{}), (object(), {}), True), - ((object(),{1:1}), (object(), {1:1}), True), - ((object(),{1:1}), (object(), {2:2}), False), - ((object(),{}, "ref"), (object(), {}, "ref"), True), - ((object(),{}, "ref"), (object(), {}, "diff"), False), - ((object(),{1:1}, "ref"), (object(), {1:1}, "ref"), True), - ((object(),{1:1}, "ref"), (object(), {2:2}, "ref"), False), - ((object(),{1:1}, "ref"), (object(), {1:1}, "diff"), False), - ((object(),{1:1}, "ref", 1,2,3), (object(), {1:1}, "ref", 4,5,6), True), - ]) + @pytest.mark.parametrize( + "first,second,result", + [ + ((object(), {}), (object(), {}), True), + ((object(), {1: 1}), (object(), {1: 1}), True), + ((object(), {1: 1}), (object(), {2: 2}), False), + ((object(), {}, "ref"), (object(), {}, "ref"), True), + ((object(), {}, "ref"), (object(), {}, "diff"), False), + ((object(), {1: 1}, "ref"), (object(), {1: 1}, "ref"), True), + ((object(), {1: 1}, "ref"), (object(), {2: 2}, "ref"), False), + ((object(), {1: 1}, "ref"), (object(), {1: 1}, "diff"), False), + ( + (object(), {1: 1}, "ref", 1, 2, 3), + (object(), {1: 1}, "ref", 4, 5, 6), + True, + ), + ], + ) def test_eq(self, first, second, result): first_obj = self._make_one(*first) second_obj = self._make_one(*second) @@ -83,6 +89,7 @@ def test_eq(self, first, second, result): def test_data(self): from google.cloud.firestore_v1.types.document import Value + client = mock.Mock() data = {"str": Value(string_value="hello world"), "int": Value(integer_value=5)} instance = self._make_one(client, data) @@ -104,13 +111,16 @@ def test_data_call(self): client = object() data = {"hello": "world"} instance = self._make_one(client, data) - with mock.patch("google.cloud.firestore_v1._helpers.decode_dict") as decode_mock: + with mock.patch( + "google.cloud.firestore_v1._helpers.decode_dict" + ) as decode_mock: got = instance.data() decode_mock.assert_called_once_with(data, client) assert got == decode_mock.return_value def test_get(self): from google.cloud.firestore_v1.types.document import Value + client = object() data = {"key": Value(string_value="hello world")} instance = self._make_one(client, data) @@ -119,10 +129,9 @@ def test_get(self): def test_get_nested(self): from google.cloud.firestore_v1.types.document import Value + client = object() - data = { - "first": {"second": Value(string_value="hello world")} - } + data = {"first": {"second": Value(string_value="hello world")}} instance = self._make_one(client, data) got = instance.get("first.second") assert got == "hello world" @@ -130,10 +139,9 @@ def test_get_nested(self): def test_get_field_path(self): from google.cloud.firestore_v1.types.document import Value from google.cloud.firestore_v1.field_path import FieldPath + client = object() - data = { - "first": {"second": Value(string_value="hello world")} - } + data = {"first": {"second": Value(string_value="hello world")}} path = FieldPath.from_string("first.second") instance = self._make_one(client, data) got = instance.get(path) @@ -156,7 +164,9 @@ def test_get_call(self): client = object() data = {"key": "value"} instance = self._make_one(client, data) - with mock.patch("google.cloud.firestore_v1._helpers.decode_value") as decode_mock: + with mock.patch( + "google.cloud.firestore_v1._helpers.decode_value" + ) as decode_mock: got = instance.get("key") decode_mock.assert_called_once_with("value", client) - assert got == decode_mock.return_value \ No newline at end of file + assert got == decode_mock.return_value diff --git a/tests/unit/v1/test_pipeline_source.py b/tests/unit/v1/test_pipeline_source.py index 0794ed0de..ee892e60c 100644 --- a/tests/unit/v1/test_pipeline_source.py +++ b/tests/unit/v1/test_pipeline_source.py @@ -24,7 +24,6 @@ class TestPipelineSource: - _expected_pipeline_type = Pipeline def _make_client(self): @@ -53,7 +52,8 @@ class TestPipelineSourceWithAsyncClient(TestPipelineSource): """ When an async client is used, it should produce async pipelines """ + _expected_pipeline_type = AsyncPipeline def _make_client(self): - return AsyncClient() \ No newline at end of file + return AsyncClient() From a8beea440a0c64e4ec41fb735512cfb42378f985 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 12 May 2025 15:51:52 -0700 Subject: [PATCH 122/131] fixed lint --- google/cloud/firestore_v1/async_pipeline.py | 1 - google/cloud/firestore_v1/pipeline_stages.py | 5 +---- tests/unit/v1/test_pipeline_expressions.py | 1 - tests/unit/v1/test_pipeline_result.py | 9 ++++----- tests/unit/v1/test_pipeline_source.py | 3 --- tests/unit/v1/test_pipeline_stages.py | 1 - 6 files changed, 5 insertions(+), 15 deletions(-) diff --git a/google/cloud/firestore_v1/async_pipeline.py b/google/cloud/firestore_v1/async_pipeline.py index edc41d94d..8619e5aad 100644 --- a/google/cloud/firestore_v1/async_pipeline.py +++ b/google/cloud/firestore_v1/async_pipeline.py @@ -13,7 +13,6 @@ # limitations under the License. from __future__ import annotations -import datetime from typing import AsyncIterable, TYPE_CHECKING from google.cloud.firestore_v1 import pipeline_stages as stages from google.cloud.firestore_v1.base_pipeline import _BasePipeline diff --git a/google/cloud/firestore_v1/pipeline_stages.py b/google/cloud/firestore_v1/pipeline_stages.py index eec3cafb5..3871a363d 100644 --- a/google/cloud/firestore_v1/pipeline_stages.py +++ b/google/cloud/firestore_v1/pipeline_stages.py @@ -13,7 +13,7 @@ # limitations under the License. from __future__ import annotations -from typing import Optional, TYPE_CHECKING +from typing import Optional from abc import ABC from abc import abstractmethod @@ -21,9 +21,6 @@ from google.cloud.firestore_v1.types.document import Value from google.cloud.firestore_v1.pipeline_expressions import Expr -if TYPE_CHECKING: - from google.cloud.firestore_v1.base_document import BaseDocumentReference - class Stage(ABC): """Base class for all pipeline stages. diff --git a/tests/unit/v1/test_pipeline_expressions.py b/tests/unit/v1/test_pipeline_expressions.py index cba339ef1..19ebed3b5 100644 --- a/tests/unit/v1/test_pipeline_expressions.py +++ b/tests/unit/v1/test_pipeline_expressions.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License -import mock import pytest import datetime diff --git a/tests/unit/v1/test_pipeline_result.py b/tests/unit/v1/test_pipeline_result.py index 33fa57ea9..7a80191df 100644 --- a/tests/unit/v1/test_pipeline_result.py +++ b/tests/unit/v1/test_pipeline_result.py @@ -14,7 +14,6 @@ import mock import pytest -import datetime from google.cloud.firestore_v1.pipeline_result import PipelineResult @@ -31,7 +30,7 @@ def test_ref(self): instance = self._make_one(ref=expected) assert instance.ref == expected # should be None if not set - assert self._make_one().ref == None + assert self._make_one().ref is None def test_id(self): ref = mock.Mock() @@ -39,21 +38,21 @@ def test_id(self): instance = self._make_one(ref=ref) assert instance.id == "test" # should be None if not set - assert self._make_one().id == None + assert self._make_one().id is None def test_create_time(self): expected = object() instance = self._make_one(create_time=expected) assert instance.create_time == expected # should be None if not set - assert self._make_one().create_time == None + assert self._make_one().create_time is None def test_update_time(self): expected = object() instance = self._make_one(update_time=expected) assert instance.update_time == expected # should be None if not set - assert self._make_one().update_time == None + assert self._make_one().update_time is None def test_exection_time(self): expected = object() diff --git a/tests/unit/v1/test_pipeline_source.py b/tests/unit/v1/test_pipeline_source.py index ee892e60c..6eedfed36 100644 --- a/tests/unit/v1/test_pipeline_source.py +++ b/tests/unit/v1/test_pipeline_source.py @@ -12,9 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License -import mock -import pytest - from google.cloud.firestore_v1.pipeline_source import PipelineSource from google.cloud.firestore_v1.pipeline import Pipeline from google.cloud.firestore_v1.async_pipeline import AsyncPipeline diff --git a/tests/unit/v1/test_pipeline_stages.py b/tests/unit/v1/test_pipeline_stages.py index d0568dcf5..549c0f917 100644 --- a/tests/unit/v1/test_pipeline_stages.py +++ b/tests/unit/v1/test_pipeline_stages.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License -import mock import pytest import google.cloud.firestore_v1.pipeline_stages as stages From 2d286bb84edcc04e20afc3302ecc8e37d6bf5d32 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 12 May 2025 16:25:24 -0700 Subject: [PATCH 123/131] fixed mypy --- google/cloud/firestore_v1/base_pipeline.py | 4 ++-- google/cloud/firestore_v1/field_path.py | 4 ++-- google/cloud/firestore_v1/pipeline_result.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/google/cloud/firestore_v1/base_pipeline.py b/google/cloud/firestore_v1/base_pipeline.py index 7f535caac..0c46d92d5 100644 --- a/google/cloud/firestore_v1/base_pipeline.py +++ b/google/cloud/firestore_v1/base_pipeline.py @@ -105,6 +105,6 @@ def _execute_response_helper( doc.fields, ref, response._pb.execution_time, - doc.create_time.timestamp_pb() if doc.create_time else None, - doc.update_time.timestamp_pb() if doc.update_time else None, + doc._pb.create_time if doc.create_time else None, + doc._pb.update_time if doc.update_time else None, ) diff --git a/google/cloud/firestore_v1/field_path.py b/google/cloud/firestore_v1/field_path.py index 048eb64d0..488ef690a 100644 --- a/google/cloud/firestore_v1/field_path.py +++ b/google/cloud/firestore_v1/field_path.py @@ -16,7 +16,7 @@ from __future__ import annotations import re from collections import abc -from typing import Iterable, cast +from typing import Any, Iterable, cast, MutableMapping _FIELD_PATH_MISSING_TOP = "{!r} is not contained in the data" _FIELD_PATH_MISSING_KEY = "{!r} is not contained in the data for the key {!r}" @@ -170,7 +170,7 @@ def render_field_path(field_names: Iterable[str]): get_field_path = render_field_path # backward-compatibility -def get_nested_value(field_path: str, data: dict): +def get_nested_value(field_path: str, data: MutableMapping[str, Any]): """Get a (potentially nested) value from a dictionary. If the data is nested, for example: diff --git a/google/cloud/firestore_v1/pipeline_result.py b/google/cloud/firestore_v1/pipeline_result.py index 54bfc152f..475106b7a 100644 --- a/google/cloud/firestore_v1/pipeline_result.py +++ b/google/cloud/firestore_v1/pipeline_result.py @@ -13,7 +13,7 @@ # limitations under the License. from __future__ import annotations -from typing import Any, TYPE_CHECKING +from typing import Any, MutableMapping, TYPE_CHECKING from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.field_path import get_nested_value from google.cloud.firestore_v1.field_path import FieldPath @@ -36,7 +36,7 @@ class PipelineResult: def __init__( self, client: BaseClient, - fields_pb: dict[str, ValueProto], + fields_pb: MutableMapping[str, ValueProto], ref: BaseDocumentReference | None = None, execution_time: Timestamp | None = None, create_time: Timestamp | None = None, From b46bdc17d9aea72daa36093d982bb875e1776d22 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 13 May 2025 15:02:46 -0700 Subject: [PATCH 124/131] fixed test issues --- google/cloud/firestore_v1/async_pipeline.py | 2 +- google/cloud/firestore_v1/base_pipeline.py | 2 +- google/cloud/firestore_v1/pipeline.py | 2 +- google/cloud/firestore_v1/pipeline_expressions.py | 4 ---- google/cloud/firestore_v1/pipeline_result.py | 2 +- google/cloud/firestore_v1/pipeline_source.py | 2 +- tests/unit/v1/test_async_client.py | 4 ++-- tests/unit/v1/test_pipeline_result.py | 5 +++++ 8 files changed, 12 insertions(+), 11 deletions(-) diff --git a/google/cloud/firestore_v1/async_pipeline.py b/google/cloud/firestore_v1/async_pipeline.py index 8619e5aad..57c4f324f 100644 --- a/google/cloud/firestore_v1/async_pipeline.py +++ b/google/cloud/firestore_v1/async_pipeline.py @@ -17,7 +17,7 @@ from google.cloud.firestore_v1 import pipeline_stages as stages from google.cloud.firestore_v1.base_pipeline import _BasePipeline -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: NO COVER from google.cloud.firestore_v1.async_client import AsyncClient from google.cloud.firestore_v1.pipeline_result import PipelineResult from google.cloud.firestore_v1.async_transaction import AsyncTransaction diff --git a/google/cloud/firestore_v1/base_pipeline.py b/google/cloud/firestore_v1/base_pipeline.py index 0c46d92d5..851ad94ad 100644 --- a/google/cloud/firestore_v1/base_pipeline.py +++ b/google/cloud/firestore_v1/base_pipeline.py @@ -22,7 +22,7 @@ from google.cloud.firestore_v1.pipeline_result import PipelineResult from google.cloud.firestore_v1 import _helpers -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: NO COVER from google.cloud.firestore_v1.client import Client from google.cloud.firestore_v1.async_client import AsyncClient from google.cloud.firestore_v1.types.firestore import ExecutePipelineResponse diff --git a/google/cloud/firestore_v1/pipeline.py b/google/cloud/firestore_v1/pipeline.py index 91207c8c6..862a52815 100644 --- a/google/cloud/firestore_v1/pipeline.py +++ b/google/cloud/firestore_v1/pipeline.py @@ -17,7 +17,7 @@ from google.cloud.firestore_v1 import pipeline_stages as stages from google.cloud.firestore_v1.base_pipeline import _BasePipeline -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: NO COVER from google.cloud.firestore_v1.client import Client from google.cloud.firestore_v1.pipeline_result import PipelineResult from google.cloud.firestore_v1.transaction import Transaction diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index 0219a29c6..5e0c775a2 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -66,10 +66,6 @@ def __repr__(self): def _to_pb(self) -> Value: raise NotImplementedError - @staticmethod - def _cast_to_expr_or_convert_to_constant(o: Any) -> "Expr": - return o if isinstance(o, Expr) else Constant(o) - class Constant(Expr, Generic[CONSTANT_TYPE]): """Represents a constant literal value in an expression.""" diff --git a/google/cloud/firestore_v1/pipeline_result.py b/google/cloud/firestore_v1/pipeline_result.py index 475106b7a..84eeccecd 100644 --- a/google/cloud/firestore_v1/pipeline_result.py +++ b/google/cloud/firestore_v1/pipeline_result.py @@ -18,7 +18,7 @@ from google.cloud.firestore_v1.field_path import get_nested_value from google.cloud.firestore_v1.field_path import FieldPath -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: NO COVER from google.cloud.firestore_v1.base_client import BaseClient from google.cloud.firestore_v1.base_document import BaseDocumentReference from google.protobuf.timestamp_pb2 import Timestamp diff --git a/google/cloud/firestore_v1/pipeline_source.py b/google/cloud/firestore_v1/pipeline_source.py index ab9cfada7..507282881 100644 --- a/google/cloud/firestore_v1/pipeline_source.py +++ b/google/cloud/firestore_v1/pipeline_source.py @@ -17,7 +17,7 @@ from google.cloud.firestore_v1 import pipeline_stages as stages from google.cloud.firestore_v1.base_pipeline import _BasePipeline -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: NO COVER from google.cloud.firestore_v1.client import Client from google.cloud.firestore_v1.async_client import AsyncClient diff --git a/tests/unit/v1/test_async_client.py b/tests/unit/v1/test_async_client.py index b4490fb69..80a17a763 100644 --- a/tests/unit/v1/test_async_client.py +++ b/tests/unit/v1/test_async_client.py @@ -533,11 +533,11 @@ def test_asyncclient_transaction(): assert transaction._id is None -def test_asyncclient_pipeline(database): +def test_asyncclient_pipeline(): from google.cloud.firestore_v1.async_pipeline import AsyncPipeline from google.cloud.firestore_v1.pipeline_source import PipelineSource - client = _make_default_async_client(database=database) + client = _make_default_async_client() ppl = client.pipeline() assert client._pipeline_cls == AsyncPipeline assert isinstance(ppl, PipelineSource) diff --git a/tests/unit/v1/test_pipeline_result.py b/tests/unit/v1/test_pipeline_result.py index 7a80191df..2facf7110 100644 --- a/tests/unit/v1/test_pipeline_result.py +++ b/tests/unit/v1/test_pipeline_result.py @@ -86,6 +86,11 @@ def test_eq(self, first, second, result): second_obj = self._make_one(*second) assert (first_obj == second_obj) is result + def test_eq_wrong_type(self): + instance = self._make_one() + result = instance == object() + assert result is False + def test_data(self): from google.cloud.firestore_v1.types.document import Value From 343232251f2b895df160340a63ab0e274c7ac054 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 10 Jun 2025 17:03:09 -0700 Subject: [PATCH 125/131] fixed lint --- google/cloud/firestore_v1/async_client.py | 2 +- google/cloud/firestore_v1/client.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/google/cloud/firestore_v1/async_client.py b/google/cloud/firestore_v1/async_client.py index b7b703426..3acbedc76 100644 --- a/google/cloud/firestore_v1/async_client.py +++ b/google/cloud/firestore_v1/async_client.py @@ -435,4 +435,4 @@ def _pipeline_cls(self): return AsyncPipeline def pipeline(self) -> PipelineSource: - return PipelineSource(self) \ No newline at end of file + return PipelineSource(self) diff --git a/google/cloud/firestore_v1/client.py b/google/cloud/firestore_v1/client.py index 2b69980e1..c23943b24 100644 --- a/google/cloud/firestore_v1/client.py +++ b/google/cloud/firestore_v1/client.py @@ -416,4 +416,4 @@ def _pipeline_cls(self): return Pipeline def pipeline(self) -> PipelineSource: - return PipelineSource(self) \ No newline at end of file + return PipelineSource(self) From 64cd4fbc18cb3f4a033366134ee83087536af366 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 12 Jun 2025 11:22:25 -0700 Subject: [PATCH 126/131] fixed comment --- google/cloud/firestore_v1/async_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/cloud/firestore_v1/async_pipeline.py b/google/cloud/firestore_v1/async_pipeline.py index 57c4f324f..ed8f7e897 100644 --- a/google/cloud/firestore_v1/async_pipeline.py +++ b/google/cloud/firestore_v1/async_pipeline.py @@ -40,7 +40,7 @@ class AsyncPipeline(_BasePipeline): ... .collection("books") ... .where(Field.of("published").gt(1980)) ... .select("title", "author") - ... async for result in pipeline.execute_async(): + ... async for result in pipeline.execute(): ... print(result) Use `client.pipeline()` to create instances of this class. From e74e04dde8e22636c458ee52d894928fbb3002de Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 16 Jun 2025 14:22:13 -0700 Subject: [PATCH 127/131] added separate stream/execute methods --- google/cloud/firestore_v1/async_pipeline.py | 19 ++++- google/cloud/firestore_v1/pipeline.py | 19 ++++- tests/unit/v1/test_async_pipeline.py | 87 +++++++++++++++++---- tests/unit/v1/test_pipeline.py | 82 +++++++++++++++---- 4 files changed, 175 insertions(+), 32 deletions(-) diff --git a/google/cloud/firestore_v1/async_pipeline.py b/google/cloud/firestore_v1/async_pipeline.py index ed8f7e897..9fe0c8756 100644 --- a/google/cloud/firestore_v1/async_pipeline.py +++ b/google/cloud/firestore_v1/async_pipeline.py @@ -59,9 +59,26 @@ def __init__(self, client: AsyncClient, *stages: stages.Stage): async def execute( self, transaction: "AsyncTransaction" | None = None, + ) -> list[PipelineResult]: + """ + Executes this pipeline and returns results as a list + + Args: + transaction + (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + An existing transaction that this query will run in. + If a ``transaction`` is used and it already has write operations + added, this method cannot be used (i.e. read-after-write is not + allowed). + """ + return [result async for result in self.stream(transaction=transaction)] + + async def stream( + self, + transaction: "AsyncTransaction" | None = None, ) -> AsyncIterable[PipelineResult]: """ - Executes this pipeline, providing results through an Iterable + Process this pipeline as a stream, providing results through an Iterable Args: transaction diff --git a/google/cloud/firestore_v1/pipeline.py b/google/cloud/firestore_v1/pipeline.py index 862a52815..f578e00b6 100644 --- a/google/cloud/firestore_v1/pipeline.py +++ b/google/cloud/firestore_v1/pipeline.py @@ -56,9 +56,26 @@ def __init__(self, client: Client, *stages: stages.Stage): def execute( self, transaction: "Transaction" | None = None, + ) -> list[PipelineResult]: + """ + Executes this pipeline and returns results as a list + + Args: + transaction + (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + An existing transaction that this query will run in. + If a ``transaction`` is used and it already has write operations + added, this method cannot be used (i.e. read-after-write is not + allowed). + """ + return [result for result in self.stream(transaction=transaction)] + + def stream( + self, + transaction: "Transaction" | None = None, ) -> Iterable[PipelineResult]: """ - Executes this pipeline, providing results through an Iterable + Process this pipeline as a stream, providing results through an Iterable Args: transaction diff --git a/tests/unit/v1/test_async_pipeline.py b/tests/unit/v1/test_async_pipeline.py index c40708ea8..57f125b0e 100644 --- a/tests/unit/v1/test_async_pipeline.py +++ b/tests/unit/v1/test_async_pipeline.py @@ -112,9 +112,9 @@ def test_async_pipeline_append(): @pytest.mark.asyncio -async def test_async_pipeline_execute_empty(): +async def test_async_pipeline_stream_empty(): """ - test execute pipeline with mocked empty response + test stream pipeline with mocked empty response """ from google.cloud.firestore_v1.types import ExecutePipelineResponse from google.cloud.firestore_v1.types import ExecutePipelineRequest @@ -128,7 +128,7 @@ async def test_async_pipeline_execute_empty(): mock_rpc.return_value = _async_it([ExecutePipelineResponse()]) ppl_1 = _make_async_pipeline(GenericStage("s"), client=client) - results = [r async for r in ppl_1.execute()] + results = [r async for r in ppl_1.stream()] assert results == [] assert mock_rpc.call_count == 1 request = mock_rpc.call_args[0][0] @@ -138,9 +138,9 @@ async def test_async_pipeline_execute_empty(): @pytest.mark.asyncio -async def test_async_pipeline_execute_no_doc_ref(): +async def test_async_pipeline_stream_no_doc_ref(): """ - test execute pipeline with no doc ref + test stream pipeline with no doc ref """ from google.cloud.firestore_v1.types import Document from google.cloud.firestore_v1.types import ExecutePipelineResponse @@ -158,7 +158,7 @@ async def test_async_pipeline_execute_no_doc_ref(): ) ppl_1 = _make_async_pipeline(GenericStage("s"), client=client) - results = [r async for r in ppl_1.execute()] + results = [r async for r in ppl_1.stream()] assert len(results) == 1 assert mock_rpc.call_count == 1 request = mock_rpc.call_args[0][0] @@ -178,9 +178,9 @@ async def test_async_pipeline_execute_no_doc_ref(): @pytest.mark.asyncio -async def test_async_pipeline_execute_populated(): +async def test_async_pipeline_stream_populated(): """ - test execute pipeline with fully populated doc ref + test stream pipeline with fully populated doc ref """ from google.cloud.firestore_v1.types import Document from google.cloud.firestore_v1.types import ExecutePipelineResponse @@ -215,7 +215,7 @@ async def test_async_pipeline_execute_populated(): ) ppl_1 = _make_async_pipeline(client=client) - results = [r async for r in ppl_1.execute()] + results = [r async for r in ppl_1.stream()] assert len(results) == 1 assert mock_rpc.call_count == 1 request = mock_rpc.call_args[0][0] @@ -235,9 +235,9 @@ async def test_async_pipeline_execute_populated(): @pytest.mark.asyncio -async def test_async_pipeline_execute_multiple(): +async def test_async_pipeline_stream_multiple(): """ - test execute pipeline with multiple docs and responses + test stream pipeline with multiple docs and responses """ from google.cloud.firestore_v1.types import Document from google.cloud.firestore_v1.types import ExecutePipelineResponse @@ -274,7 +274,7 @@ async def test_async_pipeline_execute_multiple(): ) ppl_1 = _make_async_pipeline(client=client) - results = [r async for r in ppl_1.execute()] + results = [r async for r in ppl_1.stream()] assert len(results) == 4 assert mock_rpc.call_count == 1 request = mock_rpc.call_args[0][0] @@ -288,9 +288,9 @@ async def test_async_pipeline_execute_multiple(): @pytest.mark.asyncio -async def test_async_pipeline_execute_with_transaction(): +async def test_async_pipeline_stream_with_transaction(): """ - test execute pipeline with transaction context + test stream pipeline with transaction context """ from google.cloud.firestore_v1.types import ExecutePipelineResponse from google.cloud.firestore_v1.types import ExecutePipelineRequest @@ -308,10 +308,67 @@ async def test_async_pipeline_execute_with_transaction(): mock_rpc.return_value = _async_it([ExecutePipelineResponse()]) ppl_1 = _make_async_pipeline(client=client) - [r async for r in ppl_1.execute(transaction=transaction)] + [r async for r in ppl_1.stream(transaction=transaction)] assert mock_rpc.call_count == 1 request = mock_rpc.call_args[0][0] assert isinstance(request, ExecutePipelineRequest) assert request.structured_pipeline == ppl_1._to_pb() assert request.database == "projects/A/databases/B" assert request.transaction == b"123" + + +@pytest.mark.asyncio +async def test_async_pipeline_stream_stream_equivalence(): + """ + Pipeline.stream should provide same results from pipeline.stream, as a list + """ + from google.cloud.firestore_v1.types import Document + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import Value + from google.cloud.firestore_v1.client import Client + + real_client = Client() + client = mock.Mock() + client.project = "A" + client._database = "B" + client.document = real_client.document + mock_rpc = mock.AsyncMock() + client._firestore_api.execute_pipeline = mock_rpc + mock_response = [ + ExecutePipelineResponse( + results=[ + Document( + name="test/my_doc", + fields={"key": Value(string_value="str_val")}, + ) + ], + ) + ] + mock_rpc.return_value = _async_it(mock_response) + ppl_1 = _make_async_pipeline(client=client) + + stream_results = [r async for r in ppl_1.stream()] + # reset response + mock_rpc.return_value = _async_it(mock_response) + stream_results = await ppl_1.execute() + assert stream_results == stream_results + assert stream_results[0].data()["key"] == "str_val" + assert stream_results[0].data()["key"] == "str_val" + + +@pytest.mark.asyncio +async def test_async_pipeline_stream_stream_equivalence_mocked(): + """ + pipeline.stream should call pipeline.stream internally + """ + ppl_1 = _make_async_pipeline() + expected_data = [object(), object()] + expected_arg = object() + with mock.patch.object(ppl_1, "stream") as mock_stream: + mock_stream.return_value = _async_it(expected_data) + stream_results = await ppl_1.execute(expected_arg) + assert mock_stream.call_count == 1 + assert mock_stream.call_args[0] == () + assert len(mock_stream.call_args[1]) == 1 + assert mock_stream.call_args[1]["transaction"] == expected_arg + assert stream_results == expected_data diff --git a/tests/unit/v1/test_pipeline.py b/tests/unit/v1/test_pipeline.py index 9a8ccd2a0..55f0a1145 100644 --- a/tests/unit/v1/test_pipeline.py +++ b/tests/unit/v1/test_pipeline.py @@ -105,9 +105,9 @@ def test_pipeline_append(): assert isinstance(ppl_2, type(ppl_1)) -def test_pipeline_execute_empty(): +def test_pipeline_stream_empty(): """ - test execute pipeline with mocked empty response + test stream pipeline with mocked empty response """ from google.cloud.firestore_v1.types import ExecutePipelineResponse from google.cloud.firestore_v1.types import ExecutePipelineRequest @@ -120,7 +120,7 @@ def test_pipeline_execute_empty(): mock_rpc.return_value = [ExecutePipelineResponse()] ppl_1 = _make_pipeline(GenericStage("s"), client=client) - results = list(ppl_1.execute()) + results = list(ppl_1.stream()) assert results == [] assert mock_rpc.call_count == 1 request = mock_rpc.call_args[0][0] @@ -129,9 +129,9 @@ def test_pipeline_execute_empty(): assert request.database == "projects/A/databases/B" -def test_pipeline_execute_no_doc_ref(): +def test_pipeline_stream_no_doc_ref(): """ - test execute pipeline with no doc ref + test stream pipeline with no doc ref """ from google.cloud.firestore_v1.types import Document from google.cloud.firestore_v1.types import ExecutePipelineResponse @@ -148,7 +148,7 @@ def test_pipeline_execute_no_doc_ref(): ] ppl_1 = _make_pipeline(GenericStage("s"), client=client) - results = list(ppl_1.execute()) + results = list(ppl_1.stream()) assert len(results) == 1 assert mock_rpc.call_count == 1 request = mock_rpc.call_args[0][0] @@ -166,9 +166,9 @@ def test_pipeline_execute_no_doc_ref(): assert response.data() == {} -def test_pipeline_execute_populated(): +def test_pipeline_stream_populated(): """ - test execute pipeline with fully populated doc ref + test stream pipeline with fully populated doc ref """ from google.cloud.firestore_v1.types import Document from google.cloud.firestore_v1.types import ExecutePipelineResponse @@ -200,7 +200,7 @@ def test_pipeline_execute_populated(): ] ppl_1 = _make_pipeline(client=client) - results = list(ppl_1.execute()) + results = list(ppl_1.stream()) assert len(results) == 1 assert mock_rpc.call_count == 1 request = mock_rpc.call_args[0][0] @@ -220,9 +220,9 @@ def test_pipeline_execute_populated(): assert response.data() == {"key": "str_val"} -def test_pipeline_execute_multiple(): +def test_pipeline_stream_multiple(): """ - test execute pipeline with multiple docs and responses + test stream pipeline with multiple docs and responses """ from google.cloud.firestore_v1.types import Document from google.cloud.firestore_v1.types import ExecutePipelineResponse @@ -256,7 +256,7 @@ def test_pipeline_execute_multiple(): ] ppl_1 = _make_pipeline(client=client) - results = list(ppl_1.execute()) + results = list(ppl_1.stream()) assert len(results) == 4 assert mock_rpc.call_count == 1 request = mock_rpc.call_args[0][0] @@ -269,9 +269,9 @@ def test_pipeline_execute_multiple(): assert response.data() == {"key": idx} -def test_pipeline_execute_with_transaction(): +def test_pipeline_stream_with_transaction(): """ - test execute pipeline with fully populated doc ref + test stream pipeline with fully populated doc ref """ from google.cloud.firestore_v1.types import ExecutePipelineResponse from google.cloud.firestore_v1.types import ExecutePipelineRequest @@ -288,10 +288,62 @@ def test_pipeline_execute_with_transaction(): mock_rpc.return_value = [ExecutePipelineResponse()] ppl_1 = _make_pipeline(client=client) - list(ppl_1.execute(transaction=transaction)) + list(ppl_1.stream(transaction=transaction)) assert mock_rpc.call_count == 1 request = mock_rpc.call_args[0][0] assert isinstance(request, ExecutePipelineRequest) assert request.structured_pipeline == ppl_1._to_pb() assert request.database == "projects/A/databases/B" assert request.transaction == b"123" + + +def test_pipeline_execute_stream_equivalence(): + """ + Pipeline.execute should provide same results from pipeline.stream, as a list + """ + from google.cloud.firestore_v1.types import Document + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import Value + from google.cloud.firestore_v1.client import Client + + real_client = Client() + client = mock.Mock() + client.project = "A" + client._database = "B" + client.document = real_client.document + mock_rpc = client._firestore_api.execute_pipeline + + mock_rpc.return_value = [ + ExecutePipelineResponse( + results=[ + Document( + name="test/my_doc", + fields={"key": Value(string_value="str_val")}, + ) + ], + ) + ] + ppl_1 = _make_pipeline(client=client) + + stream_results = list(ppl_1.stream()) + execute_results = ppl_1.execute() + assert stream_results == execute_results + assert stream_results[0].data()["key"] == "str_val" + assert execute_results[0].data()["key"] == "str_val" + + +def test_pipeline_execute_stream_equivalence_mocked(): + """ + pipeline.execute should call pipeline.stream internally + """ + ppl_1 = _make_pipeline() + expected_data = [object(), object()] + expected_arg = object() + with mock.patch.object(ppl_1, "stream") as mock_stream: + mock_stream.return_value = expected_data + stream_results = ppl_1.execute(expected_arg) + assert mock_stream.call_count == 1 + assert mock_stream.call_args[0] == () + assert len(mock_stream.call_args[1]) == 1 + assert mock_stream.call_args[1]["transaction"] == expected_arg + assert stream_results == expected_data From a818f52a26c20b278642a94aaab358fe542cb711 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 16 Jun 2025 14:28:20 -0700 Subject: [PATCH 128/131] removed converter reference --- google/cloud/firestore_v1/pipeline_result.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/google/cloud/firestore_v1/pipeline_result.py b/google/cloud/firestore_v1/pipeline_result.py index 84eeccecd..ada855fea 100644 --- a/google/cloud/firestore_v1/pipeline_result.py +++ b/google/cloud/firestore_v1/pipeline_result.py @@ -23,6 +23,7 @@ from google.cloud.firestore_v1.base_document import BaseDocumentReference from google.protobuf.timestamp_pb2 import Timestamp from google.cloud.firestore_v1.types.document import Value as ValueProto + from google.cloud.firestore_v1.vector import Vector class PipelineResult: @@ -109,16 +110,12 @@ def __eq__(self, other: object) -> bool: return NotImplemented return (self._ref == other._ref) and (self._fields_pb == other._fields_pb) - def data(self) -> Any: + def data(self) -> dict | "Vector" | None: """ Retrieves all fields in the result. - If a converter was provided to this `PipelineResult`, the result of the - converter's `from_firestore` method is returned. - Returns: - The data, either as a custom object (if a converter is used) or a dictionary. - Returns `None` if the document doesn't exist. + The data in dictionary format, or `None` if the document doesn't exist. """ if self._fields_pb is None: return None From 8a9c3eccc23b76894e7d95c285f4b11184e7fd78 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 16 Jun 2025 14:55:50 -0700 Subject: [PATCH 129/131] made stages private --- ...pipeline_stages.py => _pipeline_stages.py} | 0 google/cloud/firestore_v1/async_pipeline.py | 2 +- google/cloud/firestore_v1/base_pipeline.py | 28 +++++++++++++++---- google/cloud/firestore_v1/pipeline.py | 2 +- google/cloud/firestore_v1/pipeline_source.py | 13 ++++----- tests/unit/v1/test_async_pipeline.py | 25 +++++++++++------ tests/unit/v1/test_pipeline.py | 25 +++++++++++------ tests/unit/v1/test_pipeline_source.py | 2 +- tests/unit/v1/test_pipeline_stages.py | 2 +- 9 files changed, 66 insertions(+), 33 deletions(-) rename google/cloud/firestore_v1/{pipeline_stages.py => _pipeline_stages.py} (100%) diff --git a/google/cloud/firestore_v1/pipeline_stages.py b/google/cloud/firestore_v1/_pipeline_stages.py similarity index 100% rename from google/cloud/firestore_v1/pipeline_stages.py rename to google/cloud/firestore_v1/_pipeline_stages.py diff --git a/google/cloud/firestore_v1/async_pipeline.py b/google/cloud/firestore_v1/async_pipeline.py index 9fe0c8756..471c33093 100644 --- a/google/cloud/firestore_v1/async_pipeline.py +++ b/google/cloud/firestore_v1/async_pipeline.py @@ -14,7 +14,7 @@ from __future__ import annotations from typing import AsyncIterable, TYPE_CHECKING -from google.cloud.firestore_v1 import pipeline_stages as stages +from google.cloud.firestore_v1 import _pipeline_stages as stages from google.cloud.firestore_v1.base_pipeline import _BasePipeline if TYPE_CHECKING: # pragma: NO COVER diff --git a/google/cloud/firestore_v1/base_pipeline.py b/google/cloud/firestore_v1/base_pipeline.py index 851ad94ad..0950e08bb 100644 --- a/google/cloud/firestore_v1/base_pipeline.py +++ b/google/cloud/firestore_v1/base_pipeline.py @@ -14,12 +14,13 @@ from __future__ import annotations from typing import Iterable, TYPE_CHECKING -from google.cloud.firestore_v1 import pipeline_stages as stages +from google.cloud.firestore_v1 import _pipeline_stages as stages from google.cloud.firestore_v1.types.pipeline import ( StructuredPipeline as StructuredPipeline_pb, ) from google.cloud.firestore_v1.types.firestore import ExecutePipelineRequest from google.cloud.firestore_v1.pipeline_result import PipelineResult +from google.cloud.firestore_v1.pipeline_expressions import Expr from google.cloud.firestore_v1 import _helpers if TYPE_CHECKING: # pragma: NO COVER @@ -37,7 +38,23 @@ class _BasePipeline: Use `client.collection.("...").pipeline()` to create pipeline instances. """ - def __init__(self, client: Client | AsyncClient, *stages: stages.Stage): + def __init__(self, client: Client | AsyncClient): + """ + Initializes a new pipeline. + + Pipelines should not be instantiated directly. Instead, + call client.pipeline() to create an instance + + Args: + client: The client associated with the pipeline + """ + self._client = client + self.stages = tuple() + + @classmethod + def _create_with_stages( + cls, client: Client | AsyncClient, *stages + ) -> _BasePipeline: """ Initializes a new pipeline with the given stages. @@ -47,8 +64,9 @@ def __init__(self, client: Client | AsyncClient, *stages: stages.Stage): client: The client associated with the pipeline *stages: Initial stages for the pipeline. """ - self._client = client - self.stages = tuple(stages) + new_instance = cls(client) + new_instance.stages = tuple(stages) + return new_instance def __repr__(self): cls_str = type(self).__name__ @@ -69,7 +87,7 @@ def _append(self, new_stage): """ Create a new Pipeline object with a new stage appended """ - return self.__class__(self._client, *self.stages, new_stage) + return self.__class__._create_with_stages(self._client, *self.stages, new_stage) def _prep_execute_request( self, transaction: BaseTransaction | None diff --git a/google/cloud/firestore_v1/pipeline.py b/google/cloud/firestore_v1/pipeline.py index f578e00b6..9f568f925 100644 --- a/google/cloud/firestore_v1/pipeline.py +++ b/google/cloud/firestore_v1/pipeline.py @@ -14,7 +14,7 @@ from __future__ import annotations from typing import Iterable, TYPE_CHECKING -from google.cloud.firestore_v1 import pipeline_stages as stages +from google.cloud.firestore_v1 import _pipeline_stages as stages from google.cloud.firestore_v1.base_pipeline import _BasePipeline if TYPE_CHECKING: # pragma: NO COVER diff --git a/google/cloud/firestore_v1/pipeline_source.py b/google/cloud/firestore_v1/pipeline_source.py index 507282881..f2f081fee 100644 --- a/google/cloud/firestore_v1/pipeline_source.py +++ b/google/cloud/firestore_v1/pipeline_source.py @@ -14,7 +14,7 @@ from __future__ import annotations from typing import Generic, TypeVar, TYPE_CHECKING -from google.cloud.firestore_v1 import pipeline_stages as stages +from google.cloud.firestore_v1 import _pipeline_stages as stages from google.cloud.firestore_v1.base_pipeline import _BasePipeline if TYPE_CHECKING: # pragma: NO COVER @@ -30,19 +30,16 @@ class PipelineSource(Generic[PipelineType]): A factory for creating Pipeline instances, which provide a framework for building data transformation and query pipelines for Firestore. - Start by calling client.pipeline() to obtain an instance of PipelineSource. - From there, you can use the provided methods .collection() to specify the - data source for your pipeline. - - This class is typically used to start building Firestore pipelines. It allows you to define - the initial data source for a pipeline. + Not meant to be instantiated directly. Instead, start by calling client.pipeline() + to obtain an instance of PipelineSource. From there, you can use the provided + methods to specify the data source for your pipeline. """ def __init__(self, client: Client | AsyncClient): self.client = client def _create_pipeline(self, source_stage): - return self.client._pipeline_cls(self.client, source_stage) + return self.client._pipeline_cls._create_with_stages(self.client, source_stage) def collection(self, path: str) -> PipelineType: """ diff --git a/tests/unit/v1/test_async_pipeline.py b/tests/unit/v1/test_async_pipeline.py index 57f125b0e..396c00b15 100644 --- a/tests/unit/v1/test_async_pipeline.py +++ b/tests/unit/v1/test_async_pipeline.py @@ -19,7 +19,7 @@ def _make_async_pipeline(*args, client=mock.Mock()): from google.cloud.firestore_v1.async_pipeline import AsyncPipeline - return AsyncPipeline(client, *args) + return AsyncPipeline._create_with_stages(client, *args) async def _async_it(list): @@ -30,9 +30,18 @@ async def _async_it(list): def test_ctor(): from google.cloud.firestore_v1.async_pipeline import AsyncPipeline + client = object() + instance = AsyncPipeline(client) + assert instance._client == client + assert len(instance.stages) == 0 + + +def test_create(): + from google.cloud.firestore_v1.async_pipeline import AsyncPipeline + client = object() stages = [object() for i in range(10)] - instance = AsyncPipeline(client, *stages) + instance = AsyncPipeline._create_with_stages(client, *stages) assert instance._client == client assert len(instance.stages) == 10 assert instance.stages[0] == stages[0] @@ -54,7 +63,7 @@ def test_async_pipeline_repr_single_stage(): def test_async_pipeline_repr_multiple_stage(): - from google.cloud.firestore_v1.pipeline_stages import GenericStage, Collection + from google.cloud.firestore_v1._pipeline_stages import GenericStage, Collection stage_1 = Collection("path") stage_2 = GenericStage("second", 2) @@ -71,7 +80,7 @@ def test_async_pipeline_repr_multiple_stage(): def test_async_pipeline_repr_long(): - from google.cloud.firestore_v1.pipeline_stages import GenericStage + from google.cloud.firestore_v1._pipeline_stages import GenericStage num_stages = 100 stage_list = [GenericStage("custom", i) for i in range(num_stages)] @@ -83,7 +92,7 @@ def test_async_pipeline_repr_long(): def test_async_pipeline__to_pb(): from google.cloud.firestore_v1.types.pipeline import StructuredPipeline - from google.cloud.firestore_v1.pipeline_stages import GenericStage + from google.cloud.firestore_v1._pipeline_stages import GenericStage stage_1 = GenericStage("first") stage_2 = GenericStage("second") @@ -96,7 +105,7 @@ def test_async_pipeline__to_pb(): def test_async_pipeline_append(): """append should create a new pipeline with the additional stage""" - from google.cloud.firestore_v1.pipeline_stages import GenericStage + from google.cloud.firestore_v1._pipeline_stages import GenericStage stage_1 = GenericStage("first") ppl_1 = _make_async_pipeline(stage_1, client=object()) @@ -118,7 +127,7 @@ async def test_async_pipeline_stream_empty(): """ from google.cloud.firestore_v1.types import ExecutePipelineResponse from google.cloud.firestore_v1.types import ExecutePipelineRequest - from google.cloud.firestore_v1.pipeline_stages import GenericStage + from google.cloud.firestore_v1._pipeline_stages import GenericStage client = mock.Mock() client.project = "A" @@ -145,7 +154,7 @@ async def test_async_pipeline_stream_no_doc_ref(): from google.cloud.firestore_v1.types import Document from google.cloud.firestore_v1.types import ExecutePipelineResponse from google.cloud.firestore_v1.types import ExecutePipelineRequest - from google.cloud.firestore_v1.pipeline_stages import GenericStage + from google.cloud.firestore_v1._pipeline_stages import GenericStage from google.cloud.firestore_v1.pipeline_result import PipelineResult client = mock.Mock() diff --git a/tests/unit/v1/test_pipeline.py b/tests/unit/v1/test_pipeline.py index 55f0a1145..8389717fa 100644 --- a/tests/unit/v1/test_pipeline.py +++ b/tests/unit/v1/test_pipeline.py @@ -18,15 +18,24 @@ def _make_pipeline(*args, client=mock.Mock()): from google.cloud.firestore_v1.pipeline import Pipeline - return Pipeline(client, *args) + return Pipeline._create_with_stages(client, *args) def test_ctor(): from google.cloud.firestore_v1.pipeline import Pipeline + client = object() + instance = Pipeline(client) + assert instance._client == client + assert len(instance.stages) == 0 + + +def test_create(): + from google.cloud.firestore_v1.pipeline import Pipeline + client = object() stages = [object() for i in range(10)] - instance = Pipeline(client, *stages) + instance = Pipeline._create_with_stages(client, *stages) assert instance._client == client assert len(instance.stages) == 10 assert instance.stages[0] == stages[0] @@ -48,7 +57,7 @@ def test_pipeline_repr_single_stage(): def test_pipeline_repr_multiple_stage(): - from google.cloud.firestore_v1.pipeline_stages import GenericStage, Collection + from google.cloud.firestore_v1._pipeline_stages import GenericStage, Collection stage_1 = Collection("path") stage_2 = GenericStage("second", 2) @@ -65,7 +74,7 @@ def test_pipeline_repr_multiple_stage(): def test_pipeline_repr_long(): - from google.cloud.firestore_v1.pipeline_stages import GenericStage + from google.cloud.firestore_v1._pipeline_stages import GenericStage num_stages = 100 stage_list = [GenericStage("custom", i) for i in range(num_stages)] @@ -77,7 +86,7 @@ def test_pipeline_repr_long(): def test_pipeline__to_pb(): from google.cloud.firestore_v1.types.pipeline import StructuredPipeline - from google.cloud.firestore_v1.pipeline_stages import GenericStage + from google.cloud.firestore_v1._pipeline_stages import GenericStage stage_1 = GenericStage("first") stage_2 = GenericStage("second") @@ -90,7 +99,7 @@ def test_pipeline__to_pb(): def test_pipeline_append(): """append should create a new pipeline with the additional stage""" - from google.cloud.firestore_v1.pipeline_stages import GenericStage + from google.cloud.firestore_v1._pipeline_stages import GenericStage stage_1 = GenericStage("first") ppl_1 = _make_pipeline(stage_1, client=object()) @@ -111,7 +120,7 @@ def test_pipeline_stream_empty(): """ from google.cloud.firestore_v1.types import ExecutePipelineResponse from google.cloud.firestore_v1.types import ExecutePipelineRequest - from google.cloud.firestore_v1.pipeline_stages import GenericStage + from google.cloud.firestore_v1._pipeline_stages import GenericStage client = mock.Mock() client.project = "A" @@ -136,7 +145,7 @@ def test_pipeline_stream_no_doc_ref(): from google.cloud.firestore_v1.types import Document from google.cloud.firestore_v1.types import ExecutePipelineResponse from google.cloud.firestore_v1.types import ExecutePipelineRequest - from google.cloud.firestore_v1.pipeline_stages import GenericStage + from google.cloud.firestore_v1._pipeline_stages import GenericStage from google.cloud.firestore_v1.pipeline_result import PipelineResult client = mock.Mock() diff --git a/tests/unit/v1/test_pipeline_source.py b/tests/unit/v1/test_pipeline_source.py index 6eedfed36..cd8b56b68 100644 --- a/tests/unit/v1/test_pipeline_source.py +++ b/tests/unit/v1/test_pipeline_source.py @@ -17,7 +17,7 @@ from google.cloud.firestore_v1.async_pipeline import AsyncPipeline from google.cloud.firestore_v1.client import Client from google.cloud.firestore_v1.async_client import AsyncClient -from google.cloud.firestore_v1 import pipeline_stages as stages +from google.cloud.firestore_v1 import _pipeline_stages as stages class TestPipelineSource: diff --git a/tests/unit/v1/test_pipeline_stages.py b/tests/unit/v1/test_pipeline_stages.py index 549c0f917..59d808d63 100644 --- a/tests/unit/v1/test_pipeline_stages.py +++ b/tests/unit/v1/test_pipeline_stages.py @@ -14,7 +14,7 @@ import pytest -import google.cloud.firestore_v1.pipeline_stages as stages +import google.cloud.firestore_v1._pipeline_stages as stages from google.cloud.firestore_v1.pipeline_expressions import Constant from google.cloud.firestore_v1.types.document import Value from google.cloud.firestore_v1._helpers import GeoPoint From 06a2084e159def5ffc5b9647c1be99ef8c691a9b Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 16 Jun 2025 14:57:54 -0700 Subject: [PATCH 130/131] added generic_stage method to base_pipeline --- google/cloud/firestore_v1/base_pipeline.py | 23 +++++++++++ tests/unit/v1/test_async_pipeline.py | 48 +++++++++++++--------- tests/unit/v1/test_pipeline.py | 48 ++++++++++++++-------- 3 files changed, 82 insertions(+), 37 deletions(-) diff --git a/google/cloud/firestore_v1/base_pipeline.py b/google/cloud/firestore_v1/base_pipeline.py index 0950e08bb..929311769 100644 --- a/google/cloud/firestore_v1/base_pipeline.py +++ b/google/cloud/firestore_v1/base_pipeline.py @@ -126,3 +126,26 @@ def _execute_response_helper( doc._pb.create_time if doc.create_time else None, doc._pb.update_time if doc.update_time else None, ) + + def generic_stage(self, name: str, *params: Expr) -> "_BasePipeline": + """ + Adds a generic, named stage to the pipeline with specified parameters. + + This method provides a flexible way to extend the pipeline's functionality + by adding custom stages. Each generic stage is defined by a unique `name` + and a set of `params` that control its behavior. + + Example: + >>> # Assume we don't have a built-in "where" stage + >>> pipeline = client.pipeline().collection("books") + >>> pipeline = pipeline.generic_stage("where", [Field.of("published").lt(900)]) + >>> pipeline = pipeline.select("title", "author") + + Args: + name: The name of the generic stage. + *params: A sequence of `Expr` objects representing the parameters for the stage. + + Returns: + A new Pipeline object with this stage appended to the stage list + """ + return self._append(stages.GenericStage(name, *params)) diff --git a/tests/unit/v1/test_async_pipeline.py b/tests/unit/v1/test_async_pipeline.py index 396c00b15..3abc3619b 100644 --- a/tests/unit/v1/test_async_pipeline.py +++ b/tests/unit/v1/test_async_pipeline.py @@ -15,6 +15,8 @@ import mock import pytest +from google.cloud.firestore_v1 import _pipeline_stages as stages + def _make_async_pipeline(*args, client=mock.Mock()): from google.cloud.firestore_v1.async_pipeline import AsyncPipeline @@ -63,11 +65,9 @@ def test_async_pipeline_repr_single_stage(): def test_async_pipeline_repr_multiple_stage(): - from google.cloud.firestore_v1._pipeline_stages import GenericStage, Collection - - stage_1 = Collection("path") - stage_2 = GenericStage("second", 2) - stage_3 = GenericStage("third", 3) + stage_1 = stages.Collection("path") + stage_2 = stages.GenericStage("second", 2) + stage_3 = stages.GenericStage("third", 3) ppl = _make_async_pipeline(stage_1, stage_2, stage_3) repr_str = repr(ppl) assert repr_str == ( @@ -80,10 +80,8 @@ def test_async_pipeline_repr_multiple_stage(): def test_async_pipeline_repr_long(): - from google.cloud.firestore_v1._pipeline_stages import GenericStage - num_stages = 100 - stage_list = [GenericStage("custom", i) for i in range(num_stages)] + stage_list = [stages.GenericStage("custom", i) for i in range(num_stages)] ppl = _make_async_pipeline(*stage_list) repr_str = repr(ppl) assert repr_str.count("GenericStage") == num_stages @@ -92,10 +90,9 @@ def test_async_pipeline_repr_long(): def test_async_pipeline__to_pb(): from google.cloud.firestore_v1.types.pipeline import StructuredPipeline - from google.cloud.firestore_v1._pipeline_stages import GenericStage - stage_1 = GenericStage("first") - stage_2 = GenericStage("second") + stage_1 = stages.GenericStage("first") + stage_2 = stages.GenericStage("second") ppl = _make_async_pipeline(stage_1, stage_2) pb = ppl._to_pb() assert isinstance(pb, StructuredPipeline) @@ -105,11 +102,9 @@ def test_async_pipeline__to_pb(): def test_async_pipeline_append(): """append should create a new pipeline with the additional stage""" - from google.cloud.firestore_v1._pipeline_stages import GenericStage - - stage_1 = GenericStage("first") + stage_1 = stages.GenericStage("first") ppl_1 = _make_async_pipeline(stage_1, client=object()) - stage_2 = GenericStage("second") + stage_2 = stages.GenericStage("second") ppl_2 = ppl_1._append(stage_2) assert ppl_1 != ppl_2 assert len(ppl_1.stages) == 1 @@ -127,7 +122,6 @@ async def test_async_pipeline_stream_empty(): """ from google.cloud.firestore_v1.types import ExecutePipelineResponse from google.cloud.firestore_v1.types import ExecutePipelineRequest - from google.cloud.firestore_v1._pipeline_stages import GenericStage client = mock.Mock() client.project = "A" @@ -135,7 +129,7 @@ async def test_async_pipeline_stream_empty(): mock_rpc = mock.AsyncMock() client._firestore_api.execute_pipeline = mock_rpc mock_rpc.return_value = _async_it([ExecutePipelineResponse()]) - ppl_1 = _make_async_pipeline(GenericStage("s"), client=client) + ppl_1 = _make_async_pipeline(stages.GenericStage("s"), client=client) results = [r async for r in ppl_1.stream()] assert results == [] @@ -154,7 +148,6 @@ async def test_async_pipeline_stream_no_doc_ref(): from google.cloud.firestore_v1.types import Document from google.cloud.firestore_v1.types import ExecutePipelineResponse from google.cloud.firestore_v1.types import ExecutePipelineRequest - from google.cloud.firestore_v1._pipeline_stages import GenericStage from google.cloud.firestore_v1.pipeline_result import PipelineResult client = mock.Mock() @@ -165,7 +158,7 @@ async def test_async_pipeline_stream_no_doc_ref(): mock_rpc.return_value = _async_it( [ExecutePipelineResponse(results=[Document()], execution_time={"seconds": 9})] ) - ppl_1 = _make_async_pipeline(GenericStage("s"), client=client) + ppl_1 = _make_async_pipeline(stages.GenericStage("s"), client=client) results = [r async for r in ppl_1.stream()] assert len(results) == 1 @@ -381,3 +374,20 @@ async def test_async_pipeline_stream_stream_equivalence_mocked(): assert len(mock_stream.call_args[1]) == 1 assert mock_stream.call_args[1]["transaction"] == expected_arg assert stream_results == expected_data + + +@pytest.mark.parametrize( + "method,args,result_cls", + [ + ("generic_stage", ("name",), stages.GenericStage), + ("generic_stage", ("name", mock.Mock()), stages.GenericStage), + ], +) +def test_async_pipeline_methods(method, args, result_cls): + start_ppl = _make_async_pipeline() + method_ptr = getattr(start_ppl, method) + result_ppl = method_ptr(*args) + assert result_ppl != start_ppl + assert len(start_ppl.stages) == 0 + assert len(result_ppl.stages) == 1 + assert isinstance(result_ppl.stages[0], result_cls) diff --git a/tests/unit/v1/test_pipeline.py b/tests/unit/v1/test_pipeline.py index 8389717fa..6a3fef3ac 100644 --- a/tests/unit/v1/test_pipeline.py +++ b/tests/unit/v1/test_pipeline.py @@ -13,6 +13,9 @@ # limitations under the License import mock +import pytest + +from google.cloud.firestore_v1 import _pipeline_stages as stages def _make_pipeline(*args, client=mock.Mock()): @@ -57,11 +60,9 @@ def test_pipeline_repr_single_stage(): def test_pipeline_repr_multiple_stage(): - from google.cloud.firestore_v1._pipeline_stages import GenericStage, Collection - - stage_1 = Collection("path") - stage_2 = GenericStage("second", 2) - stage_3 = GenericStage("third", 3) + stage_1 = stages.Collection("path") + stage_2 = stages.GenericStage("second", 2) + stage_3 = stages.GenericStage("third", 3) ppl = _make_pipeline(stage_1, stage_2, stage_3) repr_str = repr(ppl) assert repr_str == ( @@ -74,10 +75,8 @@ def test_pipeline_repr_multiple_stage(): def test_pipeline_repr_long(): - from google.cloud.firestore_v1._pipeline_stages import GenericStage - num_stages = 100 - stage_list = [GenericStage("custom", i) for i in range(num_stages)] + stage_list = [stages.GenericStage("custom", i) for i in range(num_stages)] ppl = _make_pipeline(*stage_list) repr_str = repr(ppl) assert repr_str.count("GenericStage") == num_stages @@ -86,10 +85,9 @@ def test_pipeline_repr_long(): def test_pipeline__to_pb(): from google.cloud.firestore_v1.types.pipeline import StructuredPipeline - from google.cloud.firestore_v1._pipeline_stages import GenericStage - stage_1 = GenericStage("first") - stage_2 = GenericStage("second") + stage_1 = stages.GenericStage("first") + stage_2 = stages.GenericStage("second") ppl = _make_pipeline(stage_1, stage_2) pb = ppl._to_pb() assert isinstance(pb, StructuredPipeline) @@ -99,11 +97,10 @@ def test_pipeline__to_pb(): def test_pipeline_append(): """append should create a new pipeline with the additional stage""" - from google.cloud.firestore_v1._pipeline_stages import GenericStage - stage_1 = GenericStage("first") + stage_1 = stages.GenericStage("first") ppl_1 = _make_pipeline(stage_1, client=object()) - stage_2 = GenericStage("second") + stage_2 = stages.GenericStage("second") ppl_2 = ppl_1._append(stage_2) assert ppl_1 != ppl_2 assert len(ppl_1.stages) == 1 @@ -120,14 +117,13 @@ def test_pipeline_stream_empty(): """ from google.cloud.firestore_v1.types import ExecutePipelineResponse from google.cloud.firestore_v1.types import ExecutePipelineRequest - from google.cloud.firestore_v1._pipeline_stages import GenericStage client = mock.Mock() client.project = "A" client._database = "B" mock_rpc = client._firestore_api.execute_pipeline mock_rpc.return_value = [ExecutePipelineResponse()] - ppl_1 = _make_pipeline(GenericStage("s"), client=client) + ppl_1 = _make_pipeline(stages.GenericStage("s"), client=client) results = list(ppl_1.stream()) assert results == [] @@ -145,7 +141,6 @@ def test_pipeline_stream_no_doc_ref(): from google.cloud.firestore_v1.types import Document from google.cloud.firestore_v1.types import ExecutePipelineResponse from google.cloud.firestore_v1.types import ExecutePipelineRequest - from google.cloud.firestore_v1._pipeline_stages import GenericStage from google.cloud.firestore_v1.pipeline_result import PipelineResult client = mock.Mock() @@ -155,7 +150,7 @@ def test_pipeline_stream_no_doc_ref(): mock_rpc.return_value = [ ExecutePipelineResponse(results=[Document()], execution_time={"seconds": 9}) ] - ppl_1 = _make_pipeline(GenericStage("s"), client=client) + ppl_1 = _make_pipeline(stages.GenericStage("s"), client=client) results = list(ppl_1.stream()) assert len(results) == 1 @@ -356,3 +351,20 @@ def test_pipeline_execute_stream_equivalence_mocked(): assert len(mock_stream.call_args[1]) == 1 assert mock_stream.call_args[1]["transaction"] == expected_arg assert stream_results == expected_data + + +@pytest.mark.parametrize( + "method,args,result_cls", + [ + ("generic_stage", ("name",), stages.GenericStage), + ("generic_stage", ("name", mock.Mock()), stages.GenericStage), + ], +) +def test_pipeline_methods(method, args, result_cls): + start_ppl = _make_pipeline() + method_ptr = getattr(start_ppl, method) + result_ppl = method_ptr(*args) + assert result_ppl != start_ppl + assert len(start_ppl.stages) == 0 + assert len(result_ppl.stages) == 1 + assert isinstance(result_ppl.stages[0], result_cls) From 13389b877dd71668b8595a7b37d5707f9b982810 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 16 Jun 2025 15:15:26 -0700 Subject: [PATCH 131/131] fixed mypy --- google/cloud/firestore_v1/base_pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/google/cloud/firestore_v1/base_pipeline.py b/google/cloud/firestore_v1/base_pipeline.py index 929311769..dde906fe6 100644 --- a/google/cloud/firestore_v1/base_pipeline.py +++ b/google/cloud/firestore_v1/base_pipeline.py @@ -13,7 +13,7 @@ # limitations under the License. from __future__ import annotations -from typing import Iterable, TYPE_CHECKING +from typing import Iterable, Sequence, TYPE_CHECKING from google.cloud.firestore_v1 import _pipeline_stages as stages from google.cloud.firestore_v1.types.pipeline import ( StructuredPipeline as StructuredPipeline_pb, @@ -49,7 +49,7 @@ def __init__(self, client: Client | AsyncClient): client: The client associated with the pipeline """ self._client = client - self.stages = tuple() + self.stages: Sequence[stages.Stage] = tuple() @classmethod def _create_with_stages(