Skip to content

Commit cae6364

Browse files
YunKuiLuilayaperumalg
authored andcommitted
feat(zhipuai): ZhipuAI add thinking and response_format parameter support
- Add `thinking` and `response_format` fields to `ZhiPuAiApi` and `ZhiPuAiChatOptions` - Add ZhiPuAiChatOptionsTests with 16 test methods covering all aspects of the class - Test builder pattern with all fields including responseFormat and thinking - Test copy functionality, setters, default values, and equals/hashCode - Test tool callbacks, tool names validation, and collection handling - Test stop sequences alias and fluent setters - Add documentation for response-format.type and thinking.type properties Signed-off-by: YunKui Lu <[email protected]>
1 parent 3e17e16 commit cae6364

File tree

7 files changed

+633
-135
lines changed

7 files changed

+633
-135
lines changed

auto-configurations/models/spring-ai-autoconfigure-model-zhipuai/src/test/java/org/springframework/ai/model/zhipuai/autoconfigure/ZhiPuAiPropertiesTests.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
* {@link ZhiPuAiEmbeddingProperties}.
3838
*
3939
* @author Geng Rong
40+
* @author YunKui Lu
4041
*/
4142
public class ZhiPuAiPropertiesTests {
4243

@@ -243,7 +244,9 @@ public void chatOptionsTest() {
243244
"required": ["location", "lat", "lon", "unit"]
244245
}
245246
""",
246-
"spring.ai.zhipuai.chat.options.user=userXYZ"
247+
"spring.ai.zhipuai.chat.options.user=userXYZ",
248+
"spring.ai.zhipuai.chat.options.response-format.type=json_object",
249+
"spring.ai.zhipuai.chat.options.thinking.type=disabled"
247250
)
248251
// @formatter:on
249252
.withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class,
@@ -262,6 +265,8 @@ public void chatOptionsTest() {
262265
assertThat(chatProperties.getOptions().getTopP()).isEqualTo(0.56);
263266
assertThat(chatProperties.getOptions().getRequestId()).isEqualTo("RequestId");
264267
assertThat(chatProperties.getOptions().getDoSample()).isEqualTo(Boolean.TRUE);
268+
assertThat(chatProperties.getOptions().getResponseFormat().type()).isEqualTo("json_object");
269+
assertThat(chatProperties.getOptions().getThinking().type()).isEqualTo("disabled");
265270

266271
JSONAssert.assertEquals("{\"type\":\"function\",\"function\":{\"name\":\"toolChoiceFunctionName\"}}",
267272
chatProperties.getOptions().getToolChoice(), JSONCompareMode.LENIENT);

models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java

Lines changed: 82 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import java.util.HashSet;
2323
import java.util.List;
2424
import java.util.Map;
25+
import java.util.Objects;
2526
import java.util.Set;
2627

2728
import com.fasterxml.jackson.annotation.JsonIgnore;
@@ -30,9 +31,11 @@
3031
import com.fasterxml.jackson.annotation.JsonProperty;
3132

3233
import org.springframework.ai.chat.prompt.ChatOptions;
34+
import org.springframework.ai.model.ModelOptionsUtils;
3335
import org.springframework.ai.model.tool.ToolCallingChatOptions;
3436
import org.springframework.ai.tool.ToolCallback;
3537
import org.springframework.ai.zhipuai.api.ZhiPuAiApi;
38+
import org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatCompletionRequest;
3639
import org.springframework.lang.Nullable;
3740
import org.springframework.util.Assert;
3841

@@ -42,6 +45,7 @@
4245
* @author Geng Rong
4346
* @author Thomas Vitale
4447
* @author Ilayaperumal Gopinathan
48+
* @author YunKui Lu
4549
* @since 1.0.0 M1
4650
*/
4751
@JsonInclude(Include.NON_NULL)
@@ -104,6 +108,16 @@ public class ZhiPuAiChatOptions implements ToolCallingChatOptions {
104108
*/
105109
private @JsonProperty("do_sample") Boolean doSample;
106110

111+
/**
112+
* Control the format of the model output. Set to `json_object` to ensure the message is a valid JSON object.
113+
*/
114+
private @JsonProperty("response_format") ChatCompletionRequest.ResponseFormat responseFormat;
115+
116+
/**
117+
* Control whether to enable the large model's chain of thought. Available options: (default) enabled, disabled.
118+
*/
119+
private @JsonProperty("thinking") ChatCompletionRequest.Thinking thinking;
120+
107121
/**
108122
* Collection of {@link ToolCallback}s to be used for tool calling in the chat completion requests.
109123
*/
@@ -146,6 +160,8 @@ public static ZhiPuAiChatOptions fromOptions(ZhiPuAiChatOptions fromOptions) {
146160
.toolNames(fromOptions.getToolNames())
147161
.internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled())
148162
.toolContext(fromOptions.getToolContext())
163+
.responseFormat(fromOptions.getResponseFormat())
164+
.thinking(fromOptions.getThinking())
149165
.build();
150166
}
151167

@@ -244,6 +260,24 @@ public void setDoSample(Boolean doSample) {
244260
this.doSample = doSample;
245261
}
246262

263+
public ChatCompletionRequest.ResponseFormat getResponseFormat() {
264+
return this.responseFormat;
265+
}
266+
267+
public ZhiPuAiChatOptions setResponseFormat(ChatCompletionRequest.ResponseFormat responseFormat) {
268+
this.responseFormat = responseFormat;
269+
return this;
270+
}
271+
272+
public ChatCompletionRequest.Thinking getThinking() {
273+
return this.thinking;
274+
}
275+
276+
public ZhiPuAiChatOptions setThinking(ChatCompletionRequest.Thinking thinking) {
277+
this.thinking = thinking;
278+
return this;
279+
}
280+
247281
@Override
248282
@JsonIgnore
249283
public Double getFrequencyPenalty() {
@@ -311,138 +345,53 @@ public Map<String, Object> getToolContext() {
311345

312346
@Override
313347
public void setToolContext(Map<String, Object> toolContext) {
348+
Assert.notNull(toolContext, "toolContext cannot be null");
314349
this.toolContext = toolContext;
315350
}
316351

352+
@Override
353+
public final boolean equals(Object o) {
354+
if (!(o instanceof ZhiPuAiChatOptions that)) {
355+
return false;
356+
}
357+
358+
return Objects.equals(this.model, that.model) && Objects.equals(this.maxTokens, that.maxTokens)
359+
&& Objects.equals(this.stop, that.stop) && Objects.equals(this.temperature, that.temperature)
360+
&& Objects.equals(this.topP, that.topP) && Objects.equals(this.tools, that.tools)
361+
&& Objects.equals(this.toolChoice, that.toolChoice) && Objects.equals(this.user, that.user)
362+
&& Objects.equals(this.requestId, that.requestId) && Objects.equals(this.doSample, that.doSample)
363+
&& Objects.equals(this.responseFormat, that.responseFormat)
364+
&& Objects.equals(this.thinking, that.thinking)
365+
&& Objects.equals(this.toolCallbacks, that.toolCallbacks)
366+
&& Objects.equals(this.toolNames, that.toolNames)
367+
&& Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled)
368+
&& Objects.equals(this.toolContext, that.toolContext);
369+
}
370+
317371
@Override
318372
public int hashCode() {
319-
final int prime = 31;
320-
int result = 1;
321-
result = prime * result + ((this.model == null) ? 0 : this.model.hashCode());
322-
result = prime * result + ((this.maxTokens == null) ? 0 : this.maxTokens.hashCode());
323-
result = prime * result + ((this.stop == null) ? 0 : this.stop.hashCode());
324-
result = prime * result + ((this.temperature == null) ? 0 : this.temperature.hashCode());
325-
result = prime * result + ((this.topP == null) ? 0 : this.topP.hashCode());
326-
result = prime * result + ((this.tools == null) ? 0 : this.tools.hashCode());
327-
result = prime * result + ((this.toolChoice == null) ? 0 : this.toolChoice.hashCode());
328-
result = prime * result + ((this.user == null) ? 0 : this.user.hashCode());
329-
result = prime * result
330-
+ ((this.internalToolExecutionEnabled == null) ? 0 : this.internalToolExecutionEnabled.hashCode());
331-
result = prime * result + ((this.toolCallbacks == null) ? 0 : this.toolCallbacks.hashCode());
332-
result = prime * result + ((this.toolNames == null) ? 0 : this.toolNames.hashCode());
333-
result = prime * result + ((this.toolContext == null) ? 0 : this.toolContext.hashCode());
373+
int result = Objects.hashCode(this.model);
374+
result = 31 * result + Objects.hashCode(this.maxTokens);
375+
result = 31 * result + Objects.hashCode(this.stop);
376+
result = 31 * result + Objects.hashCode(this.temperature);
377+
result = 31 * result + Objects.hashCode(this.topP);
378+
result = 31 * result + Objects.hashCode(this.tools);
379+
result = 31 * result + Objects.hashCode(this.toolChoice);
380+
result = 31 * result + Objects.hashCode(this.user);
381+
result = 31 * result + Objects.hashCode(this.requestId);
382+
result = 31 * result + Objects.hashCode(this.doSample);
383+
result = 31 * result + Objects.hashCode(this.responseFormat);
384+
result = 31 * result + Objects.hashCode(this.thinking);
385+
result = 31 * result + Objects.hashCode(this.toolCallbacks);
386+
result = 31 * result + Objects.hashCode(this.toolNames);
387+
result = 31 * result + Objects.hashCode(this.internalToolExecutionEnabled);
388+
result = 31 * result + Objects.hashCode(this.toolContext);
334389
return result;
335390
}
336391

337392
@Override
338-
public boolean equals(Object obj) {
339-
if (this == obj) {
340-
return true;
341-
}
342-
if (obj == null) {
343-
return false;
344-
}
345-
if (getClass() != obj.getClass()) {
346-
return false;
347-
}
348-
ZhiPuAiChatOptions other = (ZhiPuAiChatOptions) obj;
349-
if (this.model == null) {
350-
if (other.model != null) {
351-
return false;
352-
}
353-
}
354-
else if (!this.model.equals(other.model)) {
355-
return false;
356-
}
357-
if (this.maxTokens == null) {
358-
if (other.maxTokens != null) {
359-
return false;
360-
}
361-
}
362-
else if (!this.maxTokens.equals(other.maxTokens)) {
363-
return false;
364-
}
365-
if (this.stop == null) {
366-
if (other.stop != null) {
367-
return false;
368-
}
369-
}
370-
else if (!this.stop.equals(other.stop)) {
371-
return false;
372-
}
373-
if (this.temperature == null) {
374-
if (other.temperature != null) {
375-
return false;
376-
}
377-
}
378-
else if (!this.temperature.equals(other.temperature)) {
379-
return false;
380-
}
381-
if (this.topP == null) {
382-
if (other.topP != null) {
383-
return false;
384-
}
385-
}
386-
else if (!this.topP.equals(other.topP)) {
387-
return false;
388-
}
389-
if (this.tools == null) {
390-
if (other.tools != null) {
391-
return false;
392-
}
393-
}
394-
else if (!this.tools.equals(other.tools)) {
395-
return false;
396-
}
397-
if (this.toolChoice == null) {
398-
if (other.toolChoice != null) {
399-
return false;
400-
}
401-
}
402-
else if (!this.toolChoice.equals(other.toolChoice)) {
403-
return false;
404-
}
405-
if (this.user == null) {
406-
if (other.user != null) {
407-
return false;
408-
}
409-
}
410-
else if (!this.user.equals(other.user)) {
411-
return false;
412-
}
413-
if (this.requestId == null) {
414-
if (other.requestId != null) {
415-
return false;
416-
}
417-
}
418-
else if (!this.requestId.equals(other.requestId)) {
419-
return false;
420-
}
421-
if (this.doSample == null) {
422-
if (other.doSample != null) {
423-
return false;
424-
}
425-
}
426-
else if (!this.doSample.equals(other.doSample)) {
427-
return false;
428-
}
429-
if (this.internalToolExecutionEnabled == null) {
430-
if (other.internalToolExecutionEnabled != null) {
431-
return false;
432-
}
433-
}
434-
else if (!this.internalToolExecutionEnabled.equals(other.internalToolExecutionEnabled)) {
435-
return false;
436-
}
437-
if (this.toolContext == null) {
438-
if (other.toolContext != null) {
439-
return false;
440-
}
441-
}
442-
else if (!this.toolContext.equals(other.toolContext)) {
443-
return false;
444-
}
445-
return true;
393+
public String toString() {
394+
return "ZhiPuAiChatOptions: " + ModelOptionsUtils.toJsonString(this);
446395
}
447396

448397
@Override
@@ -610,6 +559,16 @@ public Builder toolContext(Map<String, Object> toolContext) {
610559
return this;
611560
}
612561

562+
public Builder responseFormat(ChatCompletionRequest.ResponseFormat responseFormat) {
563+
this.options.responseFormat = responseFormat;
564+
return this;
565+
}
566+
567+
public Builder thinking(ChatCompletionRequest.Thinking thinking) {
568+
this.options.thinking = thinking;
569+
return this;
570+
}
571+
613572
public ZhiPuAiChatOptions build() {
614573
return this.options;
615574
}

models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -672,6 +672,9 @@ public void setJsonSchema(String jsonSchema) {
672672
* logged and can be used for debugging purposes.
673673
* @param doSample If set, the model will use sampling to generate the next token. If
674674
* not set, the model will use greedy decoding to generate the next token.
675+
* @param responseFormat Control the format of the model output. Set to `json_object`
676+
* to ensure the message is a valid JSON object.
677+
* @param thinking Control whether to enable the large model's chain of thought.
675678
*/
676679
@JsonInclude(Include.NON_NULL)
677680
public record ChatCompletionRequest(// @formatter:off
@@ -684,9 +687,11 @@ public record ChatCompletionRequest(// @formatter:off
684687
@JsonProperty("top_p") Double topP,
685688
@JsonProperty("tools") List<FunctionTool> tools,
686689
@JsonProperty("tool_choice") Object toolChoice,
687-
@JsonProperty("user") String user,
690+
@JsonProperty("user_id") String user,
688691
@JsonProperty("request_id") String requestId,
689-
@JsonProperty("do_sample") Boolean doSample) { // @formatter:on
692+
@JsonProperty("do_sample") Boolean doSample,
693+
@JsonProperty("response_format") ResponseFormat responseFormat,
694+
@JsonProperty("thinking") Thinking thinking) { // @formatter:on
690695

691696
/**
692697
* Shortcut constructor for a chat completion request with the given messages and
@@ -696,7 +701,7 @@ public record ChatCompletionRequest(// @formatter:off
696701
* @param temperature What sampling temperature to use, between 0 and 1.
697702
*/
698703
public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, Double temperature) {
699-
this(messages, model, null, null, false, temperature, null, null, null, null, null, null);
704+
this(messages, model, null, null, false, temperature, null, null, null, null, null, null, null, null);
700705
}
701706

702707
/**
@@ -711,7 +716,7 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
711716
*/
712717
public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, Double temperature,
713718
boolean stream) {
714-
this(messages, model, null, null, stream, temperature, null, null, null, null, null, null);
719+
this(messages, model, null, null, stream, temperature, null, null, null, null, null, null, null, null);
715720
}
716721

717722
/**
@@ -726,7 +731,7 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
726731
*/
727732
public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, List<FunctionTool> tools,
728733
Object toolChoice) {
729-
this(messages, model, null, null, false, 0.8, null, tools, toolChoice, null, null, null);
734+
this(messages, model, null, null, false, 0.8, null, tools, toolChoice, null, null, null, null, null);
730735
}
731736

732737
/**
@@ -739,7 +744,7 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
739744
* terminated by a data: [DONE] message.
740745
*/
741746
public ChatCompletionRequest(List<ChatCompletionMessage> messages, Boolean stream) {
742-
this(messages, null, null, null, stream, null, null, null, null, null, null, null);
747+
this(messages, null, null, null, stream, null, null, null, null, null, null, null, null, null);
743748
}
744749

745750
/**
@@ -774,7 +779,32 @@ public static Object function(String functionName) {
774779
*/
775780
@JsonInclude(Include.NON_NULL)
776781
public record ResponseFormat(@JsonProperty("type") String type) {
782+
783+
public static ResponseFormat text() {
784+
return new ResponseFormat("text");
785+
}
786+
787+
public static ResponseFormat jsonObject() {
788+
return new ResponseFormat("json_object");
789+
}
790+
}
791+
792+
/**
793+
* Control whether to enable the large model's chain of thought
794+
*
795+
* @param type Available options: (default) enabled, disabled
796+
*/
797+
@JsonInclude(Include.NON_NULL)
798+
public record Thinking(@JsonProperty("type") String type) {
799+
public static Thinking enabled() {
800+
return new Thinking("enabled");
801+
}
802+
803+
public static Thinking disabled() {
804+
return new Thinking("disabled");
805+
}
777806
}
807+
778808
}
779809

780810
/**

0 commit comments

Comments
 (0)