Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions genai/snippets/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@
<scope>test</scope>
<version>4.13.2</version>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<version>5.19.0</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.google.truth</groupId>
<artifactId>truth</artifactId>
Expand Down
77 changes: 77 additions & 0 deletions genai/snippets/src/main/java/genai/tools/ToolsVaisWithText.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package genai.tools;

// [START googlegenaisdk_tools_vais_with_txt]

import com.google.genai.Client;
import com.google.genai.types.GenerateContentConfig;
import com.google.genai.types.GenerateContentResponse;
import com.google.genai.types.HttpOptions;
import com.google.genai.types.Retrieval;
import com.google.genai.types.Tool;
import com.google.genai.types.VertexAISearch;

public class ToolsVaisWithText {

public static void main(String[] args) {
// TODO(developer): Replace these variables before running the sample.
String modelId = "gemini-2.5-flash";
// Load Data Store ID from Vertex AI Search
// E.g datastoreId =
// "projects/project-id/locations/global/collections/default_collection/dataStores/datastore-id"
String datastoreId = "your-datastore";
generateContent(modelId, datastoreId);
}

// Generates text with Vertex AI Search tool
public static String generateContent(String modelId, String datastoreId) {
// Initialize client that will be used to send requests. This client only needs to be created
// once, and can be reused for multiple requests.
try (Client client =
Client.builder()
.location("global")
.vertexAI(true)
.httpOptions(HttpOptions.builder().apiVersion("v1").build())
.build()) {

// Set the VertexAI Search tool and the datastore that the model can use to retrieve data from
Tool vaisSearchTool =
Tool.builder()
.retrieval(
Retrieval.builder()
.vertexAiSearch(VertexAISearch.builder().datastore(datastoreId).build())
.build())
.build();

// Create a GenerateContentConfig and set the Vertex AI Search tool
GenerateContentConfig contentConfig =
GenerateContentConfig.builder().tools(vaisSearchTool).build();

GenerateContentResponse response =
client.models.generateContent(
modelId, "How do I make an appointment to renew my driver's license?", contentConfig);

System.out.print(response.text());
// Example response:
// The process for making an appointment to renew your driver's license varies depending
// on your location. To provide you with the most accurate instructions...
return response.text();
}
}
}
// [END googlegenaisdk_tools_vais_with_txt]
55 changes: 54 additions & 1 deletion genai/snippets/src/test/java/genai/tools/ToolsIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,31 @@

import static com.google.common.truth.Truth.assertThat;
import static com.google.common.truth.Truth.assertWithMessage;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.RETURNS_SELF;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import com.google.genai.Client;
import com.google.genai.Models;
import com.google.genai.types.GenerateContentConfig;
import com.google.genai.types.GenerateContentResponse;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.PrintStream;
import java.lang.reflect.Field;
import org.junit.After;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.MockedStatic;


@RunWith(JUnit4.class)
public class ToolsIT {
Expand Down Expand Up @@ -105,4 +120,42 @@ public void testToolsGoogleSearchWithText() {
assertThat(response).isNotEmpty();
}

@Test
public void testToolsVaisWithText() throws NoSuchFieldException, IllegalAccessException {
String response = "The process for making an appointment to renew your driver's license"
+ " varies depending on your location.";

String datastore =
String.format(
"projects/%s/locations/global/collections/default_collection/"
+ "dataStores/grounding-test-datastore",
PROJECT_ID);

Client.Builder mockedBuilder = mock(Client.Builder.class, RETURNS_SELF);
Client mockedClient = mock(Client.class);
Models mockedModels = mock(Models.class);
GenerateContentResponse mockedResponse = mock(GenerateContentResponse.class);

try (MockedStatic<Client> mockedStatic = mockStatic(Client.class)) {
mockedStatic.when(Client::builder).thenReturn(mockedBuilder);
when(mockedBuilder.build()).thenReturn(mockedClient);

// Using reflection because 'models' is a final field and cannot be mockable directly
Field field = Client.class.getDeclaredField("models");
field.setAccessible(true);
field.set(mockedClient, mockedModels);

when(mockedClient.models.generateContent(
anyString(), anyString(), any(GenerateContentConfig.class)))
.thenReturn(mockedResponse);
when(mockedResponse.text()).thenReturn(response);

String generatedResponse = ToolsVaisWithText.generateContent(GEMINI_FLASH, datastore);

verify(mockedClient.models, times(1))
.generateContent(anyString(), anyString(), any(GenerateContentConfig.class));
assertThat(generatedResponse).isNotEmpty();
assertThat(response).isEqualTo(generatedResponse);
}
}
}