Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ const (

type ChannelMessage struct {
interfaces.BifrostRequest
Response chan *interfaces.CompletionResult
Response chan *interfaces.BifrostResponse
Err chan error
Type RequestType
}
Expand Down Expand Up @@ -179,7 +179,7 @@ func (bifrost *Bifrost) processRequests(provider interfaces.Provider, queue chan
defer bifrost.wg[provider.GetProviderKey()].Done()

for req := range queue {
var result *interfaces.CompletionResult
var result *interfaces.BifrostResponse
var err error

key, err := bifrost.SelectKeyFromProviderForModel(provider, req.Model)
Expand Down Expand Up @@ -234,13 +234,13 @@ func (bifrost *Bifrost) GetProviderQueue(providerKey interfaces.SupportedModelPr
return queue, nil
}

func (bifrost *Bifrost) TextCompletionRequest(providerKey interfaces.SupportedModelProvider, req *interfaces.BifrostRequest, ctx context.Context) (*interfaces.CompletionResult, error) {
func (bifrost *Bifrost) TextCompletionRequest(providerKey interfaces.SupportedModelProvider, req *interfaces.BifrostRequest, ctx context.Context) (*interfaces.BifrostResponse, error) {
queue, err := bifrost.GetProviderQueue(providerKey)
if err != nil {
return nil, err
}

responseChan := make(chan *interfaces.CompletionResult)
responseChan := make(chan *interfaces.BifrostResponse)
errorChan := make(chan error)

for _, plugin := range bifrost.plugins {
Expand Down Expand Up @@ -273,13 +273,13 @@ func (bifrost *Bifrost) TextCompletionRequest(providerKey interfaces.SupportedMo
}
}

func (bifrost *Bifrost) ChatCompletionRequest(providerKey interfaces.SupportedModelProvider, req *interfaces.BifrostRequest, ctx context.Context) (*interfaces.CompletionResult, error) {
func (bifrost *Bifrost) ChatCompletionRequest(providerKey interfaces.SupportedModelProvider, req *interfaces.BifrostRequest, ctx context.Context) (*interfaces.BifrostResponse, error) {
queue, err := bifrost.GetProviderQueue(providerKey)
if err != nil {
return nil, err
}

responseChan := make(chan *interfaces.CompletionResult)
responseChan := make(chan *interfaces.BifrostResponse)
errorChan := make(chan error)

for _, plugin := range bifrost.plugins {
Expand Down
1 change: 1 addition & 0 deletions interfaces/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ type Key struct {
Weight float64 `json:"weight"`
}

// TODO one get config method
type Account interface {
GetInitiallyConfiguredProviderKeys() ([]SupportedModelProvider, error)
GetKeysForProvider(provider Provider) ([]Key, error)
Expand Down
2 changes: 1 addition & 1 deletion interfaces/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ type BifrostRequest struct {

type Plugin interface {
PreHook(ctx context.Context, req *BifrostRequest) (context.Context, *BifrostRequest, error)
PostHook(ctx context.Context, result *CompletionResult) (*CompletionResult, error)
PostHook(ctx context.Context, result *BifrostResponse) (*BifrostResponse, error)
}
Loading