Skip to content

Commit

Permalink
internal/gaby: search API
Browse files Browse the repository at this point in the history
Add the /api/search route to gaby.

It is like /search, but for programs instead of people.

Fixes golang/go#22.

For golang/go#22.

Change-Id: I17518611d803e35401a08acbb2b215e49d273f9d
Reviewed-on: https://go-review.googlesource.com/c/oscar/+/614677
Reviewed-by: Tatiana Bradley <[email protected]>
LUCI-TryBot-Result: Go LUCI <[email protected]>
  • Loading branch information
jba committed Sep 23, 2024
1 parent bf51777 commit 88eb357
Show file tree
Hide file tree
Showing 3 changed files with 240 additions and 28 deletions.
5 changes: 5 additions & 0 deletions internal/gaby/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,11 @@ func (g *Gaby) newServer(report func(error)) *http.ServeMux {
// /search: display a form for vector similarity search.
// /search?q=...: perform a search using the value of q as input.
mux.HandleFunc("GET /search", g.handleSearch)

// /api/search: perform a vector similarity search.
// POST because the arguments to the request are in the body.
mux.HandleFunc("POST /api/search", g.handleSearchAPI)

return mux
}

Expand Down
137 changes: 114 additions & 23 deletions internal/gaby/search.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,14 @@ package main
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"math"
"net/http"
"net/url"
"path"
"strings"

"github.com/google/safehtml/template"
"golang.org/x/oscar/internal/llm"
Expand All @@ -22,12 +26,6 @@ type searchPage struct {
Results []searchResult
}

type searchResult struct {
Title string
VResult storage.VectorResult
IDIsURL bool // VResult.ID can be parsed as a URL
}

func (g *Gaby) handleSearch(w http.ResponseWriter, r *http.Request) {
data, err := g.doSearch(r)
if err != nil {
Expand All @@ -44,15 +42,12 @@ func (g *Gaby) doSearch(r *http.Request) ([]byte, error) {
}
if page.Query != "" {
var err error
page.Results, err = g.search(page.Query)
page.Results, err = g.search(r.Context(), &searchRequest{EmbedDoc: llm.EmbedDoc{Text: page.Query}})
if err != nil {
return nil, err
}
// Round scores to three decimal places.
const r = 1e3
for i := range page.Results {
sp := &page.Results[i].VResult.Score
*sp = math.Round(*sp*r) / r
page.Results[i].round()
}
}
var buf bytes.Buffer
Expand All @@ -62,28 +57,90 @@ func (g *Gaby) doSearch(r *http.Request) ([]byte, error) {
return buf.Bytes(), nil
}

// Maximum number of search results to return.
const maxResults = 20
type searchRequest struct {
Threshold float64 // lowest score to keep; default 0. Max is 1.
Limit int // max results (fewer if Threshold is set); 0 means no limit
llm.EmbedDoc
}
type searchResult struct {
Kind string // kind of document: issue, doc page, etc.
Title string
storage.VectorResult
}

// Round rounds r.Score to three decimal places.
func (r *searchResult) round() {
r.Score = math.Round(r.Score*1e3) / 1e3
}

// Maximum number of search results to return by default.
const defaultLimit = 20

// search does a search for query over Gaby's vector database.
func (g *Gaby) search(query string) ([]searchResult, error) {
vecs, err := g.embed.EmbedDocs(context.Background(), []llm.EmbedDoc{{Title: "", Text: query}})
func (g *Gaby) search(ctx context.Context, sreq *searchRequest) ([]searchResult, error) {
vecs, err := g.embed.EmbedDocs(ctx, []llm.EmbedDoc{sreq.EmbedDoc})
if err != nil {
return nil, fmt.Errorf("EmbedDocs: %w", err)
}
vec := vecs[0]

limit := defaultLimit
if sreq.Limit > 0 {
limit = sreq.Limit
}
// Search uses normalized dot product, so higher numbers are better.
// Max is 1, min is 0.
threshold := 0.0
if sreq.Threshold > 0 {
threshold = sreq.Threshold
}

var srs []searchResult
for _, r := range g.vector.Search(vec, maxResults) {
title := "?"
for _, r := range g.vector.Search(vec, limit) {
if r.Score < threshold {
break
}
title := ""
if d, ok := g.docs.Get(r.ID); ok {
title = d.Title
}
_, err := url.Parse(r.ID)
srs = append(srs, searchResult{title, r, err == nil})
srs = append(srs, searchResult{
Kind: docIDKind(r.ID),
Title: title,
VectorResult: r,
})
}
return srs, nil
}

// docIDKind determines the kind of document from its ID.
// It returns the empty string if it cannot do so.
func docIDKind(id string) string {
u, err := url.Parse(id)
if err != nil {
return ""
}
hp := path.Join(u.Host, u.Path)
switch {
case strings.HasPrefix(hp, "github.com/golang/go/issues/"):
return "GitHubIssue"
case strings.HasPrefix(hp, "go.dev/wiki/"):
return "GoWiki"
case strings.HasPrefix(hp, "go.dev/doc/"):
return "GoDocumentation"
case strings.HasPrefix(hp, "go.dev/ref/"):
return "GoReference"
case strings.HasPrefix(hp, "go.dev/blog/"):
return "GoBlog"
case strings.HasPrefix(hp, "go.dev/"):
return "GoDevPage"
default:
return ""
}
}

// This template assumes that if a result's Kind is non-empty, it is a URL,
// and vice versa.
var searchPageTmpl = template.Must(template.New("").Parse(`
<!doctype html>
<html>
Expand All @@ -110,16 +167,50 @@ var searchPageTmpl = template.Must(template.New("").Parse(`
{{with .Results -}}
{{- range . -}}
<p>{{with .Title}}{{.}}: {{end -}}
{{if .IDIsURL -}}
{{with .VResult}}<a href="{{.ID}}">{{.ID}}</a>{{end -}}
{{if .Kind -}}
<a href="{{.ID}}">{{.ID}}</a>
{{else -}}
{{.VResult.ID -}}
{{.ID -}}
{{end -}}
{{" "}}({{.VResult.Score}})</p>
{{" "}}({{.Score}})</p>
{{end}}
{{- else -}}
{{if .Query}}No results.{{end}}
{{- end}}
</body>
</html>
`))

func (g *Gaby) handleSearchAPI(w http.ResponseWriter, r *http.Request) {
sreq, err := readJSONBody[searchRequest](r)
if err != nil {
// The error could also come from failing to read the body, but then the
// connection is probably broken so it doesn't matter what status we send.
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
sres, err := g.search(r.Context(), sreq)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
data, err := json.Marshal(sres)
if err != nil {
http.Error(w, "json.Marshal: "+err.Error(), http.StatusInternalServerError)
return
}
_, _ = w.Write(data)
}

func readJSONBody[T any](r *http.Request) (*T, error) {
defer r.Body.Close()
data, err := io.ReadAll(r.Body)
if err != nil {
return nil, err
}
t := new(T)
if err := json.Unmarshal(data, t); err != nil {
return nil, err
}
return t, nil
}
126 changes: 121 additions & 5 deletions internal/gaby/search_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,37 @@ package main

import (
"bytes"
"context"
"encoding/json"
"fmt"
"slices"
"strings"
"testing"

"golang.org/x/oscar/internal/docs"
"golang.org/x/oscar/internal/llm"
"golang.org/x/oscar/internal/storage"
"golang.org/x/oscar/internal/testutil"
)

func TestSearchPageTemplate(t *testing.T) {
page := searchPage{
Query: "some query",
Results: []searchResult{
{
Kind: "Example",
Title: "t1",
VResult: storage.VectorResult{
VectorResult: storage.VectorResult{
ID: "https://example.com/x",
Score: 0.987654321,
},
IDIsURL: true,
},
{
VResult: storage.VectorResult{
Kind: "",
VectorResult: storage.VectorResult{
ID: "https://example.com/y",
Score: 0.876,
},
IDIsURL: false,
},
},
}
Expand All @@ -40,7 +47,7 @@ func TestSearchPageTemplate(t *testing.T) {
}
wants := []string{page.Query}
for _, sr := range page.Results {
wants = append(wants, sr.VResult.ID)
wants = append(wants, sr.VectorResult.ID)
}
got := buf.String()
t.Logf("%s", got)
Expand All @@ -50,3 +57,112 @@ func TestSearchPageTemplate(t *testing.T) {
}
}
}

func TestKind(t *testing.T) {
for _, test := range []struct {
id, want string
}{
{"something", ""},
{"https://go.dev/x", "GoDevPage"},
{"https://go.dev/blog/xxx", "GoBlog"},
{"https://go.dev/doc/x", "GoDocumentation"},
{"https://go.dev/ref/x", "GoReference"},
{"https://go.dev/wiki/x", "GoWiki"},
{"https://github.com/golang/go/issues/123", "GitHubIssue"},
} {
got := docIDKind(test.id)
if got != test.want {
t.Errorf("%q: got %q, want %q", test.id, got, test.want)
}
}
}

func TestSearch(t *testing.T) {
ctx := context.Background()
lg := testutil.Slogger(t)
embedder := llm.QuoteEmbedder()
db := storage.MemDB()
vdb := storage.MemVectorDB(db, lg, "")
corpus := docs.New(db)

for i := 0; i < 10; i++ {
id := fmt.Sprintf("id%d", i)
doc := llm.EmbedDoc{Title: fmt.Sprintf("title%d", i), Text: fmt.Sprintf("text-%s", strings.Repeat("x", i))}
corpus.Add(id, doc.Title, doc.Text)
vec, err := embedder.EmbedDocs(ctx, []llm.EmbedDoc{doc})
if err != nil {
t.Fatal(err)
}
vdb.Set(id, vec[0])
}
g := &Gaby{
embed: embedder,
db: db,
vector: vdb,
docs: corpus,
}
sreq := &searchRequest{
Threshold: 0,
Limit: 2,
EmbedDoc: llm.EmbedDoc{Title: "title3", Text: "text-xxx"},
}
sres, err := g.search(ctx, sreq)
if err != nil {
t.Fatal(err)
}
for i := range sres {
sres[i].round()
}

want := []searchResult{
{
Kind: "",
Title: "title3",
VectorResult: storage.VectorResult{ID: "id3", Score: 1.0},
},
{
Kind: "",
Title: "title4",
VectorResult: storage.VectorResult{ID: "id4", Score: 0.56},
},
}

if !slices.Equal(sres, want) {
t.Errorf("got %v\nwant %v", sres, want)
}

sreq.Threshold = 0.9
sres, err = g.search(ctx, sreq)
if err != nil {
t.Fatal(err)
}
if len(sres) != 1 {
t.Errorf("got %d results, want 1", len(sres))
}
}

func TestSearchJSON(t *testing.T) {
// Confirm that we can unmarshal a search request, and marshal a response.
postBody := `{"Limit": 10, "Threshold": 0.8, "Title": "t", "Text": "text"}`
var gotReq searchRequest
if err := json.Unmarshal([]byte(postBody), &gotReq); err != nil {
t.Fatal(err)
}
wantReq := searchRequest{Limit: 10, Threshold: 0.8, EmbedDoc: llm.EmbedDoc{Title: "t", Text: "text"}}
if gotReq != wantReq {
t.Errorf("got %+v, want %+v", gotReq, wantReq)
}

res := []searchResult{
{Kind: "K", Title: "t", VectorResult: storage.VectorResult{ID: "id", Score: .5}},
}
bytes, err := json.Marshal(res)
if err != nil {
t.Fatal(err)
}
got := string(bytes)
want := `[{"Kind":"K","Title":"t","ID":"id","Score":0.5}]`
if got != want {
t.Errorf("\ngot %s\nwant %s", got, want)
}
}

0 comments on commit 88eb357

Please sign in to comment.