Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Go] Added support for flow auth and Firebase auth plugin. #722

Merged
merged 32 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
e3c5b76
Initial implementation of flow auth.
apascal07 Jul 31, 2024
bb0de8e
Update flow.go
apascal07 Jul 31, 2024
8aeb4c6
Added Firebase auth plugin.
apascal07 Aug 1, 2024
0fe14f6
Cleaned up interface.
apascal07 Aug 1, 2024
0195751
Update firebase.go
apascal07 Aug 1, 2024
aa0e987
Removed flow auth methods.
apascal07 Aug 1, 2024
323ca7d
Update servers_test.go
apascal07 Aug 1, 2024
9e83262
Generics overhaul.
apascal07 Aug 1, 2024
98a4f34
Added no-auth option.
apascal07 Aug 2, 2024
63cd6e0
Renames.
apascal07 Aug 2, 2024
fbba953
Update main.go
apascal07 Aug 2, 2024
f1bbcf1
Added an option of more define methods.
apascal07 Aug 2, 2024
c9f520c
Update flow.go
apascal07 Aug 2, 2024
398ebfc
Strongly typed auth, any input.
apascal07 Aug 2, 2024
9a253de
Changed FlowAuth to be map[string]any.
apascal07 Aug 3, 2024
63f47d5
Renamed Firebase auth file.
apascal07 Aug 5, 2024
d8e0e04
Added Firebase auth tests.
apascal07 Aug 5, 2024
78715c3
Update main.go
apascal07 Aug 5, 2024
16d0dd1
Update flow.go
apascal07 Aug 5, 2024
cd54567
Update flow.go
apascal07 Aug 6, 2024
8a83700
Added auth context on context.
apascal07 Aug 7, 2024
28736bb
Update flow.go
apascal07 Aug 7, 2024
d9f96fe
Resolved review comments.
apascal07 Aug 7, 2024
c235f8d
Replaced map[string]any with AuthContext.
apascal07 Aug 7, 2024
d5a9c1b
Update flow.go
apascal07 Aug 7, 2024
d97801f
Added docs.
apascal07 Aug 7, 2024
9d062cb
Added docs.
apascal07 Aug 7, 2024
d6bd88b
Update flows.go
apascal07 Aug 7, 2024
fd03d08
Moved auth docs to separate page.
apascal07 Aug 8, 2024
b2feede
Merge branch 'main' into ap-go-flow
apascal07 Aug 8, 2024
ed0a8f7
Fix.
apascal07 Aug 8, 2024
bd4d917
Update auth_test.go
apascal07 Aug 8, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 114 additions & 17 deletions go/genkit/flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"encoding/json"
"errors"
"fmt"
"log"
"net/http"
"strconv"
"sync"
Expand Down Expand Up @@ -98,28 +99,74 @@ type Flow[In, Out, Stream any] struct {
tstate *tracing.State // set from the action when the flow is defined
inputSchema *jsonschema.Schema // Schema of the input to the flow
outputSchema *jsonschema.Schema // Schema of the output out of the flow
auth FlowAuth
// TODO: scheduler
// TODO: experimentalDurable
// TODO: authPolicy
// TODO: middleware
}

// runOptions configures a single flow run.
type runOptions struct {
authContext map[string]any // Auth context to pass to auth policy checker when calling a flow directly.
}

// flowOptions configures a flow.
type flowOptions struct {
auth FlowAuth // Auth provider and policy checker for the flow.
}

type noStream = func(context.Context, struct{}) error

// FlowAuth configures a auth context provider and a auth policy check for a flow.
type FlowAuth interface {
// ProvideAuthContext provides auth context from an auth header.
ProvideAuthContext(ctx context.Context, authHeader string) (map[string]any, error)

// CheckAuthPolicy checks auth context against policy.
CheckAuthPolicy(auth map[string]any, input any) error
}

// streamingCallback is the type of streaming callbacks.
type streamingCallback[Stream any] func(context.Context, Stream) error

// FlowOption modifies the flow with the provided option.
type FlowOption func(opts *flowOptions)

// FlowRunOption modifies a flow run with the provided option.
type FlowRunOption func(opts *runOptions)

// WithFlowAuth sets an auth provider and policy checker for the flow.
func WithFlowAuth(auth FlowAuth) FlowOption {
return func(f *flowOptions) {
if f.auth != nil {
log.Panic("auth already set in flow")
}
f.auth = auth
}
}

// WithLocalAuth configures an option to run or stream a flow with a local auth value.
func WithLocalAuth(authContext map[string]any) FlowRunOption {
return func(opts *runOptions) {
if opts.authContext != nil {
log.Panic("authContext already set in runOptions")
}
opts.authContext = authContext
}
}

// DefineFlow creates a Flow that runs fn, and registers it as an action.
//
// fn takes an input of type In and returns an output of type Out.
func DefineFlow[In, Out any](
name string,
fn func(ctx context.Context, input In) (Out, error),
opts ...FlowOption,
) *Flow[In, Out, struct{}] {
return defineFlow(registry.Global, name, core.Func[In, Out, struct{}](
func(ctx context.Context, input In, cb func(ctx context.Context, _ struct{}) error) (Out, error) {
return fn(ctx, input)
}))
}), opts...)
}

// DefineStreamingFlow creates a streaming Flow that runs fn, and registers it as an action.
Expand All @@ -134,11 +181,12 @@ func DefineFlow[In, Out any](
func DefineStreamingFlow[In, Out, Stream any](
name string,
fn func(ctx context.Context, input In, callback func(context.Context, Stream) error) (Out, error),
opts ...FlowOption,
) *Flow[In, Out, Stream] {
return defineFlow(registry.Global, name, core.Func[In, Out, Stream](fn))
return defineFlow(registry.Global, name, core.Func[In, Out, Stream](fn), opts...)
}

func defineFlow[In, Out, Stream any](r *registry.Registry, name string, fn core.Func[In, Out, Stream]) *Flow[In, Out, Stream] {
func defineFlow[In, Out, Stream any](r *registry.Registry, name string, fn core.Func[In, Out, Stream], opts ...FlowOption) *Flow[In, Out, Stream] {
var i In
var o Out
f := &Flow[In, Out, Stream]{
Expand All @@ -148,12 +196,24 @@ func defineFlow[In, Out, Stream any](r *registry.Registry, name string, fn core.
outputSchema: base.InferJSONSchema(o),
// TODO: set stateStore?
}
flowOpts := &flowOptions{}
for _, opt := range opts {
opt(flowOpts)
}
f.auth = flowOpts.auth
metadata := map[string]any{
"inputSchema": f.inputSchema,
"outputSchema": f.outputSchema,
"requiresAuth": f.auth != nil,
}
afunc := func(ctx context.Context, inst *flowInstruction[In], cb func(context.Context, Stream) error) (*flowState[In, Out], error) {
tracing.SetCustomMetadataAttr(ctx, "flow:wrapperAction", "true")
// Only non-durable flows have an auth policy so can safely assume Start.Input.
if inst.Start != nil {
if err := f.checkAuthPolicy(inst.Auth, any(inst.Start.Input)); err != nil {
return nil, err
}
}
return f.runInstruction(ctx, inst, streamingCallback[Stream](cb))
}
core.DefineActionInRegistry(r, "", f.name, atype.Flow, metadata, nil, afunc)
Expand All @@ -167,18 +227,19 @@ func defineFlow[In, Out, Stream any](r *registry.Registry, name string, fn core.
// A flowInstruction is an instruction to follow with a flow.
// It is the input for the flow's action.
// Exactly one field will be non-nil.
type flowInstruction[I any] struct {
Start *startInstruction[I] `json:"start,omitempty"`
type flowInstruction[In any] struct {
Start *startInstruction[In] `json:"start,omitempty"`
Resume *resumeInstruction `json:"resume,omitempty"`
Schedule *scheduleInstruction[I] `json:"schedule,omitempty"`
Schedule *scheduleInstruction[In] `json:"schedule,omitempty"`
RunScheduled *runScheduledInstruction `json:"runScheduled,omitempty"`
State *stateInstruction `json:"state,omitempty"`
Retry *retryInstruction `json:"retry,omitempty"`
Auth map[string]any `json:"auth,omitempty"`
}

// A startInstruction starts a flow.
type startInstruction[I any] struct {
Input I `json:"input,omitempty"`
type startInstruction[In any] struct {
Input In `json:"input,omitempty"`
Labels map[string]string `json:"labels,omitempty"`
}

Expand All @@ -189,9 +250,9 @@ type resumeInstruction struct {
}

// A scheduleInstruction schedules a flow to start at a later time.
type scheduleInstruction[I any] struct {
type scheduleInstruction[In any] struct {
DelaySecs float64 `json:"delay,omitempty"`
Input I `json:"input,omitempty"`
Input In `json:"input,omitempty"`
}

// A runScheduledInstruction starts a scheduled flow.
Expand Down Expand Up @@ -324,7 +385,7 @@ func (f *Flow[In, Out, Stream]) runInstruction(ctx context.Context, inst *flowIn
// Name returns the name that the flow was defined with.
func (f *Flow[In, Out, Stream]) Name() string { return f.name }

func (f *Flow[In, Out, Stream]) runJSON(ctx context.Context, input json.RawMessage, cb streamingCallback[json.RawMessage]) (json.RawMessage, error) {
func (f *Flow[In, Out, Stream]) runJSON(ctx context.Context, authHeader string, input json.RawMessage, cb streamingCallback[json.RawMessage]) (json.RawMessage, error) {
// Validate input before unmarshaling it because invalid or unknown fields will be discarded in the process.
if err := base.ValidateJSON(input, f.inputSchema); err != nil {
return nil, &base.HTTPError{Code: http.StatusBadRequest, Err: err}
Expand All @@ -333,6 +394,13 @@ func (f *Flow[In, Out, Stream]) runJSON(ctx context.Context, input json.RawMessa
if err := json.Unmarshal(input, &in); err != nil {
return nil, &base.HTTPError{Code: http.StatusBadRequest, Err: err}
}
authContext, err := f.provideAuthContext(ctx, authHeader)
if err != nil {
return nil, &base.HTTPError{Code: http.StatusUnauthorized, Err: err}
}
if err := f.checkAuthPolicy(authContext, in); err != nil {
return nil, &base.HTTPError{Code: http.StatusForbidden, Err: err}
}
// If there is a callback, wrap it to turn an S into a json.RawMessage.
var callback streamingCallback[Stream]
if cb != nil {
Expand Down Expand Up @@ -361,6 +429,28 @@ func (f *Flow[In, Out, Stream]) runJSON(ctx context.Context, input json.RawMessa
return json.Marshal(res.Response)
}

// provideAuthContext provides auth context for the given auth header if flow auth is configured.
func (f *Flow[In, Out, Stream]) provideAuthContext(ctx context.Context, authHeader string) (map[string]any, error) {
if f.auth != nil {
authContext, err := f.auth.ProvideAuthContext(ctx, authHeader)
if err != nil {
return nil, fmt.Errorf("unauthorized: %w", err)
}
return authContext, nil
}
return nil, nil
}

// checkAuthPolicy checks auth context against the policy if flow auth is configured.
func (f *Flow[In, Out, Stream]) checkAuthPolicy(authContext map[string]any, input any) error {
if f.auth != nil {
if err := f.auth.CheckAuthPolicy(authContext, input); err != nil {
return fmt.Errorf("permission denied for resource: %w", err)
}
}
return nil
}

// start starts executing the flow with the given input.
func (f *Flow[In, Out, Stream]) start(ctx context.Context, input In, cb streamingCallback[Stream]) (_ *flowState[In, Out], err error) {
flowID, err := generateFlowID()
Expand Down Expand Up @@ -569,11 +659,18 @@ func Run[Out any](ctx context.Context, name string, f func() (Out, error)) (Out,

// Run runs the flow in the context of another flow. The flow must run to completion when started
// (that is, it must not have interrupts).
func (f *Flow[In, Out, Stream]) Run(ctx context.Context, input In) (Out, error) {
return f.run(ctx, input, nil)
func (f *Flow[In, Out, Stream]) Run(ctx context.Context, input In, opts ...FlowRunOption) (Out, error) {
return f.run(ctx, input, nil, opts...)
}

func (f *Flow[In, Out, Stream]) run(ctx context.Context, input In, cb func(context.Context, Stream) error) (Out, error) {
func (f *Flow[In, Out, Stream]) run(ctx context.Context, input In, cb func(context.Context, Stream) error, opts ...FlowRunOption) (Out, error) {
runOpts := &runOptions{}
for _, opt := range opts {
opt(runOpts)
}
if err := f.checkAuthPolicy(runOpts.authContext, input); err != nil {
return base.Zero[Out](), err
}
state, err := f.start(ctx, input, cb)
if err != nil {
return base.Zero[Out](), err
Expand Down Expand Up @@ -602,7 +699,7 @@ type StreamFlowValue[Out, Stream any] struct {
// again.
//
// Otherwise the Stream field of the passed [StreamFlowValue] holds a streamed result.
func (f *Flow[In, Out, Stream]) Stream(ctx context.Context, input In) func(func(*StreamFlowValue[Out, Stream], error) bool) {
func (f *Flow[In, Out, Stream]) Stream(ctx context.Context, input In, opts ...FlowRunOption) func(func(*StreamFlowValue[Out, Stream], error) bool) {
return func(yield func(*StreamFlowValue[Out, Stream], error) bool) {
cb := func(ctx context.Context, s Stream) error {
if ctx.Err() != nil {
Expand All @@ -613,7 +710,7 @@ func (f *Flow[In, Out, Stream]) Stream(ctx context.Context, input In) func(func(
}
return nil
}
output, err := f.run(ctx, input, cb)
output, err := f.run(ctx, input, cb, opts...)
if err != nil {
yield(nil, err)
} else {
Expand Down
78 changes: 32 additions & 46 deletions go/genkit/servers.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,13 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"io/fs"
"log/slog"
"net/http"
"os"
"os/signal"
"strconv"
"sync"
"sync/atomic"
"syscall"
"time"

"github.com/firebase/genkit/go/core/logger"
Expand Down Expand Up @@ -76,7 +73,7 @@ type flow interface {

// runJSON uses encoding/json to unmarshal the input,
// calls Flow.start, then returns the marshaled result.
runJSON(ctx context.Context, input json.RawMessage, cb streamingCallback[json.RawMessage]) (json.RawMessage, error)
runJSON(ctx context.Context, authHeader string, input json.RawMessage, cb streamingCallback[json.RawMessage]) (json.RawMessage, error)
}

// startServer starts an HTTP server listening on the address.
Expand Down Expand Up @@ -163,19 +160,17 @@ func (s *devServer) handleRunAction(w http.ResponseWriter, r *http.Request) erro
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
return &base.HTTPError{Code: http.StatusBadRequest, Err: err}
}
stream := false
if s := r.FormValue("stream"); s != "" {
var err error
stream, err = strconv.ParseBool(s)
if err != nil {
return err
}
stream, err := parseBoolQueryParam(r, "stream")
if err != nil {
return err
}
logger.FromContext(ctx).Debug("running action",
"key", body.Key,
"stream", stream)
var callback streamingCallback[json.RawMessage]
if stream {
w.Header().Set("Content-Type", "text/plain")
w.Header().Set("Transfer-Encoding", "chunked")
// Stream results are newline-separated JSON.
callback = func(ctx context.Context, msg json.RawMessage) error {
_, err := fmt.Fprintf(w, "%s\n", msg)
Expand Down Expand Up @@ -328,29 +323,42 @@ func newFlowServeMux(r *registry.Registry, flows []string) *http.ServeMux {

func nonDurableFlowHandler(f flow) func(http.ResponseWriter, *http.Request) error {
return func(w http.ResponseWriter, r *http.Request) error {
var body struct {
Data json.RawMessage `json:"data"`
}
defer r.Body.Close()
input, err := io.ReadAll(r.Body)
if err != nil {
return err
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
return &base.HTTPError{Code: http.StatusBadRequest, Err: err}
}
stream, err := parseBoolQueryParam(r, "stream")
if err != nil {
return err
}
var callback streamingCallback[json.RawMessage]
if stream {
// TODO: implement streaming.
return &base.HTTPError{Code: http.StatusNotImplemented, Err: errors.New("streaming")}
} else {
// TODO: telemetry
out, err := f.runJSON(r.Context(), json.RawMessage(input), nil)
if err != nil {
return err
w.Header().Set("Content-Type", "text/plain")
w.Header().Set("Transfer-Encoding", "chunked")
// Stream results are newline-separated JSON.
callback = func(ctx context.Context, msg json.RawMessage) error {
_, err := fmt.Fprintf(w, "%s\n", msg)
if err != nil {
return err
}
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
return nil
}
// Responses for non-streaming, non-durable flows are passed back
// with the flow result stored in a field called "result."
_, err = fmt.Fprintf(w, `{"result": %s}\n`, out)
}
// TODO: telemetry
out, err := f.runJSON(r.Context(), r.Header.Get("Authorization"), body.Data, callback)
if err != nil {
return err
}
// Responses for non-streaming, non-durable flows are passed back
// with the flow result stored in a field called "result."
_, err = fmt.Fprintf(w, `{"result": %s}\n`, out)
return err
}
}

Expand All @@ -365,28 +373,6 @@ func serverAddress(arg, envVar, defaultValue string) string {
return defaultValue
}

func listenAndServe(addr string, mux *http.ServeMux) error {
server := &http.Server{
Addr: addr,
Handler: mux,
}
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGTERM)
go func() {
<-sigCh
slog.Info("received SIGTERM, shutting down server")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := server.Shutdown(ctx); err != nil {
slog.Error("server shutdown failed", "err", err)
} else {
slog.Info("server shutdown successfully")
}
}()
slog.Info("listening", "addr", addr)
return server.ListenAndServe()
}

// requestID is a unique ID for each request.
var requestID atomic.Int64

Expand Down
Loading
Loading