diff --git a/firestore/pipeline.go b/firestore/pipeline.go index a7048d496c30..830eff99a611 100644 --- a/firestore/pipeline.go +++ b/firestore/pipeline.go @@ -166,8 +166,8 @@ func (p *Pipeline) Offset(offset int) *Pipeline { // // client.Pipeline().Collection("users").Select("info.email") // client.Pipeline().Collection("users").Select(FieldOf("info.email")) -// client.Pipeline().Collection("users").Select(FieldOfPath([]string{"info", "email"})) -// client.Pipeline().Collection("users").Select(FieldOfPath([]string{"info", "email"})) +// client.Pipeline().Collection("users").Select(FieldOf([]string{"info", "email"})) +// client.Pipeline().Collection("users").Select(FieldOf([]string{"info", "email"})) // client.Pipeline().Collection("users").Select(Add("age", 5).As("agePlus5")) func (p *Pipeline) Select(fieldpathsOrSelectables ...any) *Pipeline { if p.err != nil { @@ -341,7 +341,7 @@ func (p *Pipeline) UnnestWithAlias(fieldpath any, alias string, opts *UnnestOpti case string: fieldExpr = FieldOf(v) case FieldPath: - fieldExpr = FieldOfPath(v) + fieldExpr = FieldOf(v) default: p.err = errInvalidArg(fieldpath, "string", "FieldPath") return p diff --git a/firestore/pipeline_aggregate.go b/firestore/pipeline_aggregate.go index 0143b39c2a55..f9cc8a68ed74 100644 --- a/firestore/pipeline_aggregate.go +++ b/firestore/pipeline_aggregate.go @@ -44,7 +44,7 @@ func newBaseAggregateFunction(name string, fieldOrExpr any) *baseAggregateFuncti case string: valueExpr = FieldOf(value) case FieldPath: - valueExpr = FieldOfPath(value) + valueExpr = FieldOf(value) case Expr: valueExpr = value default: diff --git a/firestore/pipeline_field.go b/firestore/pipeline_field.go index fa1edb6f7740..578e7ede37e2 100644 --- a/firestore/pipeline_field.go +++ b/firestore/pipeline_field.go @@ -28,21 +28,23 @@ type field struct { fieldPath FieldPath } -// FieldOf creates a new field [Expr] from a field path string. -func FieldOf(path string) Expr { - fieldPath, err := parseDotSeparatedString(path) - if err != nil { - return &field{baseExpr: &baseExpr{err: err}} +// FieldOf creates a new field [Expr] from a dot separated field path string or [FieldPath]. +func FieldOf[T string | FieldPath](path T) Expr { + var fieldPath FieldPath + switch p := any(path).(type) { + case string: + fp, err := parseDotSeparatedString(p) + if err != nil { + return &field{baseExpr: &baseExpr{err: err}} + } + fieldPath = fp + case FieldPath: + fieldPath = p } - return FieldOfPath(fieldPath) -} -// FieldOfPath creates a new field [Expr] for the given [FieldPath]. -func FieldOfPath(fieldPath FieldPath) Expr { if err := fieldPath.validate(); err != nil { return &field{baseExpr: &baseExpr{err: err}} } - pbVal := &pb.Value{ ValueType: &pb.Value_FieldReferenceValue{ FieldReferenceValue: fieldPath.toServiceFieldPath(), diff --git a/firestore/pipeline_integration_test.go b/firestore/pipeline_integration_test.go index 5b6c98e6e57b..f3936349ef0c 100644 --- a/firestore/pipeline_integration_test.go +++ b/firestore/pipeline_integration_test.go @@ -868,10 +868,10 @@ func aggregateFuncs(t *testing.T) { want: map[string]interface{}{"sum_a": int64(3)}, }, { - name: "Sum - FieldOfPath Expr", + name: "Sum - FieldOf Path Expr", pipeline: client.Pipeline(). Collection(coll.ID). - Aggregate(Sum(FieldOfPath(FieldPath([]string{"a"}))).As("sum_a")), + Aggregate(Sum(FieldOf(FieldPath([]string{"a"}))).As("sum_a")), want: map[string]interface{}{"sum_a": int64(3)}, }, { diff --git a/firestore/pipeline_stage.go b/firestore/pipeline_stage.go index 0df166aae45f..e05aac5cfd56 100644 --- a/firestore/pipeline_stage.go +++ b/firestore/pipeline_stage.go @@ -213,7 +213,7 @@ func newFindNearestStage(vectorField any, queryVector any, measure PipelineDista case string: propertyExpr = FieldOf(v) case FieldPath: - propertyExpr = FieldOfPath(v) + propertyExpr = FieldOf(v) case Expr: propertyExpr = v default: @@ -300,7 +300,7 @@ func newRemoveFieldsStage(fieldpaths ...any) (*removeFieldsStage, error) { case string: fields[i] = FieldOf(v) case FieldPath: - fields[i] = FieldOfPath(v) + fields[i] = FieldOf(v) default: return nil, errInvalidArg(fp, "string", "FieldPath") } @@ -332,7 +332,7 @@ func newReplaceStage(fieldOrSelectable any) (*replaceStage, error) { case string: expr = FieldOf(v) case FieldPath: - expr = FieldOfPath(v) + expr = FieldOf(v) case Selectable: _, expr = v.getSelectionDetails() default: @@ -463,7 +463,7 @@ func newUnnestStage(fieldExpr Expr, alias string, opts *UnnestOptions) (*unnestS var indexFieldExpr Expr switch v := opts.IndexField.(type) { case FieldPath: - indexFieldExpr = FieldOfPath(v) + indexFieldExpr = FieldOf(v) case string: indexFieldExpr = FieldOf(v) default: diff --git a/firestore/pipeline_utils.go b/firestore/pipeline_utils.go index 0c2a49b61018..3cf0d0982eac 100644 --- a/firestore/pipeline_utils.go +++ b/firestore/pipeline_utils.go @@ -117,7 +117,7 @@ func asFieldExpr(val any) Expr { case Expr: return v case FieldPath: - return FieldOfPath(v) + return FieldOf(v) case string: return FieldOf(v) default: @@ -191,7 +191,7 @@ func fieldsOrSelectablesToSelectables(fieldsOrSelectables ...any) ([]Selectable, } s = FieldOf(v).(*field) case FieldPath: - s = FieldOfPath(v).(*field) + s = FieldOf(v).(*field) case Selectable: s = v default: