From 4b0989afe40fdd6ac16735677905789f32969292 Mon Sep 17 00:00:00 2001
From: MotorBottle <71703952+MotorBottle@users.noreply.github.com>
Date: Tue, 6 Aug 2024 23:44:37 +0800
Subject: [PATCH] feat: add Max Tokens and Context Window Setting Options for
 Ollama Channel (#1694)

* Update main.go with max_tokens param

* Update model.go with max_tokens param

* Update model.go

* Update main.go

* Update main.go

* Adds num_ctx param for Ollama Channel

* Added num_ctx param for ollama adapter

* Added num_ctx param for ollama adapter

* Improved data process logic
---
 relay/adaptor/ollama/main.go  | 8 ++++++--
 relay/adaptor/ollama/model.go | 2 ++
 relay/model/general.go        | 1 +
 3 files changed, 9 insertions(+), 2 deletions(-)

diff --git a/relay/adaptor/ollama/main.go b/relay/adaptor/ollama/main.go
index 6a1d334d1a..43317ff66f 100644
--- a/relay/adaptor/ollama/main.go
+++ b/relay/adaptor/ollama/main.go
@@ -31,6 +31,8 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
 			TopP:             request.TopP,
 			FrequencyPenalty: request.FrequencyPenalty,
 			PresencePenalty:  request.PresencePenalty,
+			NumPredict:  	  request.MaxTokens,
+			NumCtx:  	  request.NumCtx,
 		},
 		Stream: request.Stream,
 	}
@@ -118,8 +120,10 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
 	common.SetEventStreamHeaders(c)
 
 	for scanner.Scan() {
-		data := strings.TrimPrefix(scanner.Text(), "}")
-		data = data + "}"
+		data := scanner.Text()
+		if strings.HasPrefix(data, "}") {
+		    data = strings.TrimPrefix(data, "}") + "}"
+		}
 
 		var ollamaResponse ChatResponse
 		err := json.Unmarshal([]byte(data), &ollamaResponse)
diff --git a/relay/adaptor/ollama/model.go b/relay/adaptor/ollama/model.go
index 29430e1c7c..7039984fcc 100644
--- a/relay/adaptor/ollama/model.go
+++ b/relay/adaptor/ollama/model.go
@@ -7,6 +7,8 @@ type Options struct {
 	TopP             float64 `json:"top_p,omitempty"`
 	FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
 	PresencePenalty  float64 `json:"presence_penalty,omitempty"`
+	NumPredict  	 int 	 `json:"num_predict,omitempty"`
+	NumCtx  	 int 	 `json:"num_ctx,omitempty"`
 }
 
 type Message struct {
diff --git a/relay/model/general.go b/relay/model/general.go
index 229a61c160..c34c1c2d5d 100644
--- a/relay/model/general.go
+++ b/relay/model/general.go
@@ -29,6 +29,7 @@ type GeneralOpenAIRequest struct {
 	Dimensions       int             `json:"dimensions,omitempty"`
 	Instruction      string          `json:"instruction,omitempty"`
 	Size             string          `json:"size,omitempty"`
+	NumCtx           int         	 `json:"num_ctx,omitempty"`
 }
 
 func (r GeneralOpenAIRequest) ParseInput() []string {