diff --git a/google/cloud/firestore_v1/base_pipeline.py b/google/cloud/firestore_v1/base_pipeline.py index 63fee19fa..c66321793 100644 --- a/google/cloud/firestore_v1/base_pipeline.py +++ b/google/cloud/firestore_v1/base_pipeline.py @@ -315,6 +315,47 @@ def find_nearest( stages.FindNearest(field, vector, distance_measure, options) ) + def replace_with( + self, + field: Selectable, + ) -> "_BasePipeline": + """ + Fully overwrites all fields in a document with those coming from a nested map. + + This stage allows you to emit a map value as a document. Each key of the map becomes a field + on the document that contains the corresponding value. + + Example: + Input document: + ```json + { + "name": "John Doe Jr.", + "parents": { + "father": "John Doe Sr.", + "mother": "Jane Doe" + } + } + ``` + + >>> # Emit the 'parents' map as the document + >>> pipeline = client.pipeline().collection("people").replace_with(Field.of("parents")) + + Output document: + ```json + { + "father": "John Doe Sr.", + "mother": "Jane Doe" + } + ``` + + Args: + field: The `Selectable` field containing the map whose content will + replace the document. + Returns: + A new Pipeline object with this stage appended to the stage list + """ + return self._append(stages.ReplaceWith(field)) + def sort(self, *orders: stages.Ordering) -> "_BasePipeline": """ Sorts the documents from previous stages based on one or more `Ordering` criteria. diff --git a/google/cloud/firestore_v1/pipeline_stages.py b/google/cloud/firestore_v1/pipeline_stages.py index 95ce32021..37829465e 100644 --- a/google/cloud/firestore_v1/pipeline_stages.py +++ b/google/cloud/firestore_v1/pipeline_stages.py @@ -362,6 +362,17 @@ def _pb_args(self) -> list[Value]: return [f._to_pb() for f in self.fields] +class ReplaceWith(Stage): + """Replaces the document content with the value of a specified field.""" + + def __init__(self, field: Selectable): + super().__init__("replace_with") + self.field = Field(field) if isinstance(field, str) else field + + def _pb_args(self): + return [self.field._to_pb(), Value(string_value="full_replace")] + + class Sample(Stage): """Performs pseudo-random sampling of documents.""" diff --git a/tests/system/pipeline_e2e/general.yaml b/tests/system/pipeline_e2e/general.yaml index 23e98cf3d..8ff3f60d2 100644 --- a/tests/system/pipeline_e2e/general.yaml +++ b/tests/system/pipeline_e2e/general.yaml @@ -655,4 +655,33 @@ tests: fieldReferenceValue: tags_alias index: fieldReferenceValue: index - name: select \ No newline at end of file + name: select + - description: replaceWith + pipeline: + - Collection: books + - Where: + - Function.equal: + - Field: title + - Constant: "The Hitchhiker's Guide to the Galaxy" + - ReplaceWith: + - Field: awards + assert_results: + - hugo: True + nebula: False + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: "The Hitchhiker's Guide to the Galaxy" + name: equal + name: where + - args: + - fieldReferenceValue: awards + - stringValue: full_replace + name: replace_with \ No newline at end of file diff --git a/tests/unit/v1/test_pipeline.py b/tests/unit/v1/test_pipeline.py index b6d353f1a..e203f6d69 100644 --- a/tests/unit/v1/test_pipeline.py +++ b/tests/unit/v1/test_pipeline.py @@ -369,6 +369,8 @@ def test_pipeline_execute_stream_equivalence_mocked(): ("name", [0.1], "cosine", stages.FindNearestOptions(10)), stages.FindNearest, ), + ("replace_with", ("name",), stages.ReplaceWith), + ("replace_with", (Field.of("n"),), stages.ReplaceWith), ("sort", (Field.of("n").descending(),), stages.Sort), ("sort", (Field.of("n").descending(), Field.of("m").ascending()), stages.Sort), ("sample", (10,), stages.Sample), diff --git a/tests/unit/v1/test_pipeline_expressions.py b/tests/unit/v1/test_pipeline_expressions.py index d3b7dfbf2..ec7f4901e 100644 --- a/tests/unit/v1/test_pipeline_expressions.py +++ b/tests/unit/v1/test_pipeline_expressions.py @@ -555,6 +555,14 @@ def test__from_query_filter_pb_unknown_filter_type(self, mock_client): BooleanExpression._from_query_filter_pb(document_pb.Value(), mock_client) +class TestFunction: + def test_equals(self): + assert expr.Function.sqrt("1") == expr.Function.sqrt("1") + assert expr.Function.sqrt("1") != expr.Function.sqrt("2") + assert expr.Function.sqrt("1") != expr.Function.sum("1") + assert expr.Function.sqrt("1") != object() + + class TestArray: """Tests for the array class""" @@ -618,15 +626,7 @@ def test_w_exprs(self): ) -class TestFunction: - def test_equals(self): - assert expr.Function.sqrt("1") == expr.Function.sqrt("1") - assert expr.Function.sqrt("1") != expr.Function.sqrt("2") - assert expr.Function.sqrt("1") != expr.Function.sum("1") - assert expr.Function.sqrt("1") != object() - - -class TestExpressionMethods: +class TestExpressionessionMethods: """ contains test methods for each Expression method """ diff --git a/tests/unit/v1/test_pipeline_stages.py b/tests/unit/v1/test_pipeline_stages.py index 18c9d6790..a2d466f47 100644 --- a/tests/unit/v1/test_pipeline_stages.py +++ b/tests/unit/v1/test_pipeline_stages.py @@ -562,6 +562,39 @@ def test_to_pb(self): assert len(result.options) == 0 +class TestReplaceWith: + def _make_one(self, *args, **kwargs): + return stages.ReplaceWith(*args, **kwargs) + + @pytest.mark.parametrize( + "in_field,expected_field", + [ + ("test", Field.of("test")), + ("test", Field.of("test")), + ("test", Field.of("test")), + (Field.of("test"), Field.of("test")), + (Field.of("test"), Field.of("test")), + ], + ) + def test_ctor(self, in_field, expected_field): + instance = self._make_one(in_field) + assert instance.field == expected_field + assert instance.name == "replace_with" + + def test_repr(self): + instance = self._make_one("test") + repr_str = repr(instance) + assert repr_str == "ReplaceWith(field=Field.of('test'))" + + def test_to_pb(self): + instance = self._make_one(Field.of("test")) + result = instance._to_pb() + assert result.name == "replace_with" + assert len(result.args) == 2 + assert result.args[0].field_reference_value == "test" + assert result.args[1].string_value == "full_replace" + + class TestSample: class TestSampleOptions: def test_ctor_percent(self):