Skip to content

Commit 9ce8356

Browse files
committed
fix(deepseek): reset tool_choice handling to prevent infinite loop when returnDirect=false (#4617)
Signed-off-by: Kuntal Maity <[email protected]>
1 parent 838801f commit 9ce8356

File tree

2 files changed

+128
-1
lines changed

2 files changed

+128
-1
lines changed

models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatModel.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
* backed by {@link DeepSeekApi}.
7777
*
7878
* @author Geng Rong
79+
* @last Updated By : @kuntal1461
7980
*/
8081
public class DeepSeekChatModel implements ChatModel {
8182

@@ -193,7 +194,10 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons
193194
}).toList();
194195

195196
// Current usage
196-
DeepSeekApi.Usage usage = completionEntity.getBody().usage();
197+
DeepSeekApi.Usage usage = null;
198+
if (completionEntity != null && completionEntity.getBody() != null) {
199+
usage = chatCompletion.usage();
200+
}
197201
Usage currentChatResponseUsage = usage != null ? getDefaultUsage(usage) : new EmptyUsage();
198202
Usage accumulatedUsage = UsageCalculator.getCumulativeUsage(currentChatResponseUsage,
199203
previousChatResponse);
@@ -216,6 +220,10 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons
216220
.build();
217221
}
218222
else {
223+
// Reset tool choice to AUTO to prevent forcing repeated tool calls.
224+
if (prompt.getOptions() instanceof DeepSeekChatOptions options) {
225+
options.setToolChoice(ChatCompletionRequest.ToolChoiceBuilder.AUTO);
226+
}
219227
// Send the tool execution result back to the model.
220228
return this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
221229
response);
@@ -305,6 +313,10 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
305313
.build());
306314
}
307315
else {
316+
// Reset tool choice to AUTO to prevent forcing repeated tool calls.
317+
if (prompt.getOptions() instanceof DeepSeekChatOptions options) {
318+
options.setToolChoice(ChatCompletionRequest.ToolChoiceBuilder.AUTO);
319+
}
308320
// Send the tool execution result back to the model.
309321
return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
310322
response);
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
/*
2+
* Copyright 2023-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.deepseek;
18+
19+
import java.time.Instant;
20+
import java.util.List;
21+
import java.util.concurrent.atomic.AtomicInteger;
22+
23+
import org.junit.jupiter.api.Test;
24+
import org.mockito.ArgumentCaptor;
25+
26+
import org.springframework.ai.chat.model.ChatResponse;
27+
import org.springframework.ai.chat.prompt.Prompt;
28+
import org.springframework.ai.deepseek.api.DeepSeekApi;
29+
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletion;
30+
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletion.Choice;
31+
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionFinishReason;
32+
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionMessage;
33+
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionMessage.ChatCompletionFunction;
34+
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionMessage.ToolCall;
35+
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionRequest;
36+
import org.springframework.ai.tool.function.FunctionToolCallback;
37+
import org.springframework.http.ResponseEntity;
38+
39+
import static org.assertj.core.api.Assertions.assertThat;
40+
import static org.mockito.Mockito.mock;
41+
import static org.mockito.Mockito.when;
42+
43+
/**
44+
* Verifies that DeepSeekChatModel resets tool_choice to AUTO when resubmitting tool
45+
* results (returnDirect=false) to avoid infinite tool call loops.
46+
* @author Kuntal Maity
47+
*/
48+
class DeepSeekChatModelToolChoiceResetTests {
49+
50+
@Test
51+
void resetsToolChoiceToAutoOnToolResultPushback() {
52+
// Arrange: mock API to return a tool call first, then a normal assistant message
53+
DeepSeekApi api = mock(DeepSeekApi.class);
54+
55+
// Capture requests to verify tool_choice on the second call
56+
ArgumentCaptor<ChatCompletionRequest> reqCaptor = ArgumentCaptor.forClass(ChatCompletionRequest.class);
57+
58+
AtomicInteger apiCalls = new AtomicInteger(0);
59+
when(api.chatCompletionEntity(reqCaptor.capture())).thenAnswer(invocation -> {
60+
int call = apiCalls.incrementAndGet();
61+
if (call == 1) {
62+
// First response: model requests tool call
63+
ChatCompletionMessage msg = new ChatCompletionMessage("", // content
64+
ChatCompletionMessage.Role.ASSISTANT, null, null, List.of(new ToolCall("call_1", "function",
65+
new ChatCompletionFunction("getMarineYetiDescription", "{}"))),
66+
null, null);
67+
ChatCompletion cc = new ChatCompletion("id-1",
68+
List.of(new Choice(ChatCompletionFinishReason.TOOL_CALLS, 0, msg, null)),
69+
Instant.now().getEpochSecond(), DeepSeekApi.ChatModel.DEEPSEEK_CHAT.getName(), null,
70+
"chat.completion", null);
71+
return ResponseEntity.ok(cc);
72+
}
73+
else {
74+
// Second response: normal assistant message
75+
ChatCompletionMessage msg = new ChatCompletionMessage("Marine yeti is orange.",
76+
ChatCompletionMessage.Role.ASSISTANT, null, null, null, null, null);
77+
ChatCompletion cc = new ChatCompletion("id-2",
78+
List.of(new Choice(ChatCompletionFinishReason.STOP, 0, msg, null)),
79+
Instant.now().getEpochSecond(), DeepSeekApi.ChatModel.DEEPSEEK_CHAT.getName(), null,
80+
"chat.completion", null);
81+
return ResponseEntity.ok(cc);
82+
}
83+
});
84+
85+
// Tool callback increments counter; returnDirect defaults to false
86+
AtomicInteger toolInvocations = new AtomicInteger(0);
87+
var tool = FunctionToolCallback.builder("getMarineYetiDescription", () -> {
88+
toolInvocations.incrementAndGet();
89+
return "Marine yeti is orange";
90+
}).build();
91+
92+
DeepSeekChatOptions options = DeepSeekChatOptions.builder()
93+
.model(DeepSeekApi.ChatModel.DEEPSEEK_CHAT)
94+
.toolCallbacks(List.of(tool))
95+
.toolChoice(ChatCompletionRequest.ToolChoiceBuilder.FUNCTION("getMarineYetiDescription"))
96+
.build();
97+
98+
DeepSeekChatModel model = DeepSeekChatModel.builder().deepSeekApi(api).defaultOptions(options).build();
99+
100+
// Act
101+
ChatResponse response = model.call(new Prompt("What is the color of a marine yeti?"));
102+
103+
// Assert: API was called twice (tool call, then final text)
104+
assertThat(apiCalls.get()).isEqualTo(2);
105+
// Second request tool_choice should be AUTO
106+
assertThat(reqCaptor.getAllValues()).hasSize(2);
107+
Object secondToolChoice = reqCaptor.getAllValues().get(1).toolChoice();
108+
assertThat(secondToolChoice).isEqualTo(ChatCompletionRequest.ToolChoiceBuilder.AUTO);
109+
// Tool executes exactly once
110+
assertThat(toolInvocations.get()).isEqualTo(1);
111+
// And final content is normal text
112+
assertThat(response.getResult().getOutput().getText()).containsIgnoringCase("orange");
113+
}
114+
115+
}

0 commit comments

Comments
 (0)