Skip to content

Commit 49506de

Browse files
YunKuiLuilayaperumalg
authored andcommitted
feat(zhipuai): Support glm-4.1v-thinking-flash model
- Add reasoning_content fields to ChatCompletionMessage - Added ZhiPuAiAssistantMessage as a subclass of AssistantMessage to support returning CoT content. - Add integration tests - fix "RestClient.Builder bean not found" exception for zhipu's image auto-config Signed-off-by: YunKui Lu <[email protected]>
1 parent c2103b0 commit 49506de

File tree

7 files changed

+205
-19
lines changed

7 files changed

+205
-19
lines changed

auto-configurations/models/spring-ai-autoconfigure-model-zhipuai/src/main/java/org/springframework/ai/model/zhipuai/autoconfigure/ZhiPuAiImageAutoConfiguration.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.springframework.ai.zhipuai.ZhiPuAiImageModel;
2323
import org.springframework.ai.zhipuai.api.ZhiPuAiApi;
2424
import org.springframework.ai.zhipuai.api.ZhiPuAiImageApi;
25+
import org.springframework.beans.factory.ObjectProvider;
2526
import org.springframework.boot.autoconfigure.AutoConfiguration;
2627
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
2728
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
@@ -51,8 +52,8 @@ public class ZhiPuAiImageAutoConfiguration {
5152
@Bean
5253
@ConditionalOnMissingBean
5354
public ZhiPuAiImageModel zhiPuAiImageModel(ZhiPuAiConnectionProperties commonProperties,
54-
ZhiPuAiImageProperties imageProperties, RestClient.Builder restClientBuilder, RetryTemplate retryTemplate,
55-
ResponseErrorHandler responseErrorHandler) {
55+
ZhiPuAiImageProperties imageProperties, ObjectProvider<RestClient.Builder> restClientBuilderProvider,
56+
RetryTemplate retryTemplate, ResponseErrorHandler responseErrorHandler) {
5657

5758
String apiKey = StringUtils.hasText(imageProperties.getApiKey()) ? imageProperties.getApiKey()
5859
: commonProperties.getApiKey();
@@ -63,7 +64,9 @@ public ZhiPuAiImageModel zhiPuAiImageModel(ZhiPuAiConnectionProperties commonPro
6364
Assert.hasText(apiKey, "ZhiPuAI API key must be set");
6465
Assert.hasText(baseUrl, "ZhiPuAI base URL must be set");
6566

66-
var zhiPuAiImageApi = new ZhiPuAiImageApi(baseUrl, apiKey, restClientBuilder, responseErrorHandler);
67+
// TODO add ZhiPuAiApi support for image
68+
var zhiPuAiImageApi = new ZhiPuAiImageApi(baseUrl, apiKey,
69+
restClientBuilderProvider.getIfAvailable(RestClient::builder), responseErrorHandler);
6770

6871
return new ZhiPuAiImageModel(zhiPuAiImageApi, imageProperties.getOptions(), retryTemplate);
6972
}
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
/*
2+
* Copyright 2025-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.zhipuai;
18+
19+
import java.util.List;
20+
import java.util.Map;
21+
import java.util.Objects;
22+
23+
import org.springframework.ai.chat.messages.AssistantMessage;
24+
import org.springframework.ai.content.Media;
25+
26+
/**
27+
* @author YunKui Lu
28+
*/
29+
public class ZhiPuAiAssistantMessage extends AssistantMessage {
30+
31+
/**
32+
* The CoT content of the message.
33+
*/
34+
private String reasoningContent;
35+
36+
public ZhiPuAiAssistantMessage(String content) {
37+
super(content);
38+
}
39+
40+
public ZhiPuAiAssistantMessage(String content, String reasoningContent, Map<String, Object> properties,
41+
List<ToolCall> toolCalls, List<Media> media) {
42+
super(content, properties, toolCalls, media);
43+
this.reasoningContent = reasoningContent;
44+
}
45+
46+
public String getReasoningContent() {
47+
return this.reasoningContent;
48+
}
49+
50+
public ZhiPuAiAssistantMessage setReasoningContent(String reasoningContent) {
51+
this.reasoningContent = reasoningContent;
52+
return this;
53+
}
54+
55+
@Override
56+
public boolean equals(Object o) {
57+
if (this == o) {
58+
return true;
59+
}
60+
if (!(o instanceof ZhiPuAiAssistantMessage that)) {
61+
return false;
62+
}
63+
if (!super.equals(o)) {
64+
return false;
65+
}
66+
return Objects.equals(this.reasoningContent, that.reasoningContent);
67+
}
68+
69+
@Override
70+
public int hashCode() {
71+
return Objects.hash(super.hashCode(), this.reasoningContent);
72+
}
73+
74+
@Override
75+
public String toString() {
76+
return "ZhiPuAiAssistantMessage{" + "media=" + this.media + ", messageType=" + this.messageType + ", metadata="
77+
+ this.metadata + ", reasoningContent='" + this.reasoningContent + '\'' + ", textContent='"
78+
+ this.textContent + '\'' + '}';
79+
}
80+
81+
}

models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023-2024 the original author or authors.
2+
* Copyright 2023-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -227,7 +227,11 @@ private static Generation buildGeneration(Choice choice, Map<String, Object> met
227227
toolCall.function().name(), toolCall.function().arguments()))
228228
.toList();
229229

230-
var assistantMessage = new AssistantMessage(choice.message().content(), metadata, toolCalls);
230+
String textContent = choice.message().content();
231+
String reasoningContent = choice.message().reasoningContent();
232+
233+
var assistantMessage = new ZhiPuAiAssistantMessage(textContent, reasoningContent, metadata, toolCalls,
234+
List.of());
231235
String finishReason = (choice.finishReason() != null ? choice.finishReason().name() : "");
232236
var generationMetadata = ChatGenerationMetadata.builder().finishReason(finishReason).build();
233237
return new Generation(assistantMessage, generationMetadata);
@@ -511,7 +515,7 @@ else if (message.getMessageType() == MessageType.ASSISTANT) {
511515
}).toList();
512516
}
513517
return List.of(new ChatCompletionMessage(assistantMessage.getText(),
514-
ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls));
518+
ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls, null));
515519
}
516520
else if (message.getMessageType() == MessageType.TOOL) {
517521
ToolResponseMessage toolMessage = (ToolResponseMessage) message;
@@ -522,7 +526,7 @@ else if (message.getMessageType() == MessageType.TOOL) {
522526
return toolMessage.getResponses()
523527
.stream()
524528
.map(tr -> new ChatCompletionMessage(tr.responseData(), ChatCompletionMessage.Role.TOOL, tr.name(),
525-
tr.id(), null))
529+
tr.id(), null, null))
526530
.toList();
527531
}
528532
else {

models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,14 @@ public enum ChatModel implements ChatModelDescription {
377377

378378
GLM_4_Flash("glm-4-flash"),
379379

380-
GLM_3_Turbo("GLM-3-Turbo"); // @formatter:on
380+
GLM_3_Turbo("GLM-3-Turbo"),
381+
382+
// --- Visual Reasoning Models ---
383+
384+
GLM_4_Thinking_FlashX("glm-4.1v-thinking-flashx"),
385+
386+
GLM_4_Thinking_Flash("glm-4.1v-thinking-flash");
387+
// @formatter:on
381388

382389
public final String value;
383390

@@ -772,7 +779,8 @@ public record ChatCompletionMessage(// @formatter:off
772779
@JsonProperty("role") Role role,
773780
@JsonProperty("name") String name,
774781
@JsonProperty("tool_call_id") String toolCallId,
775-
@JsonProperty("tool_calls") List<ToolCall> toolCalls) { // @formatter:on
782+
@JsonProperty("tool_calls") List<ToolCall> toolCalls,
783+
@JsonProperty("reasoning_content") String reasoningContent) { // @formatter:on
776784

777785
/**
778786
* Create a chat completion message with the given content and role. All other
@@ -781,7 +789,7 @@ public record ChatCompletionMessage(// @formatter:off
781789
* @param role The role of the author of this message.
782790
*/
783791
public ChatCompletionMessage(Object content, Role role) {
784-
this(content, role, null, null, null);
792+
this(content, role, null, null, null, null);
785793
}
786794

787795
/**

models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiStreamFunctionCallingHelper.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023-2024 the original author or authors.
2+
* Copyright 2023-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -85,6 +85,8 @@ private ChunkChoice merge(ChunkChoice previous, ChunkChoice current) {
8585
private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompletionMessage current) {
8686
String content = (current.content() != null ? current.content()
8787
: (previous.content() != null) ? previous.content() : "");
88+
String reasoningContent = (current.reasoningContent() != null ? current.reasoningContent()
89+
: (previous.reasoningContent() != null ? previous.reasoningContent() : ""));
8890
Role role = (current.role() != null ? current.role() : previous.role());
8991
role = (role != null ? role : Role.ASSISTANT); // default to ASSISTANT (if null
9092
String name = (current.name() != null ? current.name() : previous.name());
@@ -118,7 +120,7 @@ private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompleti
118120
toolCalls.add(lastPreviousTooCall);
119121
}
120122
}
121-
return new ChatCompletionMessage(content, role, name, toolCallId, toolCalls);
123+
return new ChatCompletionMessage(content, role, name, toolCallId, toolCalls, reasoningContent);
122124
}
123125

124126
private ToolCall merge(ToolCall previous, ToolCall current) {

models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiApiToolFunctionCallIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ public void toolFunctionCall() {
124124

125125
// extend conversation with function response.
126126
messages.add(new ChatCompletionMessage("" + weatherResponse.temp() + weatherRequest.unit(), Role.TOOL,
127-
functionName, toolCall.id(), null));
127+
functionName, toolCall.id(), null, null));
128128
}
129129
}
130130

models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatModelIT.java

Lines changed: 94 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023-2024 the original author or authors.
2+
* Copyright 2023-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -49,6 +49,7 @@
4949
import org.springframework.ai.converter.ListOutputConverter;
5050
import org.springframework.ai.converter.MapOutputConverter;
5151
import org.springframework.ai.tool.function.FunctionToolCallback;
52+
import org.springframework.ai.zhipuai.ZhiPuAiAssistantMessage;
5253
import org.springframework.ai.zhipuai.ZhiPuAiChatOptions;
5354
import org.springframework.ai.zhipuai.ZhiPuAiTestConfiguration;
5455
import org.springframework.ai.zhipuai.api.MockWeatherService;
@@ -60,6 +61,7 @@
6061
import org.springframework.core.io.ClassPathResource;
6162
import org.springframework.core.io.Resource;
6263
import org.springframework.util.MimeTypeUtils;
64+
import org.springframework.util.StringUtils;
6365

6466
import static org.assertj.core.api.Assertions.assertThat;
6567

@@ -312,7 +314,29 @@ void multiModalityEmbeddedImage(String modelName) throws IOException {
312314
}
313315

314316
@ParameterizedTest(name = "{0} : {displayName} ")
315-
@ValueSource(strings = { "glm-4v" })
317+
@ValueSource(strings = { "glm-4.1v-thinking-flash" })
318+
void reasonerMultiModalityEmbeddedImageThinkingModel(String modelName) throws IOException {
319+
var imageData = new ClassPathResource("/test.png");
320+
321+
var userMessage = UserMessage.builder()
322+
.text("Explain what do you see on this picture?")
323+
.media(List.of(new Media(MimeTypeUtils.IMAGE_PNG, imageData)))
324+
.build();
325+
326+
var response = this.chatModel
327+
.call(new Prompt(List.of(userMessage), ZhiPuAiChatOptions.builder().model(modelName).build()));
328+
329+
logger.info(response.getResult().getOutput().getText());
330+
assertThat(response.getResult().getOutput().getText()).containsAnyOf("bananas", "apple", "bowl", "basket",
331+
"fruit stand");
332+
333+
logger.info(((ZhiPuAiAssistantMessage) response.getResult().getOutput()).getReasoningContent());
334+
assertThat(((ZhiPuAiAssistantMessage) response.getResult().getOutput()).getReasoningContent())
335+
.containsAnyOf("bananas", "apple", "bowl", "basket", "fruit stand");
336+
}
337+
338+
@ParameterizedTest(name = "{0} : {displayName} ")
339+
@ValueSource(strings = { "glm-4v", "glm-4.1v-thinking-flash" })
316340
void multiModalityImageUrl(String modelName) throws IOException {
317341

318342
var userMessage = UserMessage.builder()
@@ -331,8 +355,9 @@ void multiModalityImageUrl(String modelName) throws IOException {
331355
"fruit stand");
332356
}
333357

334-
@Test
335-
void streamingMultiModalityImageUrl() throws IOException {
358+
@ParameterizedTest(name = "{0} : {displayName} ")
359+
@ValueSource(strings = { "glm-4.1v-thinking-flash" })
360+
void reasonerMultiModalityImageUrl(String modelName) throws IOException {
336361

337362
var userMessage = UserMessage.builder()
338363
.text("Explain what do you see on this picture?")
@@ -342,8 +367,32 @@ void streamingMultiModalityImageUrl() throws IOException {
342367
.build()))
343368
.build();
344369

345-
Flux<ChatResponse> response = this.streamingChatModel.stream(new Prompt(List.of(userMessage),
346-
ZhiPuAiChatOptions.builder().model(ZhiPuAiApi.ChatModel.GLM_4V.getValue()).build()));
370+
ChatResponse response = this.chatModel
371+
.call(new Prompt(List.of(userMessage), ZhiPuAiChatOptions.builder().model(modelName).build()));
372+
373+
logger.info(response.getResult().getOutput().getText());
374+
assertThat(response.getResult().getOutput().getText()).containsAnyOf("bananas", "apple", "bowl", "basket",
375+
"fruit stand");
376+
377+
logger.info(((ZhiPuAiAssistantMessage) response.getResult().getOutput()).getReasoningContent());
378+
assertThat(((ZhiPuAiAssistantMessage) response.getResult().getOutput()).getReasoningContent())
379+
.containsAnyOf("bananas", "apple", "bowl", "basket", "fruit stand");
380+
}
381+
382+
@ParameterizedTest(name = "{0} : {displayName} ")
383+
@ValueSource(strings = { "glm-4v" })
384+
void streamingMultiModalityImageUrl(String modelName) throws IOException {
385+
386+
var userMessage = UserMessage.builder()
387+
.text("Explain what do you see on this picture?")
388+
.media(List.of(Media.builder()
389+
.mimeType(MimeTypeUtils.IMAGE_PNG)
390+
.data(URI.create("https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png"))
391+
.build()))
392+
.build();
393+
394+
Flux<ChatResponse> response = this.streamingChatModel
395+
.stream(new Prompt(List.of(userMessage), ZhiPuAiChatOptions.builder().model(modelName).build()));
347396

348397
String content = Objects.requireNonNull(response.collectList().block())
349398
.stream()
@@ -356,6 +405,45 @@ void streamingMultiModalityImageUrl() throws IOException {
356405
assertThat(content).containsAnyOf("bananas", "apple", "bowl", "basket", "fruit stand");
357406
}
358407

408+
@ParameterizedTest(name = "{0} : {displayName} ")
409+
@ValueSource(strings = { "glm-4.1v-thinking-flash" })
410+
void reasonerStreamingMultiModalityImageUrl(String modelName) throws IOException {
411+
412+
var userMessage = UserMessage.builder()
413+
.text("Explain what do you see on this picture?")
414+
.media(List.of(Media.builder()
415+
.mimeType(MimeTypeUtils.IMAGE_PNG)
416+
.data(URI.create("https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png"))
417+
.build()))
418+
.build();
419+
420+
Flux<ChatResponse> response = this.streamingChatModel
421+
.stream(new Prompt(List.of(userMessage), ZhiPuAiChatOptions.builder().model(modelName).build()));
422+
423+
List<ZhiPuAiAssistantMessage> streamingMessages = Objects.requireNonNull(response.collectList().block())
424+
.stream()
425+
.map(ChatResponse::getResults)
426+
.flatMap(List::stream)
427+
.map(m -> (ZhiPuAiAssistantMessage) m.getOutput())
428+
.toList();
429+
430+
String reasoningContent = streamingMessages.stream()
431+
.map(ZhiPuAiAssistantMessage::getReasoningContent)
432+
.filter(StringUtils::hasText)
433+
.collect(Collectors.joining());
434+
435+
String content = streamingMessages.stream()
436+
.map(AssistantMessage::getText)
437+
.filter(StringUtils::hasText)
438+
.collect(Collectors.joining());
439+
440+
logger.info("CoT: {}", reasoningContent);
441+
assertThat(reasoningContent).containsAnyOf("bananas", "apple", "bowl", "basket", "fruit stand");
442+
443+
logger.info("Response: {}", content);
444+
assertThat(content).containsAnyOf("bananas", "apple", "bowl", "basket", "fruit stand");
445+
}
446+
359447
record ActorsFilmsRecord(String actor, List<String> movies) {
360448

361449
}

0 commit comments

Comments
 (0)