diff --git a/firestore/pipeline.go b/firestore/pipeline.go index 830eff99a611..3ee9d764ec82 100644 --- a/firestore/pipeline.go +++ b/firestore/pipeline.go @@ -484,3 +484,24 @@ func (p *Pipeline) FindNearest(vectorField any, queryVector any, measure Pipelin } return p.append(stage) } + +// RawStage adds a generic stage to the pipeline. +// This method provides a flexible way to extend the pipeline's functionality by adding custom stages. +// +// Example: +// +// // Assume we don't have a built-in "where" stage +// client.Pipeline().Collection("books"). +// RawStage( +// NewRawStage("where"). +// WithArguments( +// LessThan(FieldOf("published"), 1900), +// ), +// ). +// Select("title", "author") +func (p *Pipeline) RawStage(stage *RawStage) *Pipeline { + if p.err != nil { + return p + } + return p.append(stage) +} diff --git a/firestore/pipeline_integration_test.go b/firestore/pipeline_integration_test.go index 2e33f142cfa3..47f188682611 100644 --- a/firestore/pipeline_integration_test.go +++ b/firestore/pipeline_integration_test.go @@ -288,6 +288,36 @@ func TestIntegration_PipelineStages(t *testing.T) { t.Errorf("got title %q, want 'The Great Gatsby'", data["title"]) } }) + t.Run("RawStage", func(t *testing.T) { + // Using RawStage to perform a Limit operation + iter := client.Pipeline().Collection(coll.ID).RawStage(NewRawStage("limit").WithArguments(3)).Execute(ctx) + defer iter.Stop() + results, err := iter.GetAll() + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + if len(results) != 3 { + t.Errorf("got %d documents, want 3", len(results)) + } + + // Using RawStage to perform a Select operation with options + iter = client.Pipeline().Collection(coll.ID).RawStage(NewRawStage("select").WithArguments(map[string]interface{}{"title": FieldOf("title")})).Limit(1).Execute(ctx) + defer iter.Stop() + doc, err := iter.Next() + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + if !doc.Exists() { + t.Fatalf("Exists: got: false, want: true") + } + data := doc.Data() + if _, ok := data["title"]; !ok { + t.Error("missing 'title' field") + } + if _, ok := data["genre"]; ok { + t.Error("unexpected 'genre' field") + } + }) t.Run("RemoveFields", func(t *testing.T) { iter := client.Pipeline().Collection(coll.ID). Limit(1). diff --git a/firestore/pipeline_stage.go b/firestore/pipeline_stage.go index e05aac5cfd56..f6d41256297d 100644 --- a/firestore/pipeline_stage.go +++ b/firestore/pipeline_stage.go @@ -16,6 +16,7 @@ package firestore import ( "fmt" + "reflect" "strings" pb "cloud.google.com/go/firestore/apiv1/firestorepb" @@ -524,3 +525,61 @@ func newWhereStage(condition BooleanExpr) (*whereStage, error) { stagePb: newUnaryStage(stageNameWhere, argsPb), }}, nil } + +// RawStageOptions holds the options for a RawStage. +type RawStageOptions map[string]any + +// RawStage is a generic stage in the pipeline. +// It provides a flexible way to extend the pipeline's functionality by adding custom +// stages. It also allows the users to call the stages that are supported by the Firestore backend +// but not yet available in the current SDK version. +type RawStage struct { + stageName string + args []any + options RawStageOptions +} + +// NewRawStage creates a new RawStage with the given name. +func NewRawStage(name string) *RawStage { + return &RawStage{stageName: name} +} + +// WithArguments sets the arguments for the RawStage. +func (s *RawStage) WithArguments(args ...any) *RawStage { + s.args = args + return s +} + +// WithOptions sets the options for the RawStage. +func (s *RawStage) WithOptions(options RawStageOptions) *RawStage { + s.options = options + return s +} + +func (s *RawStage) name() string { return s.stageName } + +func (s *RawStage) toProto() (*pb.Pipeline_Stage, error) { + argsPb := make([]*pb.Value, len(s.args)) + for i, arg := range s.args { + val, _, err := toProtoValue(reflect.ValueOf(arg)) + if err != nil { + return nil, fmt.Errorf("firestore: error converting raw stage argument %d: %w", i, err) + } + argsPb[i] = val + } + + optionsPb := make(map[string]*pb.Value, len(s.options)) + for key, val := range s.options { + valPb, _, err := toProtoValue(reflect.ValueOf(val)) + if err != nil { + return nil, fmt.Errorf("firestore: error converting raw stage option %q: %w", key, err) + } + optionsPb[key] = valPb + } + + return &pb.Pipeline_Stage{ + Name: s.name(), + Args: argsPb, + Options: optionsPb, + }, nil +}