-
Notifications
You must be signed in to change notification settings - Fork 276
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This commit introduces support for the ollama.ai backend. Ollama is an open source LLAMA backend meant for local usage. It supports many different models, many of them related to code generation. Usage of the ollama backend is available via the `--backend` flag or `AIAC_BACKEND` environment variables set to `"ollama"`. Ollama doesn't support authentication currently, so the only related flag is `--ollama-url` for the API server's URL, but if not used, the default URL is used (http://localhost:11434/api). With this commit, `aiac` will not yet support a scenario where the API server is running behind an authenticating proxy. Resolves: #77
- Loading branch information
Showing
8 changed files
with
321 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
package ollama | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"strings" | ||
|
||
"github.com/gofireflyio/aiac/v4/libaiac/types" | ||
) | ||
|
||
// Conversation is a struct used to converse with an OpenAI chat model. It | ||
// maintains all messages sent/received in order to maintain context just like | ||
// using ChatGPT. | ||
type Conversation struct { | ||
client *Client | ||
model types.Model | ||
messages []types.Message | ||
} | ||
|
||
type chatResponse struct { | ||
Message types.Message `json:"message"` | ||
Done bool `json:"done"` | ||
} | ||
|
||
// Chat initiates a conversation with an OpenAI chat model. A conversation | ||
// maintains context, allowing to send further instructions to modify the output | ||
// from previous requests, just like using the ChatGPT website. | ||
func (client *Client) Chat(model types.Model) types.Conversation { | ||
if model.Type != types.ModelTypeChat { | ||
return nil | ||
} | ||
|
||
return &Conversation{ | ||
client: client, | ||
model: model, | ||
} | ||
} | ||
|
||
// Send sends the provided message to the API and returns a Response object. | ||
// To maintain context, all previous messages (whether from you to the API or | ||
// vice-versa) are sent as well, allowing you to ask the API to modify the | ||
// code it already generated. | ||
func (conv *Conversation) Send(ctx context.Context, prompt string, msgs ...types.Message) ( | ||
res types.Response, | ||
err error, | ||
) { | ||
var answer chatResponse | ||
|
||
if len(msgs) > 0 { | ||
conv.messages = append(conv.messages, msgs...) | ||
} | ||
|
||
conv.messages = append(conv.messages, types.Message{ | ||
Role: "user", | ||
Content: prompt, | ||
}) | ||
|
||
err = conv.client.NewRequest("POST", "/chat"). | ||
JSONBody(map[string]interface{}{ | ||
"model": conv.model.Name, | ||
"messages": conv.messages, | ||
"options": map[string]interface{}{ | ||
"temperature": 0.2, | ||
}, | ||
"stream": false, | ||
}). | ||
Into(&answer). | ||
RunContext(ctx) | ||
if err != nil { | ||
return res, fmt.Errorf("failed sending prompt: %w", err) | ||
} | ||
|
||
if !answer.Done { | ||
return res, fmt.Errorf("%w: unexpected truncated response", types.ErrResultTruncated) | ||
} | ||
|
||
conv.messages = append(conv.messages, answer.Message) | ||
|
||
res.FullOutput = strings.TrimSpace(answer.Message.Content) | ||
|
||
var ok bool | ||
if res.Code, ok = types.ExtractCode(res.FullOutput); !ok { | ||
res.Code = res.FullOutput | ||
} | ||
|
||
return res, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
package ollama | ||
|
||
import ( | ||
"encoding/json" | ||
"fmt" | ||
"io" | ||
"net/http" | ||
|
||
"github.com/gofireflyio/aiac/v4/libaiac/types" | ||
"github.com/ido50/requests" | ||
) | ||
|
||
// DefaultAPIURL is the default URL for a local Ollama API server | ||
const DefaultAPIURL = "http://localhost:11434/api" | ||
|
||
// Client is a structure used to continuously generate IaC code via Ollama | ||
type Client struct { | ||
*requests.HTTPClient | ||
} | ||
|
||
// NewClientOptions is a struct containing all the parameters accepted by the | ||
// NewClient constructor. | ||
type NewClientOptions struct { | ||
// URL is the URL of the API server (including the /api path prefix). Defaults to DefaultAPIURL. | ||
URL string | ||
} | ||
|
||
// NewClient creates a new instance of the Client struct, with the provided | ||
// input options. The Ollama API server is not contacted at this point. | ||
func NewClient(opts *NewClientOptions) *Client { | ||
if opts == nil { | ||
opts = &NewClientOptions{} | ||
} | ||
|
||
if opts.URL == "" { | ||
opts.URL = DefaultAPIURL | ||
} | ||
|
||
cli := &Client{} | ||
|
||
cli.HTTPClient = requests.NewClient(opts.URL). | ||
Accept("application/json"). | ||
ErrorHandler(func( | ||
httpStatus int, | ||
contentType string, | ||
body io.Reader, | ||
) error { | ||
var res struct { | ||
Error string `json:"error"` | ||
} | ||
|
||
err := json.NewDecoder(body).Decode(&res) | ||
if err != nil { | ||
return fmt.Errorf( | ||
"%w %s", | ||
types.ErrUnexpectedStatus, | ||
http.StatusText(httpStatus), | ||
) | ||
} | ||
|
||
return fmt.Errorf( | ||
"%w: %s", | ||
types.ErrRequestFailed, | ||
res.Error, | ||
) | ||
}) | ||
|
||
return cli | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
package ollama | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"strings" | ||
|
||
"github.com/gofireflyio/aiac/v4/libaiac/types" | ||
) | ||
|
||
type completionResponse struct { | ||
Response string `json:"response"` | ||
Done bool `json:"done"` | ||
} | ||
|
||
// Complete sends a request to OpenAI's Completions API using the provided model | ||
// and prompt, and returns the response | ||
func (client *Client) Complete( | ||
ctx context.Context, | ||
model types.Model, | ||
prompt string, | ||
) (res types.Response, err error) { | ||
var answer completionResponse | ||
|
||
err = client.NewRequest("POST", "/generate"). | ||
JSONBody(map[string]interface{}{ | ||
"model": model.Name, | ||
"prompt": prompt, | ||
"options": map[string]interface{}{ | ||
"temperature": 0.2, | ||
}, | ||
"stream": false, | ||
}). | ||
Into(&answer). | ||
RunContext(ctx) | ||
if err != nil { | ||
return res, fmt.Errorf("failed sending prompt: %w", err) | ||
} | ||
|
||
if !answer.Done { | ||
return res, fmt.Errorf("%w: unexpected truncated response", types.ErrResultTruncated) | ||
} | ||
|
||
res.FullOutput = strings.TrimSpace(answer.Response) | ||
|
||
var ok bool | ||
if res.Code, ok = types.ExtractCode(res.FullOutput); !ok { | ||
res.Code = res.FullOutput | ||
} | ||
|
||
return res, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
package ollama | ||
|
||
import ( | ||
"github.com/gofireflyio/aiac/v4/libaiac/types" | ||
) | ||
|
||
var ( | ||
// ModelCodeLlama represents the codellama model | ||
ModelCodeLlama = types.Model{"codellama", 0, types.ModelTypeChat} | ||
|
||
// ModelDeepseekCoder represents the deepseek-coder model | ||
ModelDeepseekCoder = types.Model{"deepseek-coder", 0, types.ModelTypeChat} | ||
|
||
// ModelWizardCoder represents the wizard-coder model | ||
ModelWizardCoder = types.Model{"wizard-coder", 0, types.ModelTypeChat} | ||
|
||
// ModelPhindCodeLlama represents the phind-codellama model | ||
ModelPhindCodeLlama = types.Model{"phind-codellama", 0, types.ModelTypeChat} | ||
|
||
// ModeCodeUp represents the codeup model | ||
ModelCodeUp = types.Model{"codeup", 0, types.ModelTypeChat} | ||
|
||
// ModeStarCoder represents the starcoder model | ||
ModelStarCoder = types.Model{"starcoder", 0, types.ModelTypeChat} | ||
|
||
// ModelSQLCoder represents the sqlcoder model | ||
ModelSQLCoder = types.Model{"sqlcoder", 0, types.ModelTypeChat} | ||
|
||
// ModelStableCode represents the stablecode model | ||
ModelStableCode = types.Model{"stablecode", 0, types.ModelTypeChat} | ||
|
||
// ModelMagicoder represents the magicoder model | ||
ModelMagicoder = types.Model{"magicoder", 0, types.ModelTypeChat} | ||
|
||
// ModelCodeBooga represents the codebooga model | ||
ModelCodeBooga = types.Model{"codebooga", 0, types.ModelTypeChat} | ||
|
||
// SupportedModels is a list of all language models supported by this | ||
// backend implementation. | ||
SupportedModels = []types.Model{ | ||
ModelCodeLlama, | ||
ModelDeepseekCoder, | ||
ModelWizardCoder, | ||
ModelPhindCodeLlama, | ||
ModelCodeUp, | ||
ModelStarCoder, | ||
ModelSQLCoder, | ||
ModelStableCode, | ||
ModelMagicoder, | ||
ModelCodeBooga, | ||
} | ||
) | ||
|
||
// ListModels returns a list of all the models supported by this backend | ||
// implementation. | ||
func (client *Client) ListModels() []types.Model { | ||
return SupportedModels | ||
} | ||
|
||
// DefaultModel returns the default model used by this backend. | ||
func (client *Client) DefaultModel() types.Model { | ||
return ModelCodeLlama | ||
} |
Oops, something went wrong.