diff --git a/db/init.sql b/db/init.sql index 7b9532b..91d0257 100644 --- a/db/init.sql +++ b/db/init.sql @@ -17,6 +17,19 @@ CREATE TABLE ollama_embeddings ( metadata JSONB ); +CREATE TABLE bedrock_embeddings_1024 ( + id BIGSERIAL PRIMARY KEY, + doc_id TEXT NOT NULL, + model TEXT NOT NULL, + embedding VECTOR(1024) NOT NULL, + metadata JSONB, + created_at TIMESTAMPTZ DEFAULT now() +); + +CREATE INDEX bedrock_embeddings_1024_ivf_cos ON bedrock_embeddings_1024 +USING ivfflat (embedding vector_cosine_ops) +WITH (lists = 100); + -- Index for efficient vector similarity search for OpenAI embeddings CREATE INDEX openai_embeddings_idx ON openai_embeddings USING ivfflat (embedding vector_cosine_ops) diff --git a/examples/bedrock/README.md b/examples/bedrock/README.md new file mode 100644 index 0000000..9feb74c --- /dev/null +++ b/examples/bedrock/README.md @@ -0,0 +1,23 @@ +# Bedrock Example + +Minimal example that uses **AWS Bedrock** through the `BedrockBackend` to: +1) generate text +2) create embeddings + +## Prerequisites + +- Go 1.22+ +- AWS credentials configured (env/SharedConfig) +- Environment variables: + - `AWS_REGION` (e.g. `us-east-1`) + - `BEDROCK_TEXT_MODEL` (e.g. `anthropic.claude-3-haiku-20240307-v1:0`) + - `BEDROCK_EMBED_MODEL` (e.g. `amazon.titan-embed-text-v1`) + +## Run + +```bash +export AWS_REGION=us-east-1 +export BEDROCK_TEXT_MODEL=anthropic.claude-3-haiku-20240307-v1:0 +export BEDROCK_EMBED_MODEL=amazon.titan-embed-text-v1 + +go run ./examples/bedrock "Explain vector databases in 2 sentences." diff --git a/examples/bedrock/main.go b/examples/bedrock/main.go new file mode 100644 index 0000000..fc6ff6c --- /dev/null +++ b/examples/bedrock/main.go @@ -0,0 +1,63 @@ +package main + +import ( + "context" + "fmt" + "log" + "os" + "time" + + "github.com/stackloklabs/gorag/pkg/backend" +) + +func getenv(key string) string { + v := os.Getenv(key) + if v == "" { + log.Fatalf("missing env %s", key) + } + return v +} + +func main() { + region := getenv("AWS_REGION") + textModel := getenv("BEDROCK_TEXT_MODEL") + embedModel := getenv("BEDROCK_EMBED_MODEL") + + prompt := "Explain Retrieval-Augmented Generation (RAG) in 3 sentences." + if len(os.Args) > 1 { + prompt = os.Args[1] + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + br, err := backend.NewBedrockBackend(ctx, region, textModel, embedModel) + if err != nil { + log.Fatalf("init bedrock backend: %v", err) + } + + out, err := br.Generate(ctx, prompt, map[string]any{ + "textGenerationConfig": map[string]any{ + "maxTokenCount": 512, + "temperature": 0.2, + "topP": 0.9, + "stopSequences": []string{}, + }, + }) + + if err != nil { + log.Fatalf("generate: %v", err) + } + fmt.Println("=== Completion ===") + fmt.Println(out) + + vecs, err := br.Embed(ctx, []string{ + "RAG augments LLMs with external knowledge.", + "Vector databases enable efficient similarity search.", + }, nil) + if err != nil { + log.Fatalf("embed: %v", err) + } + fmt.Println("=== Embeddings ===") + fmt.Printf("count=%d dim=%d\n", len(vecs), len(vecs[0])) +} diff --git a/go.mod b/go.mod index ffa3d10..5f406e3 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,21 @@ require ( ) require ( + github.com/aws/aws-sdk-go-v2 v1.36.6 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.11 // indirect + github.com/aws/aws-sdk-go-v2/config v1.29.18 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.17.71 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.33 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.37 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.37 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 // indirect + github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.31.1 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.4 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.18 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.25.6 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.4 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.34.1 // indirect + github.com/aws/smithy-go v1.22.4 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/jackc/chunkreader/v2 v2.0.1 // indirect github.com/jackc/pgconn v1.14.3 // indirect diff --git a/pkg/backend/bedrock_backend.go b/pkg/backend/bedrock_backend.go new file mode 100644 index 0000000..00c53f2 --- /dev/null +++ b/pkg/backend/bedrock_backend.go @@ -0,0 +1,198 @@ +package backend + +import ( + "context" + "encoding/json" + "errors" + "strings" + + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" +) + +type bedrockInvoker interface { + InvokeModel(ctx context.Context, params *bedrockruntime.InvokeModelInput, optFns ...func(*bedrockruntime.Options)) (*bedrockruntime.InvokeModelOutput, error) +} + +type BedrockBackend struct { + client bedrockInvoker + textModel string + embedModel string +} + +func NewBedrockBackend(ctx context.Context, region, textModel, embedModel string) (*BedrockBackend, error) { + cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(region)) + if err != nil { + return nil, err + } + return &BedrockBackend{ + client: bedrockruntime.NewFromConfig(cfg), + textModel: textModel, + embedModel: embedModel, + }, nil +} + +func newBedrockBackendWithClient(textModel, embedModel string, c bedrockInvoker) *BedrockBackend { + return &BedrockBackend{ + client: c, + textModel: textModel, + embedModel: embedModel, + } +} + +func (b *BedrockBackend) Generate(ctx context.Context, prompt string, params map[string]any) (string, error) { + if b.textModel == "" { + return "", errors.New("text model is not set") + } + body := map[string]any{ + "inputText": prompt, + } + for k, v := range params { + body[k] = v + } + req, err := json.Marshal(body) + if err != nil { + return "", err + } + out, err := b.client.InvokeModel(ctx, &bedrockruntime.InvokeModelInput{ + ModelId: awsString(b.textModel), + ContentType: awsString("application/json"), + Body: req, + }) + if err != nil { + return "", err + } + return parseText(out.Body) +} + +func (b *BedrockBackend) GenerateStream(ctx context.Context, prompt string, params map[string]any, onToken func(string) error) error { + txt, err := b.Generate(ctx, prompt, params) + if err != nil { + return err + } + for _, t := range strings.Split(txt, " ") { + if err := onToken(t + " "); err != nil { + return err + } + } + return nil +} + +func (b *BedrockBackend) Embed(ctx context.Context, texts []string, params map[string]any) ([][]float32, error) { + if b.embedModel == "" { + return nil, errors.New("embed model is not set") + } + outVecs := make([][]float32, 0, len(texts)) + for _, t := range texts { + body := map[string]any{ + "inputText": t, + } + for k, v := range params { + body[k] = v + } + req, err := json.Marshal(body) + if err != nil { + return nil, err + } + out, err := b.client.InvokeModel(ctx, &bedrockruntime.InvokeModelInput{ + ModelId: awsString(b.embedModel), + ContentType: awsString("application/json"), + Body: req, + }) + if err != nil { + return nil, err + } + vec, err := parseEmbedding(out.Body) + if err != nil { + return nil, err + } + outVecs = append(outVecs, vec) + } + return outVecs, nil +} +func parseText(b []byte) (string, error) { + var m map[string]any + if err := json.Unmarshal(b, &m); err != nil { + return "", err + } + + // Titan: {"results":[{"outputText":"..."}]} + if r, ok := m["results"].([]any); ok && len(r) > 0 { + for _, it := range r { + if mm, ok := it.(map[string]any); ok { + if s, ok := mm["outputText"].(string); ok { + return s, nil + } + if s, ok := mm["text"].(string); ok { + return s, nil + } + + if msg, ok := mm["message"].(map[string]any); ok { + if content, ok := msg["content"].([]any); ok && len(content) > 0 { + if c0, ok := content[0].(map[string]any); ok { + if s, ok := c0["text"].(string); ok { + return s, nil + } + } + } + } + } + } + } + + if s, ok := m["outputText"].(string); ok { + return s, nil + } + if s, ok := m["completion"].(string); ok { + return s, nil + } + if s, ok := m["generation"].(string); ok { + return s, nil + } + if arr, ok := m["content"].([]any); ok && len(arr) > 0 { + if mm, ok := arr[0].(map[string]any); ok { + if s, ok := mm["text"].(string); ok { + return s, nil + } + } + } + + return "", errors.New("unexpected response schema") +} + +func parseEmbedding(b []byte) ([]float32, error) { + var m map[string]any + if err := json.Unmarshal(b, &m); err != nil { + return nil, err + } + if v, ok := m["embedding"].([]any); ok { + return toFloat32Slice(v) + } + if r, ok := m["results"].([]any); ok && len(r) > 0 { + if mm, ok := r[0].(map[string]any); ok { + if v, ok := mm["embedding"].([]any); ok { + return toFloat32Slice(v) + } + } + } + return nil, errors.New("embedding not found") +} + +func toFloat32Slice(v []any) ([]float32, error) { + res := make([]float32, len(v)) + for i, x := range v { + switch n := x.(type) { + case float64: + res[i] = float32(n) + case float32: + res[i] = n + default: + return nil, errors.New("invalid embedding element type") + } + } + return res, nil +} + +func awsString(s string) *string { + return &s +} diff --git a/pkg/backend/bedrock_backend_test.go b/pkg/backend/bedrock_backend_test.go new file mode 100644 index 0000000..e715017 --- /dev/null +++ b/pkg/backend/bedrock_backend_test.go @@ -0,0 +1,185 @@ +package backend + +import ( + "context" + "encoding/json" + "errors" + "reflect" + "strings" + "testing" + + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" +) + +type fakeClient struct { + responses [][]byte + err error + calls int + inputs []*bedrockruntime.InvokeModelInput +} + +func (f *fakeClient) InvokeModel(ctx context.Context, in *bedrockruntime.InvokeModelInput, _ ...func(*bedrockruntime.Options)) (*bedrockruntime.InvokeModelOutput, error) { + f.calls++ + f.inputs = append(f.inputs, in) + if f.err != nil { + return nil, f.err + } + body := f.responses[f.calls-1] + return &bedrockruntime.InvokeModelOutput{ + Body: body, + ContentType: awsString("application/json"), + }, nil +} + +func TestAwsString(t *testing.T) { + v := "x" + if got := awsString(v); got == nil || *got != v { + t.Fatalf("awsString failed") + } +} + +func TestToFloat32Slice(t *testing.T) { + got, err := toFloat32Slice([]any{1.2, float32(3.4)}) + if err != nil { + t.Fatalf("err: %v", err) + } + want := []float32{1.2, 3.4} + if !reflect.DeepEqual(got, want) { + t.Fatalf("got %v want %v", got, want) + } + _, err = toFloat32Slice([]any{"nope"}) + if err == nil { + t.Fatalf("expected error") + } +} + +func TestParseEmbedding(t *testing.T) { + b1, _ := json.Marshal(map[string]any{"embedding": []any{0.1, 0.2}}) + v, err := parseEmbedding(b1) + if err != nil || !reflect.DeepEqual(v, []float32{0.1, 0.2}) { + t.Fatalf("direct embedding failed: %v %v", v, err) + } + + b2, _ := json.Marshal(map[string]any{"results": []any{ + map[string]any{"embedding": []any{1.0, 2.0, 3.0}}, + }}) + v, err = parseEmbedding(b2) + if err != nil || !reflect.DeepEqual(v, []float32{1, 2, 3}) { + t.Fatalf("results embedding failed: %v %v", v, err) + } + + _, err = parseEmbedding([]byte(`{"foo":"bar"}`)) + if err == nil || err.Error() != "embedding not found" { + t.Fatalf("want embedding not found, got %v", err) + } +} + +func TestGenerate_ErrorWhenNoModel(t *testing.T) { + b := newBedrockBackendWithClient("", "", &fakeClient{}) + _, err := b.Generate(context.Background(), "p", nil) + if err == nil || !strings.Contains(err.Error(), "text model is not set") { + t.Fatalf("expected model not set error") + } +} + +func TestEmbed_ErrorWhenNoModel(t *testing.T) { + b := newBedrockBackendWithClient("m", "", &fakeClient{}) + _, err := b.Embed(context.Background(), []string{"a"}, nil) + if err == nil || !strings.Contains(err.Error(), "embed model is not set") { + t.Fatalf("expected embed model not set error") + } +} + +func TestGenerate_SuccessBranches(t *testing.T) { + makeResp := func(m map[string]any) []byte { b, _ := json.Marshal(m); return b } + tests := []struct { + name string + payload map[string]any + expected string + }{ + {"outputText", map[string]any{"outputText": "ok1"}, "ok1"}, + {"completion", map[string]any{"completion": "ok2"}, "ok2"}, + {"generation", map[string]any{"generation": "ok3"}, "ok3"}, + {"content[0].text", map[string]any{"content": []any{map[string]any{"text": "ok4"}}}, "ok4"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fc := &fakeClient{responses: [][]byte{makeResp(tt.payload)}} + b := newBedrockBackendWithClient("text-model", "embed-model", fc) + got, err := b.Generate(context.Background(), "hello", map[string]any{"temp": 0.1}) + if err != nil { + t.Fatalf("err: %v", err) + } + if got != tt.expected { + t.Fatalf("got %q want %q", got, tt.expected) + } + if fc.calls != 1 { + t.Fatalf("expected 1 call") + } + var in map[string]any + _ = json.Unmarshal(fc.inputs[0].Body, &in) + if in["inputText"] != "hello" { + t.Fatalf("prompt not forwarded") + } + if in["temp"] != 0.1 { + t.Fatalf("params not forwarded") + } + }) + } +} + +func TestGenerate_UnexpectedSchema(t *testing.T) { + resp, _ := json.Marshal(map[string]any{"foo": "bar"}) + fc := &fakeClient{responses: [][]byte{resp}} + b := newBedrockBackendWithClient("text-model", "embed-model", fc) + _, err := b.Generate(context.Background(), "p", nil) + if err == nil || !strings.Contains(err.Error(), "unexpected response schema") { + t.Fatalf("expected unexpected response schema, got %v", err) + } +} + +func TestGenerateStream(t *testing.T) { + resp, _ := json.Marshal(map[string]any{"outputText": "a b c"}) + fc := &fakeClient{responses: [][]byte{resp}} + b := newBedrockBackendWithClient("text-model", "embed-model", fc) + var got []string + err := b.GenerateStream(context.Background(), "p", nil, func(s string) error { + got = append(got, s) + return nil + }) + if err != nil { + t.Fatalf("err: %v", err) + } + if strings.Join(got, "") != "a b c " { + t.Fatalf("bad stream: %v", got) + } +} + +func TestEmbed_SuccessAndInvalidElement(t *testing.T) { + okBody, _ := json.Marshal(map[string]any{"embedding": []any{0.1, 0.2}}) + badBody, _ := json.Marshal(map[string]any{"embedding": []any{"x"}}) + fc := &fakeClient{responses: [][]byte{okBody, badBody}} + b := newBedrockBackendWithClient("t", "e", fc) + + vecs, err := b.Embed(context.Background(), []string{"a"}, nil) + if err != nil { + t.Fatalf("err: %v", err) + } + if len(vecs) != 1 || !reflect.DeepEqual(vecs[0], []float32{0.1, 0.2}) { + t.Fatalf("unexpected vecs: %v", vecs) + } + + _, err = b.Embed(context.Background(), []string{"b"}, nil) + if err == nil { + t.Fatalf("expected error on invalid element") + } +} + +func TestFakeClient_PropagatesError(t *testing.T) { + fc := &fakeClient{err: errors.New("boom")} + b := newBedrockBackendWithClient("t", "e", fc) + _, err := b.Generate(context.Background(), "p", nil) + if err == nil || !strings.Contains(err.Error(), "boom") { + t.Fatalf("expected boom") + } +}