Skip to content

Commit 7e69016

Browse files
google-genai-botcopybara-github
authored andcommitted
refactor!: Use RxJava for VertexAiClient
In addition to the refactor, use the built in RxJava sleep functionality instead of Thread.sleep(). Also, adding some randomness to the LRO checking on createSession to test out that the LRO logic works. PiperOrigin-RevId: 812825416
1 parent 224552a commit 7e69016

File tree

3 files changed

+243
-191
lines changed

3 files changed

+243
-191
lines changed

core/src/main/java/com/google/adk/sessions/VertexAiClient.java

Lines changed: 112 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,16 @@
99
import com.google.common.base.Splitter;
1010
import com.google.common.collect.Iterables;
1111
import com.google.genai.types.HttpOptions;
12+
import io.reactivex.rxjava3.core.Completable;
13+
import io.reactivex.rxjava3.core.Maybe;
14+
import io.reactivex.rxjava3.core.Single;
1215
import java.io.IOException;
1316
import java.io.UncheckedIOException;
1417
import java.util.List;
1518
import java.util.Optional;
1619
import java.util.concurrent.ConcurrentHashMap;
1720
import java.util.concurrent.ConcurrentMap;
21+
import java.util.concurrent.TimeoutException;
1822
import javax.annotation.Nullable;
1923
import okhttp3.ResponseBody;
2024
import org.slf4j.Logger;
@@ -46,104 +50,124 @@ final class VertexAiClient {
4650
new HttpApiClient(Optional.of(project), Optional.of(location), credentials, httpOptions);
4751
}
4852

49-
@Nullable
50-
JsonNode createSession(
53+
Maybe<JsonNode> createSession(
5154
String reasoningEngineId, String userId, ConcurrentMap<String, Object> state) {
5255
ConcurrentHashMap<String, Object> sessionJsonMap = new ConcurrentHashMap<>();
5356
sessionJsonMap.put("userId", userId);
5457
if (state != null) {
5558
sessionJsonMap.put("sessionState", state);
5659
}
5760

58-
String sessId;
59-
String operationId;
60-
try {
61-
String sessionJson = objectMapper.writeValueAsString(sessionJsonMap);
62-
try (ApiResponse apiResponse =
63-
apiClient.request(
64-
"POST", "reasoningEngines/" + reasoningEngineId + "/sessions", sessionJson)) {
65-
logger.debug("Create Session response {}", apiResponse.getResponseBody());
66-
if (apiResponse == null || apiResponse.getResponseBody() == null) {
67-
return null;
68-
}
69-
70-
JsonNode jsonResponse = getJsonResponse(apiResponse);
71-
if (jsonResponse == null) {
72-
return null;
73-
}
74-
String sessionName = jsonResponse.get("name").asText();
75-
List<String> parts = Splitter.on('/').splitToList(sessionName);
76-
sessId = parts.get(parts.size() - 3);
77-
operationId = Iterables.getLast(parts);
78-
}
79-
} catch (IOException e) {
80-
throw new UncheckedIOException(e);
81-
}
61+
return Single.fromCallable(() -> objectMapper.writeValueAsString(sessionJsonMap))
62+
.flatMap(
63+
sessionJson ->
64+
performApiRequest(
65+
"POST", "reasoningEngines/" + reasoningEngineId + "/sessions", sessionJson))
66+
.flatMapMaybe(
67+
apiResponse -> {
68+
logger.debug("Create Session response {}", apiResponse.getResponseBody());
69+
return getJsonResponse(apiResponse);
70+
})
71+
.flatMap(
72+
jsonResponse -> {
73+
String sessionName = jsonResponse.get("name").asText();
74+
List<String> parts = Splitter.on('/').splitToList(sessionName);
75+
String sessId = parts.get(parts.size() - 3);
76+
String operationId = Iterables.getLast(parts);
77+
78+
return pollOperation(operationId, 0).andThen(getSession(reasoningEngineId, sessId));
79+
});
80+
}
8281

83-
for (int i = 0; i < MAX_RETRY_ATTEMPTS; i++) {
84-
try (ApiResponse lroResponse = apiClient.request("GET", "operations/" + operationId, "")) {
85-
JsonNode lroJsonResponse = getJsonResponse(lroResponse);
86-
if (lroJsonResponse != null && lroJsonResponse.get("done") != null) {
87-
break;
88-
}
89-
}
90-
try {
91-
SECONDS.sleep(1);
92-
} catch (InterruptedException e) {
93-
logger.warn("Error during sleep", e);
94-
Thread.currentThread().interrupt();
95-
}
82+
/**
83+
* Polls the status of a long-running operation.
84+
*
85+
* @param operationId The ID of the operation to poll.
86+
* @param attempt The current retry attempt number (starting from 0).
87+
* @return A Completable that completes when the operation is done, or errors with
88+
* TimeoutException if max retries are exceeded.
89+
*/
90+
private Completable pollOperation(String operationId, int attempt) {
91+
if (attempt >= MAX_RETRY_ATTEMPTS) {
92+
return Completable.error(
93+
new TimeoutException("Operation " + operationId + " did not complete in time."));
9694
}
97-
return getSession(reasoningEngineId, sessId);
95+
return performApiRequest("GET", "operations/" + operationId, "")
96+
.flatMapMaybe(VertexAiClient::getJsonResponse)
97+
.flatMapCompletable(
98+
lroJsonResponse -> {
99+
if (lroJsonResponse != null && lroJsonResponse.get("done") != null) {
100+
return Completable.complete(); // Operation is done
101+
} else {
102+
// Not done, retry after a delay
103+
return Completable.timer(1, SECONDS)
104+
.andThen(pollOperation(operationId, attempt + 1));
105+
}
106+
});
98107
}
99108

100-
JsonNode listSessions(String reasoningEngineId, String userId) {
101-
try (ApiResponse apiResponse =
102-
apiClient.request(
109+
Maybe<JsonNode> listSessions(String reasoningEngineId, String userId) {
110+
return performApiRequest(
103111
"GET",
104112
"reasoningEngines/" + reasoningEngineId + "/sessions?filter=user_id=" + userId,
105-
"")) {
106-
return getJsonResponse(apiResponse);
107-
}
113+
"")
114+
.flatMapMaybe(VertexAiClient::getJsonResponse);
108115
}
109116

110-
JsonNode listEvents(String reasoningEngineId, String sessionId) {
111-
try (ApiResponse apiResponse =
112-
apiClient.request(
117+
Maybe<JsonNode> listEvents(String reasoningEngineId, String sessionId) {
118+
return performApiRequest(
113119
"GET",
114120
"reasoningEngines/" + reasoningEngineId + "/sessions/" + sessionId + "/events",
115-
"")) {
116-
logger.debug("List events response {}", apiResponse);
117-
return getJsonResponse(apiResponse);
118-
}
121+
"")
122+
.doOnSuccess(apiResponse -> logger.debug("List events response {}", apiResponse))
123+
.flatMapMaybe(VertexAiClient::getJsonResponse);
119124
}
120125

121-
JsonNode getSession(String reasoningEngineId, String sessionId) {
122-
try (ApiResponse apiResponse =
123-
apiClient.request(
124-
"GET", "reasoningEngines/" + reasoningEngineId + "/sessions/" + sessionId, "")) {
125-
return getJsonResponse(apiResponse);
126-
}
126+
Maybe<JsonNode> getSession(String reasoningEngineId, String sessionId) {
127+
return performApiRequest(
128+
"GET", "reasoningEngines/" + reasoningEngineId + "/sessions/" + sessionId, "")
129+
.flatMapMaybe(apiResponse -> getJsonResponse(apiResponse));
127130
}
128131

129-
void deleteSession(String reasoningEngineId, String sessionId) {
130-
try (ApiResponse response =
131-
apiClient.request(
132-
"DELETE", "reasoningEngines/" + reasoningEngineId + "/sessions/" + sessionId, "")) {}
132+
Completable deleteSession(String reasoningEngineId, String sessionId) {
133+
return performApiRequest(
134+
"DELETE", "reasoningEngines/" + reasoningEngineId + "/sessions/" + sessionId, "")
135+
.doOnSuccess(ApiResponse::close)
136+
.ignoreElement();
133137
}
134138

135-
void appendEvent(String reasoningEngineId, String sessionId, String eventJson) {
136-
try (ApiResponse response =
137-
apiClient.request(
139+
Completable appendEvent(String reasoningEngineId, String sessionId, String eventJson) {
140+
return performApiRequest(
138141
"POST",
139142
"reasoningEngines/" + reasoningEngineId + "/sessions/" + sessionId + ":appendEvent",
140-
eventJson)) {
141-
if (response.getResponseBody().string().contains("com.google.genai.errors.ClientException")) {
142-
logger.warn("Failed to append event: {}", eventJson);
143-
}
144-
} catch (IOException e) {
145-
throw new UncheckedIOException(e);
146-
}
143+
eventJson)
144+
.flatMapCompletable(
145+
response -> {
146+
try (response) {
147+
ResponseBody responseBody = response.getResponseBody();
148+
if (responseBody != null) {
149+
String responseString = responseBody.string();
150+
if (responseString.contains("com.google.genai.errors.ClientException")) {
151+
logger.warn("Failed to append event: {}", eventJson);
152+
}
153+
}
154+
return Completable.complete();
155+
} catch (IOException e) {
156+
return Completable.error(new UncheckedIOException(e));
157+
}
158+
});
159+
}
160+
161+
/**
162+
* Performs an API request and returns a Single emitting the ApiResponse.
163+
*
164+
* <p>Note: The caller is responsible for closing the returned {@link ApiResponse}.
165+
*/
166+
private Single<ApiResponse> performApiRequest(String method, String path, String body) {
167+
return Single.fromCallable(
168+
() -> {
169+
return apiClient.request(method, path, body);
170+
});
147171
}
148172

149173
/**
@@ -152,19 +176,23 @@ void appendEvent(String reasoningEngineId, String sessionId, String eventJson) {
152176
* @throws UncheckedIOException if parsing fails.
153177
*/
154178
@Nullable
155-
private static JsonNode getJsonResponse(ApiResponse apiResponse) {
156-
if (apiResponse == null || apiResponse.getResponseBody() == null) {
157-
return null;
158-
}
179+
private static Maybe<JsonNode> getJsonResponse(ApiResponse apiResponse) {
159180
try {
160-
ResponseBody responseBody = apiResponse.getResponseBody();
161-
String responseString = responseBody.string();
162-
if (responseString.isEmpty()) {
163-
return null;
181+
if (apiResponse == null || apiResponse.getResponseBody() == null) {
182+
return Maybe.empty();
183+
}
184+
try {
185+
ResponseBody responseBody = apiResponse.getResponseBody();
186+
String responseString = responseBody.string(); // Read body here
187+
if (responseString.isEmpty()) {
188+
return Maybe.empty();
189+
}
190+
return Maybe.just(objectMapper.readTree(responseString));
191+
} catch (IOException e) {
192+
return Maybe.error(new UncheckedIOException(e));
164193
}
165-
return objectMapper.readTree(responseString);
166-
} catch (IOException e) {
167-
throw new UncheckedIOException(e);
194+
} finally {
195+
apiResponse.close();
168196
}
169197
}
170198
}

0 commit comments

Comments
 (0)