From abd71c1beecab539db9595c5c6a5aab25d60d6c0 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Thu, 8 Aug 2024 13:57:09 -0700 Subject: [PATCH] [Go] Added support for flow auth and Firebase auth plugin. (#722) --- docs-go/_guides.yaml | 2 + docs-go/auth.md | 57 ++++++++ docs-go/flows.md | 2 +- go/genkit/flow.go | 148 ++++++++++++++++++--- go/genkit/servers.go | 78 +++++------ go/genkit/servers_test.go | 13 +- go/go.mod | 6 + go/go.sum | 26 ++++ go/internal/doc-snippets/flows.go | 53 ++++++++ go/plugins/firebase/auth.go | 114 ++++++++++++++++ go/plugins/firebase/auth_test.go | 211 ++++++++++++++++++++++++++++++ go/plugins/firebase/firebase.go | 46 +++++++ go/samples/firebase-auth/main.go | 72 ++++++++++ 13 files changed, 763 insertions(+), 65 deletions(-) create mode 100644 docs-go/auth.md create mode 100644 go/plugins/firebase/auth.go create mode 100644 go/plugins/firebase/auth_test.go create mode 100644 go/plugins/firebase/firebase.go create mode 100644 go/samples/firebase-auth/main.go diff --git a/docs-go/_guides.yaml b/docs-go/_guides.yaml index f76832e17..3b60e2c1a 100644 --- a/docs-go/_guides.yaml +++ b/docs-go/_guides.yaml @@ -21,6 +21,8 @@ toc: path: /docs/genkit-go/models - title: Creating flows path: /docs/genkit-go/flows + - title: Adding authentication to flows + path: /docs/genkit-go/auth - title: Prompting models path: /docs/genkit-go/prompts - title: Managing prompts diff --git a/docs-go/auth.md b/docs-go/auth.md new file mode 100644 index 000000000..8c7d92799 --- /dev/null +++ b/docs-go/auth.md @@ -0,0 +1,57 @@ +# Flow Authentication + +Genkit supports flow-level authentication, allowing you to secure your flows and ensure that only authorized users can execute them. This is particularly useful when deploying flows as HTTP endpoints. + +## Configuring Flow Authentication + +To add authentication to a flow, you can use the `WithFlowAuth` option when defining the flow. This option takes an implementation of the `FlowAuth` interface, which provides methods for handling authentication and authorization. + +Here's an example of how to define a flow with authentication: + +```golang +{% includecode github_path="firebase/genkit/go/internal/doc-snippets/flows.go" region_tag="auth" adjust_indentation="auto" %} +``` + +In this example, we're using the Firebase auth plugin to handle authentication. The `policy` function defines the authorization logic, checking if the user ID in the auth context matches the input user ID. + +## Using the Firebase Auth Plugin + +The Firebase auth plugin provides an easy way to integrate Firebase Authentication with your Genkit flows. Here's how to use it: + +1. Import the Firebase plugin: + + ```golang + import "github.com/firebase/genkit/go/plugins/firebase" + ``` + +2. Create a Firebase auth provider: + + ```golang + {% includecode github_path="firebase/genkit/go/internal/doc-snippets/flows.go" region_tag="auth-create" adjust_indentation="auto" %} + ``` + + The `NewAuth` function takes three arguments: + + - `ctx`: The context for Firebase initialization. + - `policy`: A function that defines your authorization logic. + - `required`: A boolean indicating whether authentication is required for direct calls. + +3. Use the auth provider when defining your flow: + + ```golang + {% includecode github_path="firebase/genkit/go/internal/doc-snippets/flows.go" region_tag="auth-define" adjust_indentation="auto" %} + ``` + +## Handling Authentication in HTTP Requests + +When your flow is deployed as an HTTP endpoint, the Firebase auth plugin will automatically handle authentication for incoming requests. It expects a Bearer token in the Authorization header of the HTTP request. + +## Running Authenticated Flows Locally + +When running authenticated flows locally or from within other flows, you can provide local authentication context using the `WithLocalAuth` option: + +```golang +{% includecode github_path="firebase/genkit/go/internal/doc-snippets/flows.go" region_tag="auth-run" adjust_indentation="auto" %} +``` + +This allows you to test authenticated flows without needing to provide a valid Firebase token. diff --git a/docs-go/flows.md b/docs-go/flows.md index ff447857c..864e8b2e1 100644 --- a/docs-go/flows.md +++ b/docs-go/flows.md @@ -89,7 +89,7 @@ then call `Init()`: ``` `Init` starts a `net/http` server that exposes your flows as HTTP -endpoints (for example, `http://localhost:3400/menuSuggestionFlow`). +endpoints (for example, `http://localhost:3400/menuSuggestionFlow`). The second parameter is an optional `Options` that specifies the following: diff --git a/go/genkit/flow.go b/go/genkit/flow.go index 635916a90..7f62ef024 100644 --- a/go/genkit/flow.go +++ b/go/genkit/flow.go @@ -20,6 +20,7 @@ import ( "encoding/json" "errors" "fmt" + "log" "net/http" "strconv" "sync" @@ -98,28 +99,85 @@ type Flow[In, Out, Stream any] struct { tstate *tracing.State // set from the action when the flow is defined inputSchema *jsonschema.Schema // Schema of the input to the flow outputSchema *jsonschema.Schema // Schema of the output out of the flow + auth FlowAuth // Auth provider and policy checker for the flow. // TODO: scheduler // TODO: experimentalDurable - // TODO: authPolicy // TODO: middleware } +// runOptions configures a single flow run. +type runOptions struct { + authContext AuthContext // Auth context to pass to auth policy checker when calling a flow directly. +} + +// flowOptions configures a flow. +type flowOptions struct { + auth FlowAuth // Auth provider and policy checker for the flow. +} + type noStream = func(context.Context, struct{}) error +// AuthContext is the type of the auth context passed to the auth policy checker. +type AuthContext map[string]any + +// FlowAuth configures an auth context provider and an auth policy check for a flow. +type FlowAuth interface { + // ProvideAuthContext sets the auth context on the given context by parsing an auth header. + // The parsing logic is provided by the auth provider. + ProvideAuthContext(ctx context.Context, authHeader string) (context.Context, error) + + // NewContext sets the auth context on the given context. This is used when + // the auth context is provided by the user, rather than by the auth provider. + NewContext(ctx context.Context, authContext AuthContext) context.Context + + // FromContext retrieves the auth context from the given context. + FromContext(ctx context.Context) AuthContext + + // CheckAuthPolicy checks the auth context against policy. + CheckAuthPolicy(ctx context.Context, input any) error +} + // streamingCallback is the type of streaming callbacks. type streamingCallback[Stream any] func(context.Context, Stream) error +// FlowOption modifies the flow with the provided option. +type FlowOption func(opts *flowOptions) + +// FlowRunOption modifies a flow run with the provided option. +type FlowRunOption func(opts *runOptions) + +// WithFlowAuth sets an auth provider and policy checker for the flow. +func WithFlowAuth(auth FlowAuth) FlowOption { + return func(f *flowOptions) { + if f.auth != nil { + log.Panic("auth already set in flow") + } + f.auth = auth + } +} + +// WithLocalAuth configures an option to run or stream a flow with a local auth value. +func WithLocalAuth(authContext AuthContext) FlowRunOption { + return func(opts *runOptions) { + if opts.authContext != nil { + log.Panic("authContext already set in runOptions") + } + opts.authContext = authContext + } +} + // DefineFlow creates a Flow that runs fn, and registers it as an action. // // fn takes an input of type In and returns an output of type Out. func DefineFlow[In, Out any]( name string, fn func(ctx context.Context, input In) (Out, error), + opts ...FlowOption, ) *Flow[In, Out, struct{}] { return defineFlow(registry.Global, name, core.Func[In, Out, struct{}]( func(ctx context.Context, input In, cb func(ctx context.Context, _ struct{}) error) (Out, error) { return fn(ctx, input) - })) + }), opts...) } // DefineStreamingFlow creates a streaming Flow that runs fn, and registers it as an action. @@ -134,11 +192,12 @@ func DefineFlow[In, Out any]( func DefineStreamingFlow[In, Out, Stream any]( name string, fn func(ctx context.Context, input In, callback func(context.Context, Stream) error) (Out, error), + opts ...FlowOption, ) *Flow[In, Out, Stream] { - return defineFlow(registry.Global, name, core.Func[In, Out, Stream](fn)) + return defineFlow(registry.Global, name, core.Func[In, Out, Stream](fn), opts...) } -func defineFlow[In, Out, Stream any](r *registry.Registry, name string, fn core.Func[In, Out, Stream]) *Flow[In, Out, Stream] { +func defineFlow[In, Out, Stream any](r *registry.Registry, name string, fn core.Func[In, Out, Stream], opts ...FlowOption) *Flow[In, Out, Stream] { var i In var o Out f := &Flow[In, Out, Stream]{ @@ -148,12 +207,27 @@ func defineFlow[In, Out, Stream any](r *registry.Registry, name string, fn core. outputSchema: base.InferJSONSchema(o), // TODO: set stateStore? } + flowOpts := &flowOptions{} + for _, opt := range opts { + opt(flowOpts) + } + f.auth = flowOpts.auth metadata := map[string]any{ "inputSchema": f.inputSchema, "outputSchema": f.outputSchema, + "requiresAuth": f.auth != nil, } afunc := func(ctx context.Context, inst *flowInstruction[In], cb func(context.Context, Stream) error) (*flowState[In, Out], error) { tracing.SetCustomMetadataAttr(ctx, "flow:wrapperAction", "true") + // Only non-durable flows have an auth policy so can safely assume Start.Input. + if inst.Start != nil { + if f.auth != nil { + ctx = f.auth.NewContext(ctx, inst.Auth) + } + if err := f.checkAuthPolicy(ctx, any(inst.Start.Input)); err != nil { + return nil, err + } + } return f.runInstruction(ctx, inst, streamingCallback[Stream](cb)) } core.DefineActionInRegistry(r, "", f.name, atype.Flow, metadata, nil, afunc) @@ -167,18 +241,19 @@ func defineFlow[In, Out, Stream any](r *registry.Registry, name string, fn core. // A flowInstruction is an instruction to follow with a flow. // It is the input for the flow's action. // Exactly one field will be non-nil. -type flowInstruction[I any] struct { - Start *startInstruction[I] `json:"start,omitempty"` +type flowInstruction[In any] struct { + Start *startInstruction[In] `json:"start,omitempty"` Resume *resumeInstruction `json:"resume,omitempty"` - Schedule *scheduleInstruction[I] `json:"schedule,omitempty"` + Schedule *scheduleInstruction[In] `json:"schedule,omitempty"` RunScheduled *runScheduledInstruction `json:"runScheduled,omitempty"` State *stateInstruction `json:"state,omitempty"` Retry *retryInstruction `json:"retry,omitempty"` + Auth map[string]any `json:"auth,omitempty"` } // A startInstruction starts a flow. -type startInstruction[I any] struct { - Input I `json:"input,omitempty"` +type startInstruction[In any] struct { + Input In `json:"input,omitempty"` Labels map[string]string `json:"labels,omitempty"` } @@ -189,9 +264,9 @@ type resumeInstruction struct { } // A scheduleInstruction schedules a flow to start at a later time. -type scheduleInstruction[I any] struct { +type scheduleInstruction[In any] struct { DelaySecs float64 `json:"delay,omitempty"` - Input I `json:"input,omitempty"` + Input In `json:"input,omitempty"` } // A runScheduledInstruction starts a scheduled flow. @@ -324,7 +399,7 @@ func (f *Flow[In, Out, Stream]) runInstruction(ctx context.Context, inst *flowIn // Name returns the name that the flow was defined with. func (f *Flow[In, Out, Stream]) Name() string { return f.name } -func (f *Flow[In, Out, Stream]) runJSON(ctx context.Context, input json.RawMessage, cb streamingCallback[json.RawMessage]) (json.RawMessage, error) { +func (f *Flow[In, Out, Stream]) runJSON(ctx context.Context, authHeader string, input json.RawMessage, cb streamingCallback[json.RawMessage]) (json.RawMessage, error) { // Validate input before unmarshaling it because invalid or unknown fields will be discarded in the process. if err := base.ValidateJSON(input, f.inputSchema); err != nil { return nil, &base.HTTPError{Code: http.StatusBadRequest, Err: err} @@ -333,6 +408,13 @@ func (f *Flow[In, Out, Stream]) runJSON(ctx context.Context, input json.RawMessa if err := json.Unmarshal(input, &in); err != nil { return nil, &base.HTTPError{Code: http.StatusBadRequest, Err: err} } + newCtx, err := f.provideAuthContext(ctx, authHeader) + if err != nil { + return nil, &base.HTTPError{Code: http.StatusUnauthorized, Err: err} + } + if err := f.checkAuthPolicy(newCtx, in); err != nil { + return nil, &base.HTTPError{Code: http.StatusForbidden, Err: err} + } // If there is a callback, wrap it to turn an S into a json.RawMessage. var callback streamingCallback[Stream] if cb != nil { @@ -361,6 +443,28 @@ func (f *Flow[In, Out, Stream]) runJSON(ctx context.Context, input json.RawMessa return json.Marshal(res.Response) } +// provideAuthContext provides auth context for the given auth header if flow auth is configured. +func (f *Flow[In, Out, Stream]) provideAuthContext(ctx context.Context, authHeader string) (context.Context, error) { + if f.auth != nil { + newCtx, err := f.auth.ProvideAuthContext(ctx, authHeader) + if err != nil { + return nil, fmt.Errorf("unauthorized: %w", err) + } + return newCtx, nil + } + return ctx, nil +} + +// checkAuthPolicy checks auth context against the policy if flow auth is configured. +func (f *Flow[In, Out, Stream]) checkAuthPolicy(ctx context.Context, input any) error { + if f.auth != nil { + if err := f.auth.CheckAuthPolicy(ctx, input); err != nil { + return fmt.Errorf("permission denied for resource: %w", err) + } + } + return nil +} + // start starts executing the flow with the given input. func (f *Flow[In, Out, Stream]) start(ctx context.Context, input In, cb streamingCallback[Stream]) (_ *flowState[In, Out], err error) { flowID, err := generateFlowID() @@ -569,11 +673,21 @@ func Run[Out any](ctx context.Context, name string, f func() (Out, error)) (Out, // Run runs the flow in the context of another flow. The flow must run to completion when started // (that is, it must not have interrupts). -func (f *Flow[In, Out, Stream]) Run(ctx context.Context, input In) (Out, error) { - return f.run(ctx, input, nil) +func (f *Flow[In, Out, Stream]) Run(ctx context.Context, input In, opts ...FlowRunOption) (Out, error) { + return f.run(ctx, input, nil, opts...) } -func (f *Flow[In, Out, Stream]) run(ctx context.Context, input In, cb func(context.Context, Stream) error) (Out, error) { +func (f *Flow[In, Out, Stream]) run(ctx context.Context, input In, cb func(context.Context, Stream) error, opts ...FlowRunOption) (Out, error) { + runOpts := &runOptions{} + for _, opt := range opts { + opt(runOpts) + } + if runOpts.authContext != nil && f.auth != nil { + ctx = f.auth.NewContext(ctx, runOpts.authContext) + } + if err := f.checkAuthPolicy(ctx, input); err != nil { + return base.Zero[Out](), err + } state, err := f.start(ctx, input, cb) if err != nil { return base.Zero[Out](), err @@ -602,7 +716,7 @@ type StreamFlowValue[Out, Stream any] struct { // again. // // Otherwise the Stream field of the passed [StreamFlowValue] holds a streamed result. -func (f *Flow[In, Out, Stream]) Stream(ctx context.Context, input In) func(func(*StreamFlowValue[Out, Stream], error) bool) { +func (f *Flow[In, Out, Stream]) Stream(ctx context.Context, input In, opts ...FlowRunOption) func(func(*StreamFlowValue[Out, Stream], error) bool) { return func(yield func(*StreamFlowValue[Out, Stream], error) bool) { cb := func(ctx context.Context, s Stream) error { if ctx.Err() != nil { @@ -613,7 +727,7 @@ func (f *Flow[In, Out, Stream]) Stream(ctx context.Context, input In) func(func( } return nil } - output, err := f.run(ctx, input, cb) + output, err := f.run(ctx, input, cb, opts...) if err != nil { yield(nil, err) } else { diff --git a/go/genkit/servers.go b/go/genkit/servers.go index 26f7fb022..161e3d191 100644 --- a/go/genkit/servers.go +++ b/go/genkit/servers.go @@ -27,16 +27,13 @@ import ( "encoding/json" "errors" "fmt" - "io" "io/fs" "log/slog" "net/http" "os" - "os/signal" "strconv" "sync" "sync/atomic" - "syscall" "time" "github.com/firebase/genkit/go/core/logger" @@ -76,7 +73,7 @@ type flow interface { // runJSON uses encoding/json to unmarshal the input, // calls Flow.start, then returns the marshaled result. - runJSON(ctx context.Context, input json.RawMessage, cb streamingCallback[json.RawMessage]) (json.RawMessage, error) + runJSON(ctx context.Context, authHeader string, input json.RawMessage, cb streamingCallback[json.RawMessage]) (json.RawMessage, error) } // startServer starts an HTTP server listening on the address. @@ -163,19 +160,17 @@ func (s *devServer) handleRunAction(w http.ResponseWriter, r *http.Request) erro if err := json.NewDecoder(r.Body).Decode(&body); err != nil { return &base.HTTPError{Code: http.StatusBadRequest, Err: err} } - stream := false - if s := r.FormValue("stream"); s != "" { - var err error - stream, err = strconv.ParseBool(s) - if err != nil { - return err - } + stream, err := parseBoolQueryParam(r, "stream") + if err != nil { + return err } logger.FromContext(ctx).Debug("running action", "key", body.Key, "stream", stream) var callback streamingCallback[json.RawMessage] if stream { + w.Header().Set("Content-Type", "text/plain") + w.Header().Set("Transfer-Encoding", "chunked") // Stream results are newline-separated JSON. callback = func(ctx context.Context, msg json.RawMessage) error { _, err := fmt.Fprintf(w, "%s\n", msg) @@ -328,29 +323,42 @@ func newFlowServeMux(r *registry.Registry, flows []string) *http.ServeMux { func nonDurableFlowHandler(f flow) func(http.ResponseWriter, *http.Request) error { return func(w http.ResponseWriter, r *http.Request) error { + var body struct { + Data json.RawMessage `json:"data"` + } defer r.Body.Close() - input, err := io.ReadAll(r.Body) - if err != nil { - return err + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + return &base.HTTPError{Code: http.StatusBadRequest, Err: err} } stream, err := parseBoolQueryParam(r, "stream") if err != nil { return err } + var callback streamingCallback[json.RawMessage] if stream { - // TODO: implement streaming. - return &base.HTTPError{Code: http.StatusNotImplemented, Err: errors.New("streaming")} - } else { - // TODO: telemetry - out, err := f.runJSON(r.Context(), json.RawMessage(input), nil) - if err != nil { - return err + w.Header().Set("Content-Type", "text/plain") + w.Header().Set("Transfer-Encoding", "chunked") + // Stream results are newline-separated JSON. + callback = func(ctx context.Context, msg json.RawMessage) error { + _, err := fmt.Fprintf(w, "%s\n", msg) + if err != nil { + return err + } + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + return nil } - // Responses for non-streaming, non-durable flows are passed back - // with the flow result stored in a field called "result." - _, err = fmt.Fprintf(w, `{"result": %s}\n`, out) + } + // TODO: telemetry + out, err := f.runJSON(r.Context(), r.Header.Get("Authorization"), body.Data, callback) + if err != nil { return err } + // Responses for non-streaming, non-durable flows are passed back + // with the flow result stored in a field called "result." + _, err = fmt.Fprintf(w, `{"result": %s}\n`, out) + return err } } @@ -365,28 +373,6 @@ func serverAddress(arg, envVar, defaultValue string) string { return defaultValue } -func listenAndServe(addr string, mux *http.ServeMux) error { - server := &http.Server{ - Addr: addr, - Handler: mux, - } - sigCh := make(chan os.Signal, 1) - signal.Notify(sigCh, syscall.SIGTERM) - go func() { - <-sigCh - slog.Info("received SIGTERM, shutting down server") - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - if err := server.Shutdown(ctx); err != nil { - slog.Error("server shutdown failed", "err", err) - } else { - slog.Info("server shutdown successfully") - } - }() - slog.Info("listening", "addr", addr) - return server.ListenAndServe() -} - // requestID is a unique ID for each request. var requestID atomic.Int64 diff --git a/go/genkit/servers_test.go b/go/genkit/servers_test.go index e95369252..61a3fb1db 100644 --- a/go/genkit/servers_test.go +++ b/go/genkit/servers_test.go @@ -15,6 +15,7 @@ package genkit import ( + "bytes" "context" "encoding/json" "io" @@ -140,7 +141,17 @@ func TestProdServer(t *testing.T) { defer srv.Close() check := func(t *testing.T, input string, wantStatus, wantResult int) { - res, err := http.Post(srv.URL+"/inc", "application/json", strings.NewReader(input)) + type body struct { + Data json.RawMessage `json:"data"` + } + payload := body{ + Data: json.RawMessage([]byte(input)), + } + jsonPayload, err := json.Marshal(payload) + if err != nil { + t.Fatal(err) + } + res, err := http.Post(srv.URL+"/inc", "application/json", bytes.NewBuffer(jsonPayload)) if err != nil { t.Fatal(err) } diff --git a/go/go.mod b/go/go.mod index 1d2601c41..1f8186eae 100644 --- a/go/go.mod +++ b/go/go.mod @@ -6,6 +6,7 @@ require ( cloud.google.com/go/aiplatform v1.68.0 cloud.google.com/go/logging v1.10.0 cloud.google.com/go/vertexai v0.12.1-0.20240711230438-265963bd5b91 + firebase.google.com/go/v4 v4.14.1 github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.46.0 github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/trace v1.22.0 github.com/aymerick/raymond v2.0.2+incompatible @@ -39,11 +40,14 @@ require ( cloud.google.com/go/auth v0.7.0 // indirect cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect cloud.google.com/go/compute/metadata v0.4.0 // indirect + cloud.google.com/go/firestore v1.15.0 // indirect cloud.google.com/go/iam v1.1.10 // indirect cloud.google.com/go/longrunning v0.5.9 // indirect cloud.google.com/go/monitoring v1.20.1 // indirect + cloud.google.com/go/storage v1.41.0 // indirect cloud.google.com/go/trace v1.10.9 // indirect github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.46.0 // indirect + github.com/MicahParks/keyfunc v1.9.0 // indirect github.com/PuerkitoBio/purell v1.1.1 // indirect github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578 // indirect github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 // indirect @@ -61,6 +65,7 @@ require ( github.com/go-openapi/strfmt v0.23.0 // indirect github.com/go-openapi/swag v0.22.3 // indirect github.com/go-openapi/validate v0.21.0 // indirect + github.com/golang-jwt/jwt/v4 v4.5.0 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/google/s2a-go v0.1.7 // indirect @@ -84,6 +89,7 @@ require ( golang.org/x/sys v0.22.0 // indirect golang.org/x/text v0.16.0 // indirect golang.org/x/time v0.5.0 // indirect + google.golang.org/appengine/v2 v2.0.2 // indirect google.golang.org/genproto v0.0.0-20240708141625-4ad9e859172b // indirect google.golang.org/genproto/googleapis/api v0.0.0-20240701130421-f6361c86f094 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240708141625-4ad9e859172b // indirect diff --git a/go/go.sum b/go/go.sum index 4fce3712f..2be2270c8 100644 --- a/go/go.sum +++ b/go/go.sum @@ -11,6 +11,8 @@ cloud.google.com/go/auth/oauth2adapt v0.2.2 h1:+TTV8aXpjeChS9M+aTtN/TjdQnzJvmzKF cloud.google.com/go/auth/oauth2adapt v0.2.2/go.mod h1:wcYjgpZI9+Yu7LyYBg4pqSiaRkfEK3GQcpb7C/uyF1Q= cloud.google.com/go/compute/metadata v0.4.0 h1:vHzJCWaM4g8XIcm8kopr3XmDA4Gy/lblD3EhhSux05c= cloud.google.com/go/compute/metadata v0.4.0/go.mod h1:SIQh1Kkb4ZJ8zJ874fqVkslA29PRXuleyj6vOzlbK7M= +cloud.google.com/go/firestore v1.15.0 h1:/k8ppuWOtNuDHt2tsRV42yI21uaGnKDEQnRFeBpbFF8= +cloud.google.com/go/firestore v1.15.0/go.mod h1:GWOxFXcv8GZUtYpWHw/w6IuYNux/BtmeVTMmjrm4yhk= cloud.google.com/go/iam v1.1.10 h1:ZSAr64oEhQSClwBL670MsJAW5/RLiC6kfw3Bqmd5ZDI= cloud.google.com/go/iam v1.1.10/go.mod h1:iEgMq62sg8zx446GCaijmA2Miwg5o3UbO+nI47WHJps= cloud.google.com/go/logging v1.10.0 h1:f+ZXMqyrSJ5vZ5pE/zr0xC8y/M9BLNzQeLBwfeZ+wY4= @@ -19,12 +21,16 @@ cloud.google.com/go/longrunning v0.5.9 h1:haH9pAuXdPAMqHvzX0zlWQigXT7B0+CL4/2nXX cloud.google.com/go/longrunning v0.5.9/go.mod h1:HD+0l9/OOW0za6UWdKJtXoFAX/BGg/3Wj8p10NeWF7c= cloud.google.com/go/monitoring v1.20.1 h1:XmM6uk4+mI2ZhWdI2n/2GNhJdpeQN+1VdG2UWEDhX48= cloud.google.com/go/monitoring v1.20.1/go.mod h1:FYSe/brgfuaXiEzOQFhTjsEsJv+WePyK71X7Y8qo6uQ= +cloud.google.com/go/storage v1.41.0 h1:RusiwatSu6lHeEXe3kglxakAmAbfV+rhtPqA6i8RBx0= +cloud.google.com/go/storage v1.41.0/go.mod h1:J1WCa/Z2FcgdEDuPUY8DxT5I+d9mFKsCepp5vR6Sq80= cloud.google.com/go/trace v1.10.9 h1:Cy6D1Zdz8up4mIPUWModTuIGDr3fh7AZaCnR+uyxpgA= cloud.google.com/go/trace v1.10.9/go.mod h1:vtWRnvEh+d8h2xljwxVwsdxxpoWZkxcNYnJF3FuJUV8= cloud.google.com/go/vertexai v0.12.1-0.20240711230438-265963bd5b91 h1:JwSkFKQ/yI97gCjMMnaEOZAigRpN53yiH6gJzik/OYA= cloud.google.com/go/vertexai v0.12.1-0.20240711230438-265963bd5b91/go.mod h1:KrfEQtFq2gqyHt4kZ+k1kIo5oy9Jw90yEHxgPsyl1bw= entgo.io/ent v0.13.1 h1:uD8QwN1h6SNphdCCzmkMN3feSUzNnVvV/WIkHKMbzOE= entgo.io/ent v0.13.1/go.mod h1:qCEmo+biw3ccBn9OyL4ZK5dfpwg++l1Gxwac5B1206A= +firebase.google.com/go/v4 v4.14.1 h1:4qiUETaFRWoFGE1XP5VbcEdtPX93Qs+8B/7KvP2825g= +firebase.google.com/go/v4 v4.14.1/go.mod h1:fgk2XshgNDEKaioKco+AouiegSI9oTWVqRaBdTTGBoM= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.46.0 h1:n3T26hyfDl9RdgcOjWvOFMh1lCBNuZ0JQ/3DM5pou8Y= github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.46.0/go.mod h1:3S7qK2nHOO2cLID3xk6H8f55D38XswhVFzKEk0nqIbY= @@ -34,6 +40,8 @@ github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/cloudmock v0 github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/cloudmock v0.46.0/go.mod h1:V28hx+cUCZC9e3qcqszMb+Sbt8cQZtHTiXOmyDzoDOg= github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.46.0 h1:xlfPHZ5QFvHad9KmrVDoaPpJUT/XluwNDMNHn+k7z/s= github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.46.0/go.mod h1:mzI44HpPp75Z8/a1sJP1asdHdu7Wui7t10SZ9EEPPnM= +github.com/MicahParks/keyfunc v1.9.0 h1:lhKd5xrFHLNOWrDc4Tyb/Q1AJ4LCzQ48GVJyVIID3+o= +github.com/MicahParks/keyfunc v1.9.0/go.mod h1:IdnCilugA0O/99dW+/MkvlyrsX8+L8+x95xuVNtM5jw= github.com/PuerkitoBio/purell v1.1.1 h1:WEQqlqaGbrPkxLJWfBwQmfEAE1Z7ONdDLqrN38tNFfI= github.com/PuerkitoBio/purell v1.1.1/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0= github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578 h1:d+Bc7a5rLufV/sSk/8dngufqelfh6jnri85riMAaF/M= @@ -97,6 +105,9 @@ github.com/go-pg/pg/v10 v10.11.0 h1:CMKJqLgTrfpE/aOVeLdybezR2om071Vh38OLZjsyMI0= github.com/go-pg/pg/v10 v10.11.0/go.mod h1:4BpHRoxE61y4Onpof3x1a2SQvi9c+q1dJnrNdMjsroA= github.com/go-pg/zerochecker v0.2.0 h1:pp7f72c3DobMWOb2ErtZsnrPaSvHd2W4o9//8HtF4mU= github.com/go-pg/zerochecker v0.2.0/go.mod h1:NJZ4wKL0NmTtz0GKCoJ8kym6Xn/EQzXRl2OnAe7MmDo= +github.com/golang-jwt/jwt/v4 v4.4.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg= +github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/gobuffalo/attrs v0.0.0-20190224210810-a9411de4debd/go.mod h1:4duuawTqi2wkkpB4ePgWMaai6/Kc6WEz83bhFwpHzj0= github.com/gobuffalo/depgen v0.0.0-20190329151759-d478694a28d3/go.mod h1:3STtPUQYuzV0gBVOY3vy6CfMm/ljR4pABfrTeHNLHUY= @@ -128,6 +139,7 @@ github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= @@ -150,6 +162,8 @@ github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/martian/v3 v3.3.3 h1:DIhPTQrbPkgs2yJYdXU/eNACCG5DVQjySNRNlflZ9Fc= +github.com/google/martian/v3 v3.3.3/go.mod h1:iEPrYcgCF7jA9OtScMFQyAlZZ4YXTKEtJ1E6RWzmBA0= github.com/google/s2a-go v0.1.7 h1:60BLSyTrOV4/haCDW4zb1guZItoSq8foHCXrAnjBo/o= github.com/google/s2a-go v0.1.7/go.mod h1:50CgR4k1jNlWBu4UfS4AcfhVe1r6pdZPygJ3R8F0Qdw= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= @@ -316,6 +330,7 @@ golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73r golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20220708220712-1185a9018129/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.0.0-20210421230115-4e50805a0758/go.mod h1:72T/g9IO56b78aLF+1Kcs5dz7/ng1VjMUvfKvpfy+jM= golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys= golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= @@ -339,6 +354,13 @@ golang.org/x/sys v0.0.0-20190419153524-e8e3143a4f4a/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190531175056-4c3a928424d2/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= +golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210420072515-93ed5bcd2bfe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= @@ -365,10 +387,14 @@ golang.org/x/tools v0.0.0-20190531172133-b3315ee88b7d/go.mod h1:/rFqwRUd4F7ZHNgw golang.org/x/tools v0.23.0 h1:SGsXPZ+2l4JsgaCKkx+FQ9YZ5XEtA1GZYuoDjenLjvg= golang.org/x/tools v0.23.0/go.mod h1:pnu6ufv6vQkll6szChhK3C3L/ruaIv5eBeztNG8wtsI= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 h1:+cNy6SZtPcJQH3LJVLOSmiC7MMxXNOb3PU/VUEz+EhU= +golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028/go.mod h1:NDW/Ps6MPRej6fsCIbMTohpP40sJ/P/vI1MoTEGwX90= google.golang.org/api v0.188.0 h1:51y8fJ/b1AaaBRJr4yWm96fPcuxSo0JcegXE3DaHQHw= google.golang.org/api v0.188.0/go.mod h1:VR0d+2SIiWOYG3r/jdm7adPW9hI2aRv9ETOSCQ9Beag= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/appengine/v2 v2.0.2 h1:MSqyWy2shDLwG7chbwBJ5uMyw6SNqJzhJHNDwYB0Akk= +google.golang.org/appengine/v2 v2.0.2/go.mod h1:PkgRUWz4o1XOvbqtWTkBtCitEJ5Tp4HoVEdMMYQR/8E= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= diff --git a/go/internal/doc-snippets/flows.go b/go/internal/doc-snippets/flows.go index fb5094520..b91f02e36 100644 --- a/go/internal/doc-snippets/flows.go +++ b/go/internal/doc-snippets/flows.go @@ -16,12 +16,14 @@ package snippets import ( "context" + "errors" "fmt" "log" "net/http" "strings" "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/firebase" ) func f1() { @@ -172,3 +174,54 @@ func deploy(ctx context.Context) { } // [END init] } + +func f5() { + // [START auth] + ctx := context.Background() + // Define an auth policy and create a Firebase auth provider + firebaseAuth, err := firebase.NewAuth(ctx, func(authContext genkit.AuthContext, input any) error { + // The type must match the input type of the flow. + userID := input.(string) + if authContext == nil || authContext["UID"] != userID { + return errors.New("user ID does not match") + } + return nil + }, true) + if err != nil { + log.Fatalf("failed to set up Firebase auth: %v", err) + } + // Define a flow with authentication + authenticatedFlow := genkit.DefineFlow( + "authenticated-flow", + func(ctx context.Context, userID string) (string, error) { + return fmt.Sprintf("Secure data for user %s", userID), nil + }, + genkit.WithFlowAuth(firebaseAuth), + ) + // [END auth] + _ = authenticatedFlow +} + +func f6() { + ctx := context.Background() + var policy func(authContext genkit.AuthContext, input any) error + required := true + // [START auth-create] + firebaseAuth, err := firebase.NewAuth(ctx, policy, required) + // [END auth-create] + _ = firebaseAuth + _ = err + userDataFunc := func(ctx context.Context, userID string) (string, error) { + return fmt.Sprintf("Secure data for user %s", userID), nil + } + // [START auth-define] + genkit.DefineFlow("secureUserFlow", userDataFunc, genkit.WithFlowAuth(firebaseAuth)) + // [END auth-define] + authenticatedFlow := genkit.DefineFlow("your-flow", userDataFunc, genkit.WithFlowAuth(firebaseAuth)) + // [START auth-run] + response, err := authenticatedFlow.Run(ctx, "user123", + genkit.WithLocalAuth(map[string]any{"UID": "user123"})) + // [END auth-run] + _ = response + _ = err +} diff --git a/go/plugins/firebase/auth.go b/go/plugins/firebase/auth.go new file mode 100644 index 000000000..de27795a6 --- /dev/null +++ b/go/plugins/firebase/auth.go @@ -0,0 +1,114 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package firebase + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strings" + + "firebase.google.com/go/v4/auth" + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/internal/base" +) + +var authContextKey = base.NewContextKey[map[string]any]() + +type AuthClient interface { + VerifyIDToken(context.Context, string) (*auth.Token, error) +} + +// firebaseAuth is a Firebase auth provider. +type firebaseAuth struct { + client AuthClient // Auth client for verifying ID tokens. + policy func(genkit.AuthContext, any) error // Auth policy for checking auth context. + required bool // Whether auth is required for direct calls. +} + +// NewAuth creates a Firebase auth check. +func NewAuth(ctx context.Context, policy func(genkit.AuthContext, any) error, required bool) (genkit.FlowAuth, error) { + app, err := App(ctx) + if err != nil { + return nil, err + } + client, err := app.Auth(ctx) + if err != nil { + return nil, err + } + auth := &firebaseAuth{ + client: client, + policy: policy, + required: required, + } + return auth, nil +} + +// ProvideAuthContext provides auth context from an auth header and sets it on the context. +func (f *firebaseAuth) ProvideAuthContext(ctx context.Context, authHeader string) (context.Context, error) { + if authHeader == "" { + if f.required { + return nil, errors.New("authorization header is required but not provided") + } + return ctx, nil + } + const bearerPrefix = "bearer " + if !strings.HasPrefix(strings.ToLower(authHeader), bearerPrefix) { + return nil, errors.New("invalid authorization header format") + } + token := authHeader[len(bearerPrefix):] + authToken, err := f.client.VerifyIDToken(ctx, token) + if err != nil { + return nil, fmt.Errorf("error verifying ID token: %v", err) + } + authBytes, err := json.Marshal(authToken) + if err != nil { + return nil, err + } + var authContext genkit.AuthContext + if err = json.Unmarshal(authBytes, &authContext); err != nil { + return nil, err + } + return f.NewContext(ctx, authContext), nil +} + +// NewContext sets the auth context on the given context. +func (f *firebaseAuth) NewContext(ctx context.Context, authContext genkit.AuthContext) context.Context { + if ctx == nil { + return nil + } + return authContextKey.NewContext(ctx, authContext) +} + +// FromContext retrieves the auth context from the given context. +func (*firebaseAuth) FromContext(ctx context.Context) genkit.AuthContext { + if ctx == nil { + return nil + } + return authContextKey.FromContext(ctx) +} + +// CheckAuthPolicy checks auth context against policy. +func (f *firebaseAuth) CheckAuthPolicy(ctx context.Context, input any) error { + authContext := f.FromContext(ctx) + if authContext == nil { + if f.required { + return errors.New("auth is required") + } + return nil + } + return f.policy(authContext, input) +} diff --git a/go/plugins/firebase/auth_test.go b/go/plugins/firebase/auth_test.go new file mode 100644 index 000000000..6bd566064 --- /dev/null +++ b/go/plugins/firebase/auth_test.go @@ -0,0 +1,211 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package firebase + +import ( + "context" + "errors" + "testing" + + "firebase.google.com/go/v4/auth" + "github.com/firebase/genkit/go/genkit" +) + +type mockAuthClient struct { + verifyIDTokenFunc func(context.Context, string) (*auth.Token, error) +} + +func (m *mockAuthClient) VerifyIDToken(ctx context.Context, token string) (*auth.Token, error) { + return m.verifyIDTokenFunc(ctx, token) +} + +func TestProvideAuthContext(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + tests := []struct { + name string + authHeader string + required bool + mockToken *auth.Token + mockError error + expectedUID string + expectedError string + }{ + { + name: "Valid token", + authHeader: "Bearer validtoken", + required: true, + mockToken: &auth.Token{ + UID: "user123", + Firebase: auth.FirebaseInfo{ + SignInProvider: "custom", + }, + }, + mockError: nil, + expectedUID: "user123", + expectedError: "", + }, + { + name: "Missing header when required", + authHeader: "", + required: true, + expectedUID: "", + expectedError: "authorization header is required but not provided", + }, + { + name: "Missing header when not required", + authHeader: "", + required: false, + expectedUID: "", + expectedError: "", + }, + { + name: "Invalid header format", + authHeader: "InvalidBearer token", + required: true, + expectedUID: "", + expectedError: "invalid authorization header format", + }, + { + name: "Token verification error", + authHeader: "Bearer invalidtoken", + required: true, + mockToken: nil, + mockError: errors.New("invalid token"), + expectedUID: "", + expectedError: "error verifying ID token: invalid token", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockClient := &mockAuthClient{ + verifyIDTokenFunc: func(ctx context.Context, token string) (*auth.Token, error) { + if token == "validtoken" { + return tt.mockToken, tt.mockError + } + return nil, tt.mockError + }, + } + + auth := &firebaseAuth{ + client: mockClient, + required: tt.required, + } + + newCtx, err := auth.ProvideAuthContext(ctx, tt.authHeader) + + if tt.expectedError != "" { + if err == nil || err.Error() != tt.expectedError { + t.Errorf("Expected error %q, got %v", tt.expectedError, err) + } + } else if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if tt.expectedUID != "" { + authContext := auth.FromContext(newCtx) + if authContext == nil { + t.Errorf("Expected non-nil auth context") + } else { + uid, ok := authContext["uid"].(string) + if !ok { + t.Errorf("Expected 'uid' to be a string, got %T", authContext["uid"]) + } else if uid != tt.expectedUID { + t.Errorf("Expected UID %q, got %q", tt.expectedUID, uid) + } + } + } else if auth.FromContext(newCtx) != nil && tt.authHeader != "" { + t.Errorf("Expected nil auth context, but got non-nil") + } + }) + } +} + +func TestCheckAuthPolicy(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + authContext genkit.AuthContext + input any + required bool + policy func(genkit.AuthContext, any) error + expectedError string + }{ + { + name: "Valid auth context and policy", + authContext: map[string]any{"uid": "user123"}, + input: "test input", + required: true, + policy: func(authContext genkit.AuthContext, in any) error { + return nil + }, + expectedError: "", + }, + { + name: "Policy error", + authContext: map[string]any{"uid": "user123"}, + input: "test input", + required: true, + policy: func(authContext genkit.AuthContext, in any) error { + return errors.New("policy error") + }, + expectedError: "policy error", + }, + { + name: "Missing auth context when required", + authContext: nil, + input: "test input", + required: true, + policy: nil, + expectedError: "auth is required", + }, + { + name: "Missing auth context when not required", + authContext: nil, + input: "test input", + required: false, + policy: nil, + expectedError: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + auth := &firebaseAuth{ + required: tt.required, + policy: tt.policy, + } + + ctx := context.Background() + if tt.authContext != nil { + ctx = auth.NewContext(ctx, tt.authContext) + } + + err := auth.CheckAuthPolicy(ctx, tt.input) + + if tt.expectedError != "" { + if err == nil || err.Error() != tt.expectedError { + t.Errorf("Expected error %q, got %v", tt.expectedError, err) + } + } else if err != nil { + t.Errorf("Unexpected error: %v", err) + } + }) + } +} diff --git a/go/plugins/firebase/firebase.go b/go/plugins/firebase/firebase.go new file mode 100644 index 000000000..e2f4a85b4 --- /dev/null +++ b/go/plugins/firebase/firebase.go @@ -0,0 +1,46 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package firebase + +import ( + "context" + "sync" + + firebase "firebase.google.com/go/v4" + "firebase.google.com/go/v4/auth" +) + +type FirebaseApp interface { + Auth(ctx context.Context) (*auth.Client, error) +} + +var ( + app *firebase.App + mutex sync.Mutex +) + +// app returns a cached Firebase app. +func App(ctx context.Context) (FirebaseApp, error) { + mutex.Lock() + defer mutex.Unlock() + if app == nil { + newApp, err := firebase.NewApp(ctx, nil) + if err != nil { + return nil, err + } + app = newApp + } + return app, nil +} diff --git a/go/samples/firebase-auth/main.go b/go/samples/firebase-auth/main.go new file mode 100644 index 000000000..14e7bcd37 --- /dev/null +++ b/go/samples/firebase-auth/main.go @@ -0,0 +1,72 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "errors" + "fmt" + "log" + + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/firebase" +) + +func main() { + ctx := context.Background() + + policy := func(authContext genkit.AuthContext, input any) error { + user := input.(string) + if authContext == nil || authContext["UID"] != user { + return errors.New("user ID does not match") + } + return nil + } + firebaseAuth, err := firebase.NewAuth(ctx, policy, true) + if err != nil { + log.Fatalf("failed to set up Firebase auth: %v", err) + } + + flowWithRequiredAuth := genkit.DefineFlow("flow-with-required-auth", func(ctx context.Context, user string) (string, error) { + return fmt.Sprintf("info about user %q", user), nil + }, genkit.WithFlowAuth(firebaseAuth)) + + firebaseAuth, err = firebase.NewAuth(ctx, policy, false) + if err != nil { + log.Fatalf("failed to set up Firebase auth: %v", err) + } + + flowWithAuth := genkit.DefineFlow("flow-with-auth", func(ctx context.Context, user string) (string, error) { + return fmt.Sprintf("info about user %q", user), nil + }, genkit.WithFlowAuth(firebaseAuth)) + + genkit.DefineFlow("super-caller", func(ctx context.Context, _ struct{}) (string, error) { + // Auth is required so we have to pass local credentials. + resp1, err := flowWithRequiredAuth.Run(ctx, "admin-user", genkit.WithLocalAuth(map[string]any{"UID": "admin-user"})) + if err != nil { + return "", fmt.Errorf("flowWithRequiredAuth: %w", err) + } + // Auth is not required so we can just run the flow. + resp2, err := flowWithAuth.Run(ctx, "admin-user-2") + if err != nil { + return "", fmt.Errorf("flowWithAuth: %w", err) + } + return resp1 + ", " + resp2, nil + }) + + if err := genkit.Init(ctx, nil); err != nil { + log.Fatal(err) + } +}