Skip to content

Commit

Permalink
Add support for the Ollama backend
Browse files Browse the repository at this point in the history
This commit introduces support for the ollama.ai backend. Ollama is
an open source LLAMA backend meant for local usage. It supports
many different models, many of them related to code generation.

Usage of the ollama backend is available via the `--backend` flag
or `AIAC_BACKEND` environment variables set to `"ollama"`. Ollama
doesn't support authentication currently, so the only related flag
is `--ollama-url` for the API server's URL, but if not used, the
default URL is used (http://localhost:11434/api). With this commit,
`aiac` will not yet support a scenario where the API server is
running behind an authenticating proxy.

Resolves: #77
  • Loading branch information
ido50 committed Feb 6, 2024
1 parent 9d24ddf commit bfc9a45
Show file tree
Hide file tree
Showing 8 changed files with 321 additions and 9 deletions.
21 changes: 19 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ Generator.
## Description

`aiac` is a command line tool to generate IaC (Infrastructure as Code) templates,
configurations, utilities, queries and more via [LLM](https://en.wikipedia.org/wiki/Large_language_model) providers such as [OpenAI](https://openai.com/)
and [Amazon Bedrock](https://aws.amazon.com/bedrock/). The CLI allows you to ask a model to generate templates
configurations, utilities, queries and more via [LLM](https://en.wikipedia.org/wiki/Large_language_model) providers such as [OpenAI](https://openai.com/),
[Amazon Bedrock](https://aws.amazon.com/bedrock/) and [Ollama](https://ollama.ai/). The CLI allows you to ask a model to generate templates
for different scenarios (e.g. "get terraform for AWS EC2"). It composes an
appropriate request to the selected provider, and stores the resulting code to
a file, and/or prints it to standard output.
Expand Down Expand Up @@ -90,6 +90,11 @@ For **Amazon Bedrock**, you will need an AWS account with Bedrock enabled, and
access to relevant models (currently Amazon Titan and Anthropic Claude are
supported by `aiac`). Refer to the [Bedrock documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/what-is-bedrock.html) for more information.

For **Ollama**, you only need the URL to the local Ollama API server, including
the /api path prefix. This defaults to http://localhost:11434/api. Ollama does
not provide an authentication mechanism, but one may be in place in case of a
proxy server being used. This scenario is not currently supported by `aiac`.

### Installation

Via `brew`:
Expand Down Expand Up @@ -132,6 +137,12 @@ For **Amazon Bedrock**:
command line flags. These values default to "default" and "us-east-1",
respectively.

For **Ollama**:

1. Nothing needed except the URL to the API server, if the default one is not
used. Provide it via the `--ollama-url` flag or the `OLLAMA_API_URL`
environment variable. Don't forget to include the /api path prefix.

#### Command Line

By default, aiac prints the extracted code to standard output and opens an
Expand Down Expand Up @@ -186,6 +197,12 @@ the AWS region and profile:

The default model when using Bedrock is "amazon.titan-text-lite-v1".

To generate code via Ollama, provide the `--backend` flag:

aiac get terraform for eks --backend=ollama

The default model when using Ollama is "codellama".

#### Via Docker

All the same instructions apply, except you execute a `docker` image:
Expand Down
20 changes: 19 additions & 1 deletion libaiac/libaiac.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/gofireflyio/aiac/v4/libaiac/bedrock"
"github.com/gofireflyio/aiac/v4/libaiac/ollama"
"github.com/gofireflyio/aiac/v4/libaiac/openai"
"github.com/gofireflyio/aiac/v4/libaiac/types"
)
Expand All @@ -32,6 +33,9 @@ const (

// BackendBedrock represents the Amazon Bedrock LLM provider.
BackendBedrock BackendName = "bedrock"

// BackendOllama represents the Ollama LLM provider.
BackendOllama BackendName = "ollama"
)

// Decode is used by the kong library to map CLI-provided values to the Model
Expand All @@ -49,6 +53,8 @@ func (b *BackendName) Decode(ctx *kong.DecodeContext) error {
*b = BackendOpenAI
case string(BackendBedrock):
*b = BackendBedrock
case string(BackendOllama):
*b = BackendOllama
default:
return fmt.Errorf("%w %s", types.ErrUnsupportedBackend, provided)
}
Expand All @@ -60,7 +66,7 @@ func (b *BackendName) Decode(ctx *kong.DecodeContext) error {
// constructor.
type NewClientOptions struct {
// Backend is the name of the backend to use. Use the available constants,
// e.g. BackendOpenAI or BackendBedrock. Defaults to openai.
// e.g. BackendOpenAI, BackendBedrock or BackendOllama. Defaults to openai.
Backend BackendName

// ----------------------
Expand All @@ -86,6 +92,14 @@ type NewClientOptions struct {

// AWSProfile is the name of the AWS profile to use. Defaults to "default".
AWSProfile string

// ---------------------
// Ollama configuration
// ---------------------

// OllamaURL is the URL to the Ollama API server, including the /api path
// prefix. Defaults to http://localhost:11434/api.
OllamaURL string
}

const (
Expand Down Expand Up @@ -118,6 +132,10 @@ func NewClient(opts *NewClientOptions) *Client {
Credentials: cfg.Credentials,
Region: opts.AWSRegion,
})
case BackendOllama:
backend = ollama.NewClient(&ollama.NewClientOptions{
URL: opts.OllamaURL,
})
default:
// default to openai
backend = openai.NewClient(&openai.NewClientOptions{
Expand Down
87 changes: 87 additions & 0 deletions libaiac/ollama/chat.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package ollama

import (
"context"
"fmt"
"strings"

"github.com/gofireflyio/aiac/v4/libaiac/types"
)

// Conversation is a struct used to converse with an OpenAI chat model. It
// maintains all messages sent/received in order to maintain context just like
// using ChatGPT.
type Conversation struct {
client *Client
model types.Model
messages []types.Message
}

type chatResponse struct {
Message types.Message `json:"message"`
Done bool `json:"done"`
}

// Chat initiates a conversation with an OpenAI chat model. A conversation
// maintains context, allowing to send further instructions to modify the output
// from previous requests, just like using the ChatGPT website.
func (client *Client) Chat(model types.Model) types.Conversation {
if model.Type != types.ModelTypeChat {
return nil
}

return &Conversation{
client: client,
model: model,
}
}

// Send sends the provided message to the API and returns a Response object.
// To maintain context, all previous messages (whether from you to the API or
// vice-versa) are sent as well, allowing you to ask the API to modify the
// code it already generated.
func (conv *Conversation) Send(ctx context.Context, prompt string, msgs ...types.Message) (
res types.Response,
err error,
) {
var answer chatResponse

if len(msgs) > 0 {
conv.messages = append(conv.messages, msgs...)
}

conv.messages = append(conv.messages, types.Message{
Role: "user",
Content: prompt,
})

err = conv.client.NewRequest("POST", "/chat").
JSONBody(map[string]interface{}{
"model": conv.model.Name,
"messages": conv.messages,
"options": map[string]interface{}{
"temperature": 0.2,
},
"stream": false,
}).
Into(&answer).
RunContext(ctx)
if err != nil {
return res, fmt.Errorf("failed sending prompt: %w", err)
}

if !answer.Done {
return res, fmt.Errorf("%w: unexpected truncated response", types.ErrResultTruncated)
}

conv.messages = append(conv.messages, answer.Message)

res.FullOutput = strings.TrimSpace(answer.Message.Content)

var ok bool
if res.Code, ok = types.ExtractCode(res.FullOutput); !ok {
res.Code = res.FullOutput
}

return res, nil
}
69 changes: 69 additions & 0 deletions libaiac/ollama/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package ollama

import (
"encoding/json"
"fmt"
"io"
"net/http"

"github.com/gofireflyio/aiac/v4/libaiac/types"
"github.com/ido50/requests"
)

// DefaultAPIURL is the default URL for a local Ollama API server
const DefaultAPIURL = "http://localhost:11434/api"

// Client is a structure used to continuously generate IaC code via Ollama
type Client struct {
*requests.HTTPClient
}

// NewClientOptions is a struct containing all the parameters accepted by the
// NewClient constructor.
type NewClientOptions struct {
// URL is the URL of the API server (including the /api path prefix). Defaults to DefaultAPIURL.
URL string
}

// NewClient creates a new instance of the Client struct, with the provided
// input options. The Ollama API server is not contacted at this point.
func NewClient(opts *NewClientOptions) *Client {
if opts == nil {
opts = &NewClientOptions{}
}

if opts.URL == "" {
opts.URL = DefaultAPIURL
}

cli := &Client{}

cli.HTTPClient = requests.NewClient(opts.URL).
Accept("application/json").
ErrorHandler(func(
httpStatus int,
contentType string,
body io.Reader,
) error {
var res struct {
Error string `json:"error"`
}

err := json.NewDecoder(body).Decode(&res)
if err != nil {
return fmt.Errorf(
"%w %s",
types.ErrUnexpectedStatus,
http.StatusText(httpStatus),
)
}

return fmt.Errorf(
"%w: %s",
types.ErrRequestFailed,
res.Error,
)
})

return cli
}
52 changes: 52 additions & 0 deletions libaiac/ollama/completion.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package ollama

import (
"context"
"fmt"
"strings"

"github.com/gofireflyio/aiac/v4/libaiac/types"
)

type completionResponse struct {
Response string `json:"response"`
Done bool `json:"done"`
}

// Complete sends a request to OpenAI's Completions API using the provided model
// and prompt, and returns the response
func (client *Client) Complete(
ctx context.Context,
model types.Model,
prompt string,
) (res types.Response, err error) {
var answer completionResponse

err = client.NewRequest("POST", "/generate").
JSONBody(map[string]interface{}{
"model": model.Name,
"prompt": prompt,
"options": map[string]interface{}{
"temperature": 0.2,
},
"stream": false,
}).
Into(&answer).
RunContext(ctx)
if err != nil {
return res, fmt.Errorf("failed sending prompt: %w", err)
}

if !answer.Done {
return res, fmt.Errorf("%w: unexpected truncated response", types.ErrResultTruncated)
}

res.FullOutput = strings.TrimSpace(answer.Response)

var ok bool
if res.Code, ok = types.ExtractCode(res.FullOutput); !ok {
res.Code = res.FullOutput
}

return res, nil
}
63 changes: 63 additions & 0 deletions libaiac/ollama/models.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package ollama

import (
"github.com/gofireflyio/aiac/v4/libaiac/types"
)

var (
// ModelCodeLlama represents the codellama model
ModelCodeLlama = types.Model{"codellama", 0, types.ModelTypeChat}

// ModelDeepseekCoder represents the deepseek-coder model
ModelDeepseekCoder = types.Model{"deepseek-coder", 0, types.ModelTypeChat}

// ModelWizardCoder represents the wizard-coder model
ModelWizardCoder = types.Model{"wizard-coder", 0, types.ModelTypeChat}

// ModelPhindCodeLlama represents the phind-codellama model
ModelPhindCodeLlama = types.Model{"phind-codellama", 0, types.ModelTypeChat}

// ModeCodeUp represents the codeup model
ModelCodeUp = types.Model{"codeup", 0, types.ModelTypeChat}

// ModeStarCoder represents the starcoder model
ModelStarCoder = types.Model{"starcoder", 0, types.ModelTypeChat}

// ModelSQLCoder represents the sqlcoder model
ModelSQLCoder = types.Model{"sqlcoder", 0, types.ModelTypeChat}

// ModelStableCode represents the stablecode model
ModelStableCode = types.Model{"stablecode", 0, types.ModelTypeChat}

// ModelMagicoder represents the magicoder model
ModelMagicoder = types.Model{"magicoder", 0, types.ModelTypeChat}

// ModelCodeBooga represents the codebooga model
ModelCodeBooga = types.Model{"codebooga", 0, types.ModelTypeChat}

// SupportedModels is a list of all language models supported by this
// backend implementation.
SupportedModels = []types.Model{
ModelCodeLlama,
ModelDeepseekCoder,
ModelWizardCoder,
ModelPhindCodeLlama,
ModelCodeUp,
ModelStarCoder,
ModelSQLCoder,
ModelStableCode,
ModelMagicoder,
ModelCodeBooga,
}
)

// ListModels returns a list of all the models supported by this backend
// implementation.
func (client *Client) ListModels() []types.Model {
return SupportedModels
}

// DefaultModel returns the default model used by this backend.
func (client *Client) DefaultModel() types.Model {
return ModelCodeLlama
}
Loading

0 comments on commit bfc9a45

Please sign in to comment.