Skip to content

Commit 2112a9d

Browse files
committed
feat(chat): fix ExtraBody embedding and add comprehensive tests
1 parent c18b4e2 commit 2112a9d

File tree

2 files changed

+123
-2
lines changed

2 files changed

+123
-2
lines changed

chat.go

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -482,12 +482,28 @@ func (c *Client) CreateChatCompletion(
482482
return
483483
}
484484

485+
// The body map is used to dynamically construct the request payload for the chat completion API.
486+
// Instead of relying on a fixed struct, the body map allows for flexible inclusion of fields
487+
// based on their presence, avoiding unnecessary or empty fields in the request.
488+
extraBody := request.ExtraBody
489+
request.ExtraBody = nil
490+
491+
// Serialize request to JSON
492+
jsonData, err := json.Marshal(request)
493+
if err != nil {
494+
return
495+
}
496+
497+
// Deserialize JSON to map[string]any
498+
var body map[string]any
499+
_ = json.Unmarshal(jsonData, &body)
500+
485501
req, err := c.newRequest(
486502
ctx,
487503
http.MethodPost,
488504
c.fullURL(urlSuffix, withModel(request.Model)),
489-
withBody(request),
490-
withExtraBody(request.ExtraBody),
505+
withBody(body), // Main request body.
506+
withExtraBody(extraBody), // Merge ExtraBody fields.
491507
)
492508
if err != nil {
493509
return

chat_test.go

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -756,6 +756,111 @@ func TestChatCompletionsFunctions(t *testing.T) {
756756
})
757757
}
758758

759+
func TestChatCompletionsWithExtraBody(t *testing.T) {
760+
client, server, teardown := setupOpenAITestServer()
761+
defer teardown()
762+
763+
// Register a custom handler that checks if ExtraBody fields are properly embedded
764+
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
765+
// Read the request body
766+
reqBody, err := io.ReadAll(r.Body)
767+
if err != nil {
768+
http.Error(w, "could not read request", http.StatusInternalServerError)
769+
return
770+
}
771+
772+
// Parse the request body into a map to check all fields
773+
var requestBody map[string]any
774+
if err := json.Unmarshal(reqBody, &requestBody); err != nil {
775+
http.Error(w, fmt.Sprintf("could not parse request: %v, body: %s", err, string(reqBody)), http.StatusInternalServerError)
776+
return
777+
}
778+
779+
// Check that ExtraBody fields are present in the root level
780+
if _, exists := requestBody["custom_field"]; !exists {
781+
w.Header().Set("Content-Type", "application/json")
782+
w.WriteHeader(http.StatusBadRequest)
783+
json.NewEncoder(w).Encode(map[string]string{"error": "custom_field not found in request body"})
784+
return
785+
}
786+
787+
if _, exists := requestBody["another_field"]; !exists {
788+
w.Header().Set("Content-Type", "application/json")
789+
w.WriteHeader(http.StatusBadRequest)
790+
json.NewEncoder(w).Encode(map[string]string{"error": "another_field not found in request body"})
791+
return
792+
}
793+
794+
// Check that regular fields are still present
795+
if _, exists := requestBody["model"]; !exists {
796+
w.Header().Set("Content-Type", "application/json")
797+
w.WriteHeader(http.StatusBadRequest)
798+
json.NewEncoder(w).Encode(map[string]string{"error": "model not found in request body"})
799+
return
800+
}
801+
802+
if _, exists := requestBody["messages"]; !exists {
803+
w.Header().Set("Content-Type", "application/json")
804+
w.WriteHeader(http.StatusBadRequest)
805+
json.NewEncoder(w).Encode(map[string]string{"error": "messages not found in request body"})
806+
return
807+
}
808+
809+
// ExtraBody should not be present in the final request
810+
if _, exists := requestBody["extra_body"]; exists {
811+
w.Header().Set("Content-Type", "application/json")
812+
w.WriteHeader(http.StatusBadRequest)
813+
json.NewEncoder(w).Encode(map[string]string{"error": "extra_body should not be present in final request"})
814+
return
815+
}
816+
817+
// Return a success response
818+
res := openai.ChatCompletionResponse{
819+
ID: "test-id",
820+
Object: "chat.completion",
821+
Created: time.Now().Unix(),
822+
Model: "gpt-3.5-turbo",
823+
Choices: []openai.ChatCompletionChoice{
824+
{
825+
Index: 0,
826+
Message: openai.ChatCompletionMessage{
827+
Role: openai.ChatMessageRoleAssistant,
828+
Content: "Hello!",
829+
},
830+
FinishReason: openai.FinishReasonStop,
831+
},
832+
},
833+
Usage: openai.Usage{
834+
PromptTokens: 5,
835+
CompletionTokens: 5,
836+
TotalTokens: 10,
837+
},
838+
}
839+
840+
resBytes, _ := json.Marshal(res)
841+
w.Header().Set("Content-Type", "application/json")
842+
w.WriteHeader(http.StatusOK)
843+
w.Write(resBytes)
844+
})
845+
846+
// Test the ExtraBody functionality
847+
_, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
848+
Model: openai.GPT3Dot5Turbo,
849+
Messages: []openai.ChatCompletionMessage{
850+
{
851+
Role: openai.ChatMessageRoleUser,
852+
Content: "Hello!",
853+
},
854+
},
855+
ExtraBody: map[string]any{
856+
"custom_field": "custom_value",
857+
"another_field": 123,
858+
},
859+
})
860+
861+
checks.NoError(t, err, "CreateChatCompletion with ExtraBody error")
862+
}
863+
759864
func TestAzureChatCompletions(t *testing.T) {
760865
client, server, teardown := setupAzureTestServer()
761866
defer teardown()

0 commit comments

Comments
 (0)