|
22 | 22 | import java.util.HashSet; |
23 | 23 | import java.util.List; |
24 | 24 | import java.util.Map; |
| 25 | +import java.util.Objects; |
25 | 26 | import java.util.Set; |
26 | 27 |
|
27 | 28 | import com.fasterxml.jackson.annotation.JsonIgnore; |
|
30 | 31 | import com.fasterxml.jackson.annotation.JsonProperty; |
31 | 32 |
|
32 | 33 | import org.springframework.ai.chat.prompt.ChatOptions; |
| 34 | +import org.springframework.ai.model.ModelOptionsUtils; |
33 | 35 | import org.springframework.ai.model.tool.ToolCallingChatOptions; |
34 | 36 | import org.springframework.ai.tool.ToolCallback; |
35 | 37 | import org.springframework.ai.zhipuai.api.ZhiPuAiApi; |
| 38 | +import org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatCompletionRequest; |
36 | 39 | import org.springframework.lang.Nullable; |
37 | 40 | import org.springframework.util.Assert; |
38 | 41 |
|
|
42 | 45 | * @author Geng Rong |
43 | 46 | * @author Thomas Vitale |
44 | 47 | * @author Ilayaperumal Gopinathan |
| 48 | + * @author YunKui Lu |
45 | 49 | * @since 1.0.0 M1 |
46 | 50 | */ |
47 | 51 | @JsonInclude(Include.NON_NULL) |
@@ -104,6 +108,16 @@ public class ZhiPuAiChatOptions implements ToolCallingChatOptions { |
104 | 108 | */ |
105 | 109 | private @JsonProperty("do_sample") Boolean doSample; |
106 | 110 |
|
| 111 | + /** |
| 112 | + * Control the format of the model output. Set to `json_object` to ensure the message is a valid JSON object. |
| 113 | + */ |
| 114 | + private @JsonProperty("response_format") ChatCompletionRequest.ResponseFormat responseFormat; |
| 115 | + |
| 116 | + /** |
| 117 | + * Control whether to enable the large model's chain of thought. Available options: (default) enabled, disabled. |
| 118 | + */ |
| 119 | + private @JsonProperty("thinking") ChatCompletionRequest.Thinking thinking; |
| 120 | + |
107 | 121 | /** |
108 | 122 | * Collection of {@link ToolCallback}s to be used for tool calling in the chat completion requests. |
109 | 123 | */ |
@@ -146,6 +160,8 @@ public static ZhiPuAiChatOptions fromOptions(ZhiPuAiChatOptions fromOptions) { |
146 | 160 | .toolNames(fromOptions.getToolNames()) |
147 | 161 | .internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled()) |
148 | 162 | .toolContext(fromOptions.getToolContext()) |
| 163 | + .responseFormat(fromOptions.getResponseFormat()) |
| 164 | + .thinking(fromOptions.getThinking()) |
149 | 165 | .build(); |
150 | 166 | } |
151 | 167 |
|
@@ -244,6 +260,24 @@ public void setDoSample(Boolean doSample) { |
244 | 260 | this.doSample = doSample; |
245 | 261 | } |
246 | 262 |
|
| 263 | + public ChatCompletionRequest.ResponseFormat getResponseFormat() { |
| 264 | + return this.responseFormat; |
| 265 | + } |
| 266 | + |
| 267 | + public ZhiPuAiChatOptions setResponseFormat(ChatCompletionRequest.ResponseFormat responseFormat) { |
| 268 | + this.responseFormat = responseFormat; |
| 269 | + return this; |
| 270 | + } |
| 271 | + |
| 272 | + public ChatCompletionRequest.Thinking getThinking() { |
| 273 | + return this.thinking; |
| 274 | + } |
| 275 | + |
| 276 | + public ZhiPuAiChatOptions setThinking(ChatCompletionRequest.Thinking thinking) { |
| 277 | + this.thinking = thinking; |
| 278 | + return this; |
| 279 | + } |
| 280 | + |
247 | 281 | @Override |
248 | 282 | @JsonIgnore |
249 | 283 | public Double getFrequencyPenalty() { |
@@ -311,138 +345,53 @@ public Map<String, Object> getToolContext() { |
311 | 345 |
|
312 | 346 | @Override |
313 | 347 | public void setToolContext(Map<String, Object> toolContext) { |
| 348 | + Assert.notNull(toolContext, "toolContext cannot be null"); |
314 | 349 | this.toolContext = toolContext; |
315 | 350 | } |
316 | 351 |
|
| 352 | + @Override |
| 353 | + public final boolean equals(Object o) { |
| 354 | + if (!(o instanceof ZhiPuAiChatOptions that)) { |
| 355 | + return false; |
| 356 | + } |
| 357 | + |
| 358 | + return Objects.equals(this.model, that.model) && Objects.equals(this.maxTokens, that.maxTokens) |
| 359 | + && Objects.equals(this.stop, that.stop) && Objects.equals(this.temperature, that.temperature) |
| 360 | + && Objects.equals(this.topP, that.topP) && Objects.equals(this.tools, that.tools) |
| 361 | + && Objects.equals(this.toolChoice, that.toolChoice) && Objects.equals(this.user, that.user) |
| 362 | + && Objects.equals(this.requestId, that.requestId) && Objects.equals(this.doSample, that.doSample) |
| 363 | + && Objects.equals(this.responseFormat, that.responseFormat) |
| 364 | + && Objects.equals(this.thinking, that.thinking) |
| 365 | + && Objects.equals(this.toolCallbacks, that.toolCallbacks) |
| 366 | + && Objects.equals(this.toolNames, that.toolNames) |
| 367 | + && Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled) |
| 368 | + && Objects.equals(this.toolContext, that.toolContext); |
| 369 | + } |
| 370 | + |
317 | 371 | @Override |
318 | 372 | public int hashCode() { |
319 | | - final int prime = 31; |
320 | | - int result = 1; |
321 | | - result = prime * result + ((this.model == null) ? 0 : this.model.hashCode()); |
322 | | - result = prime * result + ((this.maxTokens == null) ? 0 : this.maxTokens.hashCode()); |
323 | | - result = prime * result + ((this.stop == null) ? 0 : this.stop.hashCode()); |
324 | | - result = prime * result + ((this.temperature == null) ? 0 : this.temperature.hashCode()); |
325 | | - result = prime * result + ((this.topP == null) ? 0 : this.topP.hashCode()); |
326 | | - result = prime * result + ((this.tools == null) ? 0 : this.tools.hashCode()); |
327 | | - result = prime * result + ((this.toolChoice == null) ? 0 : this.toolChoice.hashCode()); |
328 | | - result = prime * result + ((this.user == null) ? 0 : this.user.hashCode()); |
329 | | - result = prime * result |
330 | | - + ((this.internalToolExecutionEnabled == null) ? 0 : this.internalToolExecutionEnabled.hashCode()); |
331 | | - result = prime * result + ((this.toolCallbacks == null) ? 0 : this.toolCallbacks.hashCode()); |
332 | | - result = prime * result + ((this.toolNames == null) ? 0 : this.toolNames.hashCode()); |
333 | | - result = prime * result + ((this.toolContext == null) ? 0 : this.toolContext.hashCode()); |
| 373 | + int result = Objects.hashCode(this.model); |
| 374 | + result = 31 * result + Objects.hashCode(this.maxTokens); |
| 375 | + result = 31 * result + Objects.hashCode(this.stop); |
| 376 | + result = 31 * result + Objects.hashCode(this.temperature); |
| 377 | + result = 31 * result + Objects.hashCode(this.topP); |
| 378 | + result = 31 * result + Objects.hashCode(this.tools); |
| 379 | + result = 31 * result + Objects.hashCode(this.toolChoice); |
| 380 | + result = 31 * result + Objects.hashCode(this.user); |
| 381 | + result = 31 * result + Objects.hashCode(this.requestId); |
| 382 | + result = 31 * result + Objects.hashCode(this.doSample); |
| 383 | + result = 31 * result + Objects.hashCode(this.responseFormat); |
| 384 | + result = 31 * result + Objects.hashCode(this.thinking); |
| 385 | + result = 31 * result + Objects.hashCode(this.toolCallbacks); |
| 386 | + result = 31 * result + Objects.hashCode(this.toolNames); |
| 387 | + result = 31 * result + Objects.hashCode(this.internalToolExecutionEnabled); |
| 388 | + result = 31 * result + Objects.hashCode(this.toolContext); |
334 | 389 | return result; |
335 | 390 | } |
336 | 391 |
|
337 | 392 | @Override |
338 | | - public boolean equals(Object obj) { |
339 | | - if (this == obj) { |
340 | | - return true; |
341 | | - } |
342 | | - if (obj == null) { |
343 | | - return false; |
344 | | - } |
345 | | - if (getClass() != obj.getClass()) { |
346 | | - return false; |
347 | | - } |
348 | | - ZhiPuAiChatOptions other = (ZhiPuAiChatOptions) obj; |
349 | | - if (this.model == null) { |
350 | | - if (other.model != null) { |
351 | | - return false; |
352 | | - } |
353 | | - } |
354 | | - else if (!this.model.equals(other.model)) { |
355 | | - return false; |
356 | | - } |
357 | | - if (this.maxTokens == null) { |
358 | | - if (other.maxTokens != null) { |
359 | | - return false; |
360 | | - } |
361 | | - } |
362 | | - else if (!this.maxTokens.equals(other.maxTokens)) { |
363 | | - return false; |
364 | | - } |
365 | | - if (this.stop == null) { |
366 | | - if (other.stop != null) { |
367 | | - return false; |
368 | | - } |
369 | | - } |
370 | | - else if (!this.stop.equals(other.stop)) { |
371 | | - return false; |
372 | | - } |
373 | | - if (this.temperature == null) { |
374 | | - if (other.temperature != null) { |
375 | | - return false; |
376 | | - } |
377 | | - } |
378 | | - else if (!this.temperature.equals(other.temperature)) { |
379 | | - return false; |
380 | | - } |
381 | | - if (this.topP == null) { |
382 | | - if (other.topP != null) { |
383 | | - return false; |
384 | | - } |
385 | | - } |
386 | | - else if (!this.topP.equals(other.topP)) { |
387 | | - return false; |
388 | | - } |
389 | | - if (this.tools == null) { |
390 | | - if (other.tools != null) { |
391 | | - return false; |
392 | | - } |
393 | | - } |
394 | | - else if (!this.tools.equals(other.tools)) { |
395 | | - return false; |
396 | | - } |
397 | | - if (this.toolChoice == null) { |
398 | | - if (other.toolChoice != null) { |
399 | | - return false; |
400 | | - } |
401 | | - } |
402 | | - else if (!this.toolChoice.equals(other.toolChoice)) { |
403 | | - return false; |
404 | | - } |
405 | | - if (this.user == null) { |
406 | | - if (other.user != null) { |
407 | | - return false; |
408 | | - } |
409 | | - } |
410 | | - else if (!this.user.equals(other.user)) { |
411 | | - return false; |
412 | | - } |
413 | | - if (this.requestId == null) { |
414 | | - if (other.requestId != null) { |
415 | | - return false; |
416 | | - } |
417 | | - } |
418 | | - else if (!this.requestId.equals(other.requestId)) { |
419 | | - return false; |
420 | | - } |
421 | | - if (this.doSample == null) { |
422 | | - if (other.doSample != null) { |
423 | | - return false; |
424 | | - } |
425 | | - } |
426 | | - else if (!this.doSample.equals(other.doSample)) { |
427 | | - return false; |
428 | | - } |
429 | | - if (this.internalToolExecutionEnabled == null) { |
430 | | - if (other.internalToolExecutionEnabled != null) { |
431 | | - return false; |
432 | | - } |
433 | | - } |
434 | | - else if (!this.internalToolExecutionEnabled.equals(other.internalToolExecutionEnabled)) { |
435 | | - return false; |
436 | | - } |
437 | | - if (this.toolContext == null) { |
438 | | - if (other.toolContext != null) { |
439 | | - return false; |
440 | | - } |
441 | | - } |
442 | | - else if (!this.toolContext.equals(other.toolContext)) { |
443 | | - return false; |
444 | | - } |
445 | | - return true; |
| 393 | + public String toString() { |
| 394 | + return "ZhiPuAiChatOptions: " + ModelOptionsUtils.toJsonString(this); |
446 | 395 | } |
447 | 396 |
|
448 | 397 | @Override |
@@ -610,6 +559,16 @@ public Builder toolContext(Map<String, Object> toolContext) { |
610 | 559 | return this; |
611 | 560 | } |
612 | 561 |
|
| 562 | + public Builder responseFormat(ChatCompletionRequest.ResponseFormat responseFormat) { |
| 563 | + this.options.responseFormat = responseFormat; |
| 564 | + return this; |
| 565 | + } |
| 566 | + |
| 567 | + public Builder thinking(ChatCompletionRequest.Thinking thinking) { |
| 568 | + this.options.thinking = thinking; |
| 569 | + return this; |
| 570 | + } |
| 571 | + |
613 | 572 | public ZhiPuAiChatOptions build() { |
614 | 573 | return this.options; |
615 | 574 | } |
|
0 commit comments