Skip to content

Commit be79b89

Browse files
filiphrilayaperumalg
authored andcommitted
Resolve OpenAI ApiKey for every request
Signed-off-by: Filip Hrisafov <[email protected]>
1 parent cd364e3 commit be79b89

File tree

8 files changed

+752
-20
lines changed

8 files changed

+752
-20
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
* @author Thomas Vitale
6363
* @author David Frizelle
6464
* @author Alexandros Pappas
65+
* @author Filip Hrisafov
6566
*/
6667
public class OpenAiApi {
6768

@@ -128,22 +129,28 @@ public OpenAiApi(String baseUrl, ApiKey apiKey, MultiValueMap<String, String> he
128129

129130
// @formatter:off
130131
Consumer<HttpHeaders> finalHeaders = h -> {
131-
if (!(apiKey instanceof NoopApiKey)) {
132-
h.setBearerAuth(apiKey.getValue());
133-
}
134-
135132
h.setContentType(MediaType.APPLICATION_JSON);
136133
h.addAll(headers);
137134
};
138135
this.restClient = restClientBuilder.clone()
139136
.baseUrl(baseUrl)
140137
.defaultHeaders(finalHeaders)
141138
.defaultStatusHandler(responseErrorHandler)
139+
.defaultRequest(requestHeadersSpec -> {
140+
if (!(apiKey instanceof NoopApiKey)) {
141+
requestHeadersSpec.header(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey.getValue());
142+
}
143+
})
142144
.build();
143145

144146
this.webClient = webClientBuilder.clone()
145147
.baseUrl(baseUrl)
146148
.defaultHeaders(finalHeaders)
149+
.defaultRequest(requestHeadersSpec -> {
150+
if (!(apiKey instanceof NoopApiKey)) {
151+
requestHeadersSpec.header(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey.getValue());
152+
}
153+
})
147154
.build(); // @formatter:on
148155
}
149156

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiAudioApi.java

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
* @author Christian Tzolov
5050
* @author Ilayaperumal Gopinathan
5151
* @author Jonghoon Park
52+
* @author Filip Hrisafov
5253
* @since 0.8.1
5354
*/
5455
public class OpenAiAudioApi {
@@ -71,20 +72,30 @@ public OpenAiAudioApi(String baseUrl, ApiKey apiKey, MultiValueMap<String, Strin
7172
ResponseErrorHandler responseErrorHandler) {
7273

7374
Consumer<HttpHeaders> authHeaders = h -> {
74-
if (!(apiKey instanceof NoopApiKey)) {
75-
h.setBearerAuth(apiKey.getValue());
76-
}
7775
h.addAll(headers);
78-
// h.setContentType(MediaType.APPLICATION_JSON);
7976
};
8077

78+
// @formatter:off
8179
this.restClient = restClientBuilder.clone()
8280
.baseUrl(baseUrl)
8381
.defaultHeaders(authHeaders)
8482
.defaultStatusHandler(responseErrorHandler)
83+
.defaultRequest(requestHeadersSpec -> {
84+
if (!(apiKey instanceof NoopApiKey)) {
85+
requestHeadersSpec.header(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey.getValue());
86+
}
87+
})
8588
.build();
8689

87-
this.webClient = webClientBuilder.clone().baseUrl(baseUrl).defaultHeaders(authHeaders).build();
90+
this.webClient = webClientBuilder.clone()
91+
.baseUrl(baseUrl)
92+
.defaultHeaders(authHeaders)
93+
.defaultRequest(requestHeadersSpec -> {
94+
if (!(apiKey instanceof NoopApiKey)) {
95+
requestHeadersSpec.header(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey.getValue());
96+
}
97+
})
98+
.build(); // @formatter:on
8899
}
89100

90101
public static Builder builder() {

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageApi.java

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.springframework.ai.model.SimpleApiKey;
2828
import org.springframework.ai.openai.api.common.OpenAiApiConstants;
2929
import org.springframework.ai.retry.RetryUtils;
30+
import org.springframework.http.HttpHeaders;
3031
import org.springframework.http.MediaType;
3132
import org.springframework.http.ResponseEntity;
3233
import org.springframework.util.Assert;
@@ -40,6 +41,7 @@
4041
*
4142
* @see <a href= "https://platform.openai.com/docs/api-reference/images">Images</a>
4243
* @author lambochen
44+
* @author Filip Hrisafov
4345
*/
4446
public class OpenAiImageApi {
4547

@@ -62,15 +64,18 @@ public OpenAiImageApi(String baseUrl, ApiKey apiKey, MultiValueMap<String, Strin
6264
RestClient.Builder restClientBuilder, ResponseErrorHandler responseErrorHandler) {
6365

6466
// @formatter:off
65-
this.restClient = restClientBuilder.baseUrl(baseUrl)
67+
this.restClient = restClientBuilder.clone()
68+
.baseUrl(baseUrl)
6669
.defaultHeaders(h -> {
67-
if (!(apiKey instanceof NoopApiKey)) {
68-
h.setBearerAuth(apiKey.getValue());
69-
}
7070
h.setContentType(MediaType.APPLICATION_JSON);
7171
h.addAll(headers);
7272
})
7373
.defaultStatusHandler(responseErrorHandler)
74+
.defaultRequest(requestHeadersSpec -> {
75+
if (!(apiKey instanceof NoopApiKey)) {
76+
requestHeadersSpec.header(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey.getValue());
77+
}
78+
})
7479
.build();
7580
// @formatter:on
7681

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiModerationApi.java

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.springframework.ai.model.SimpleApiKey;
2828
import org.springframework.ai.openai.api.common.OpenAiApiConstants;
2929
import org.springframework.ai.retry.RetryUtils;
30+
import org.springframework.http.HttpHeaders;
3031
import org.springframework.http.MediaType;
3132
import org.springframework.http.ResponseEntity;
3233
import org.springframework.util.Assert;
@@ -40,6 +41,7 @@
4041
*
4142
* @author Ahmed Yousri
4243
* @author Ilayaperumal Gopinathan
44+
* @author Filip Hrisafov
4345
* @see <a href=
4446
* "https://platform.openai.com/docs/api-reference/moderations">https://platform.openai.com/docs/api-reference/moderations</a>
4547
*/
@@ -64,13 +66,20 @@ public OpenAiModerationApi(String baseUrl, ApiKey apiKey, MultiValueMap<String,
6466

6567
this.objectMapper = new ObjectMapper().configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
6668

67-
this.restClient = restClientBuilder.baseUrl(baseUrl).defaultHeaders(h -> {
68-
if (!(apiKey instanceof NoopApiKey)) {
69-
h.setBearerAuth(apiKey.getValue());
70-
}
71-
h.setContentType(MediaType.APPLICATION_JSON);
72-
h.addAll(headers);
73-
}).defaultStatusHandler(responseErrorHandler).build();
69+
// @formatter:off
70+
this.restClient = restClientBuilder.clone()
71+
.baseUrl(baseUrl)
72+
.defaultHeaders(h -> {
73+
h.setContentType(MediaType.APPLICATION_JSON);
74+
h.addAll(headers);
75+
})
76+
.defaultStatusHandler(responseErrorHandler)
77+
.defaultRequest(requestHeadersSpec -> {
78+
if (!(apiKey instanceof NoopApiKey)) {
79+
requestHeadersSpec.header(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey.getValue());
80+
}
81+
})
82+
.build(); // @formatter:on
7483
}
7584

7685
public ResponseEntity<OpenAiModerationResponse> createModeration(OpenAiModerationRequest openAiModerationRequest) {

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiBuilderTests.java

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,27 @@
1616

1717
package org.springframework.ai.openai.api;
1818

19+
import java.io.IOException;
20+
import java.util.LinkedList;
21+
import java.util.List;
22+
import java.util.Objects;
23+
import java.util.Queue;
24+
25+
import okhttp3.mockwebserver.MockResponse;
26+
import okhttp3.mockwebserver.MockWebServer;
27+
import okhttp3.mockwebserver.RecordedRequest;
28+
29+
import org.junit.jupiter.api.AfterEach;
30+
import org.junit.jupiter.api.BeforeEach;
31+
import org.junit.jupiter.api.Nested;
1932
import org.junit.jupiter.api.Test;
2033

2134
import org.springframework.ai.model.ApiKey;
2235
import org.springframework.ai.model.SimpleApiKey;
36+
import org.springframework.http.HttpHeaders;
37+
import org.springframework.http.HttpStatus;
38+
import org.springframework.http.MediaType;
39+
import org.springframework.http.ResponseEntity;
2340
import org.springframework.util.LinkedMultiValueMap;
2441
import org.springframework.util.MultiValueMap;
2542
import org.springframework.web.client.ResponseErrorHandler;
@@ -142,4 +159,126 @@ void testInvalidResponseErrorHandler() {
142159
.hasMessageContaining("responseErrorHandler cannot be null");
143160
}
144161

162+
@Nested
163+
class MockRequests {
164+
165+
MockWebServer mockWebServer;
166+
167+
@BeforeEach
168+
void setUp() throws IOException {
169+
mockWebServer = new MockWebServer();
170+
mockWebServer.start();
171+
}
172+
173+
@AfterEach
174+
void tearDown() throws IOException {
175+
mockWebServer.shutdown();
176+
}
177+
178+
@Test
179+
void dynamicApiKeyRestClient() throws InterruptedException {
180+
Queue<ApiKey> apiKeys = new LinkedList<>(List.of(new SimpleApiKey("key1"), new SimpleApiKey("key2")));
181+
OpenAiApi api = OpenAiApi.builder()
182+
.apiKey(() -> Objects.requireNonNull(apiKeys.poll()).getValue())
183+
.baseUrl(mockWebServer.url("/").toString())
184+
.build();
185+
186+
MockResponse mockResponse = new MockResponse().setResponseCode(200)
187+
.addHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
188+
.setBody("""
189+
{
190+
"id": "chatcmpl-12345",
191+
"object": "chat.completion",
192+
"created": 1677858242,
193+
"model": "gpt-3.5-turbo",
194+
"choices": [
195+
{
196+
"index": 0,
197+
"message": {
198+
"role": "assistant",
199+
"content": "Hello world"
200+
},
201+
"finish_reason": "stop"
202+
}
203+
],
204+
"usage": {
205+
"prompt_tokens": 10,
206+
"completion_tokens": 5,
207+
"total_tokens": 15
208+
}
209+
}
210+
""");
211+
mockWebServer.enqueue(mockResponse);
212+
mockWebServer.enqueue(mockResponse);
213+
214+
OpenAiApi.ChatCompletionMessage chatCompletionMessage = new OpenAiApi.ChatCompletionMessage("Hello world",
215+
OpenAiApi.ChatCompletionMessage.Role.USER);
216+
OpenAiApi.ChatCompletionRequest request = new OpenAiApi.ChatCompletionRequest(
217+
List.of(chatCompletionMessage), "gpt-3.5-turbo", 0.8, false);
218+
ResponseEntity<OpenAiApi.ChatCompletion> response = api.chatCompletionEntity(request);
219+
assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK);
220+
RecordedRequest recordedRequest = mockWebServer.takeRequest();
221+
assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key1");
222+
223+
response = api.chatCompletionEntity(request);
224+
assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK);
225+
226+
recordedRequest = mockWebServer.takeRequest();
227+
assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key2");
228+
}
229+
230+
@Test
231+
void dynamicApiKeyWebClient() throws InterruptedException {
232+
Queue<ApiKey> apiKeys = new LinkedList<>(List.of(new SimpleApiKey("key1"), new SimpleApiKey("key2")));
233+
OpenAiApi api = OpenAiApi.builder()
234+
.apiKey(() -> Objects.requireNonNull(apiKeys.poll()).getValue())
235+
.baseUrl(mockWebServer.url("/").toString())
236+
.build();
237+
238+
MockResponse mockResponse = new MockResponse().setResponseCode(200)
239+
.addHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
240+
.setBody("""
241+
{
242+
"id": "chatcmpl-12345",
243+
"object": "chat.completion",
244+
"created": 1677858242,
245+
"model": "gpt-3.5-turbo",
246+
"choices": [
247+
{
248+
"index": 0,
249+
"message": {
250+
"role": "assistant",
251+
"content": "Hello world"
252+
},
253+
"finish_reason": "stop"
254+
}
255+
],
256+
"usage": {
257+
"prompt_tokens": 10,
258+
"completion_tokens": 5,
259+
"total_tokens": 15
260+
}
261+
}
262+
""".replace("\n", ""));
263+
mockWebServer.enqueue(mockResponse);
264+
mockWebServer.enqueue(mockResponse);
265+
266+
OpenAiApi.ChatCompletionMessage chatCompletionMessage = new OpenAiApi.ChatCompletionMessage("Hello world",
267+
OpenAiApi.ChatCompletionMessage.Role.USER);
268+
OpenAiApi.ChatCompletionRequest request = new OpenAiApi.ChatCompletionRequest(
269+
List.of(chatCompletionMessage), "gpt-3.5-turbo", 0.8, true);
270+
List<OpenAiApi.ChatCompletionChunk> response = api.chatCompletionStream(request).collectList().block();
271+
assertThat(response).hasSize(1);
272+
RecordedRequest recordedRequest = mockWebServer.takeRequest();
273+
assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key1");
274+
275+
response = api.chatCompletionStream(request).collectList().block();
276+
assertThat(response).hasSize(1);
277+
278+
recordedRequest = mockWebServer.takeRequest();
279+
assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key2");
280+
}
281+
282+
}
283+
145284
}

0 commit comments

Comments
 (0)