Skip to content

Commit

Permalink
Merge pull request #64 from erikdubbelboer/NewEmbeddingFuncOllamaWithURL
Browse files Browse the repository at this point in the history
Take Ollama base URL as parameter
  • Loading branch information
philippgille authored Apr 21, 2024
2 parents 36a6eb3 + f51b592 commit d56117d
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 9 deletions.
12 changes: 8 additions & 4 deletions embed_ollama.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@ import (
"sync"
)

// TODO: Turn into const and use as default, but allow user to pass custom URL
// as well as custom API key, in case Ollama runs on a remote (secured) server.
var baseURLOllama = "http://localhost:11434/api"
const defaultBaseURLOllama = "http://localhost:11434/api"

type ollamaResponse struct {
Embedding []float32 `json:"embedding"`
Expand All @@ -23,7 +21,13 @@ type ollamaResponse struct {
// using Ollama's embedding API. You can pass any model that Ollama supports and
// that supports embeddings. A good one as of 2024-03-02 is "nomic-embed-text".
// See https://ollama.com/library/nomic-embed-text
func NewEmbeddingFuncOllama(model string) EmbeddingFunc {
// baseURLOllama is the base URL of the Ollama API. If it's empty,
// "http://localhost:11434/api" is used.
func NewEmbeddingFuncOllama(model string, baseURLOllama string) EmbeddingFunc {
if baseURLOllama == "" {
baseURLOllama = defaultBaseURLOllama
}

// We don't set a default timeout here, although it's usually a good idea.
// In our case though, the library user can set the timeout on the context,
// and it might have to be a long timeout, depending on the text length.
Expand Down
5 changes: 1 addition & 4 deletions embed_ollama_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,8 @@ func TestNewEmbeddingFuncOllama(t *testing.T) {
if err != nil {
t.Fatal("unexpected error:", err)
}
// TODO: It's bad to overwrite a global var for testing. Follow-up with a change
// to allow passing custom URLs to the function.
baseURLOllama = strings.Replace(baseURLOllama, "11434", u.Port(), 1)

f := NewEmbeddingFuncOllama(model)
f := NewEmbeddingFuncOllama(model, strings.Replace(defaultBaseURLOllama, "11434", u.Port(), 1))
res, err := f(context.Background(), prompt)
if err != nil {
t.Fatal("expected nil, got", err)
Expand Down
2 changes: 1 addition & 1 deletion examples/rag-wikipedia-ollama/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func main() {
// variable to be set.
// For this example we choose to use a locally running embedding model though.
// It requires Ollama to serve its API at "http://localhost:11434/api".
collection, err := db.GetOrCreateCollection("Wikipedia", nil, chromem.NewEmbeddingFuncOllama(embeddingModel))
collection, err := db.GetOrCreateCollection("Wikipedia", nil, chromem.NewEmbeddingFuncOllama(embeddingModel, ""))
if err != nil {
panic(err)
}
Expand Down

0 comments on commit d56117d

Please sign in to comment.