diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 916198b658dd..de657443fd26 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -15,3 +15,7 @@ **/firestore/ @googleapis/toolbox-firestore-team **/looker/ @googleapis/toolbox-looker-team **/spanner/ @googleapis/toolbox-spanner-team + +# Docs +/docs/ @googleapis/senseai-eco-team +/*.md @googleapis/senseai-eco-team diff --git a/.github/workflows/deploy_versioned_docs.yaml b/.github/workflows/deploy_versioned_docs.yaml index 8da91c95be5c..adc4bfedf6ff 100644 --- a/.github/workflows/deploy_versioned_docs.yaml +++ b/.github/workflows/deploy_versioned_docs.yaml @@ -35,9 +35,10 @@ jobs: ref: ${{ github.event.release.tag_name }} - name: Get Version from Release Tag - run: echo "VERSION=${GITHUB_EVENT_RELEASE_TAG_NAME}" >> $GITHUB_ENV + id: get_version env: - GITHUB_EVENT_RELEASE_TAG_NAME: ${{ github.event.release.tag_name }} + RELEASE_TAG: ${{ github.event.release.tag_name }} + run: echo "VERSION=${RELEASE_TAG}" >> "$GITHUB_OUTPUT" - name: Setup Hugo uses: peaceiris/actions-hugo@75d2e84710de30f6ff7268e08f310b60ef14033f # v3 @@ -58,7 +59,7 @@ jobs: run: hugo --minify working-directory: .hugo env: - HUGO_BASEURL: https://${{ github.repository_owner }}.github.io/${{ github.event.repository.name }}/${{ env.VERSION }}/ + HUGO_BASEURL: https://${{ github.repository_owner }}.github.io/${{ github.event.repository.name }}/${{ steps.get_version.outputs.VERSION }}/ HUGO_RELATIVEURLS: false - name: Deploy @@ -67,9 +68,9 @@ jobs: github_token: ${{ secrets.GITHUB_TOKEN }} publish_dir: .hugo/public publish_branch: versioned-gh-pages - destination_dir: ./${{ env.VERSION }} + destination_dir: ./${{ steps.get_version.outputs.VERSION }} keep_files: true - commit_message: "deploy: docs for ${{ env.VERSION }}" + commit_message: "deploy: docs for ${{ steps.get_version.outputs.VERSION }}" - name: Clean Build Directory run: rm -rf .hugo/public @@ -89,4 +90,4 @@ jobs: publish_branch: versioned-gh-pages keep_files: true allow_empty_commit: true - commit_message: "deploy: docs to root for ${{ env.VERSION }}" + commit_message: "deploy: docs to root for ${{ steps.get_version.outputs.VERSION }}" diff --git a/.github/workflows/link_checker.yaml b/.github/workflows/link_checker.yaml index ef835d8b6d78..b535a21fc498 100644 --- a/.github/workflows/link_checker.yaml +++ b/.github/workflows/link_checker.yaml @@ -20,7 +20,6 @@ permissions: pull-requests: write issues: write - jobs: link-check: runs-on: ubuntu-latest @@ -29,6 +28,7 @@ jobs: uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 with: fetch-depth: 0 + - name: Identify Changed Files id: changed-files shell: bash @@ -38,25 +38,23 @@ jobs: if [ -z "$CHANGED_FILES" ]; then echo "No markdown files changed. Skipping checks." - echo "HAS_CHANGES=false" >> $GITHUB_ENV + echo "HAS_CHANGES=false" >> "$GITHUB_OUTPUT" else echo "--- Changed Files to Scan ---" echo "$CHANGED_FILES" echo "-----------------------------" - # FIX: Wrap filenames in quotes to handle spaces FILES_QUOTED=$(echo "$CHANGED_FILES" | sed 's/^/"/;s/$/"/' | tr '\n' ' ') - # Write to env using EOF pattern - echo "CHECK_FILES<> $GITHUB_ENV - echo "$FILES_QUOTED" >> $GITHUB_ENV - echo "EOF" >> $GITHUB_ENV - echo "HAS_CHANGES=true" >> $GITHUB_ENV + # Use EOF to write multiline or long strings to GITHUB_OUTPUT + echo "HAS_CHANGES=true" >> "$GITHUB_OUTPUT" + echo "CHECK_FILES<> "$GITHUB_OUTPUT" + echo "$FILES_QUOTED" >> "$GITHUB_OUTPUT" + echo "EOF" >> "$GITHUB_OUTPUT" fi - - name: Restore lychee cache - if: env.HAS_CHANGES == 'true' + if: steps.changed-files.outputs.HAS_CHANGES == 'true' uses: actions/cache@cdf6c1fa76f9f475f3d7449005a359c84ca0f306 # v5 with: path: .lycheecache @@ -65,7 +63,7 @@ jobs: - name: Link Checker id: lychee-check - if: env.HAS_CHANGES == 'true' + if: steps.changed-files.outputs.HAS_CHANGES == 'true' uses: lycheeverse/lychee-action@a8c4c7cb88f0c7386610c35eb25108e448569cb0 # v2 continue-on-error: true with: @@ -75,7 +73,7 @@ jobs: --cache --max-cache-age 1d --exclude '^neo4j\+.*' --exclude '^bolt://.*' - ${{ env.CHECK_FILES }} + ${{ steps.changed-files.outputs.CHECK_FILES }} output: lychee-report.md format: markdown fail: true @@ -83,6 +81,7 @@ jobs: debug: false env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Find comment uses: peter-evans/find-comment@b30e6a3c0ed37e7c023ccd3f1db5c6c0b0c23aad # v4 id: find-comment @@ -103,18 +102,16 @@ jobs: STEPS_FIND_COMMENT_OUTPUTS_COMMENT_ID: ${{ steps.find-comment.outputs.comment-id }} - name: Prepare Report - if: env.HAS_CHANGES == 'true' && steps.lychee-check.outcome == 'failure' + if: steps.changed-files.outputs.HAS_CHANGES == 'true' && steps.lychee-check.outcome == 'failure' run: | echo "## Link Resolution Note" > full-report.md - - - echo "Local links and directory changes work differently on GitHub than on the docsite.You must ensure fixes pass the **GitHub check** and also work with **\`hugo server\`**." >> full-report.md + echo "Local links and directory changes work differently on GitHub than on the docsite. You must ensure fixes pass the **GitHub check** and also work with **\`hugo server\`**." >> full-report.md echo "See [Link Checking and Fixing with Lychee](https://github.com/googleapis/genai-toolbox/blob/main/DEVELOPER.md#link-checking-and-fixing-with-lychee) for more details." >> full-report.md echo "" >> full-report.md sed -E '/(Redirect|Redirects per input)/d' lychee-report.md >> full-report.md - name: Create PR Comment - if: env.HAS_CHANGES == 'true' && steps.lychee-check.outcome == 'failure' + if: steps.changed-files.outputs.HAS_CHANGES == 'true' && steps.lychee-check.outcome == 'failure' uses: peter-evans/create-or-update-comment@e8674b075228eee787fea43ef493e45ece1004c9 # v5 with: comment-id: ${{ steps.find-comment.outputs.comment-id }} diff --git a/.hugo/hugo.toml b/.hugo/hugo.toml index 76e253dc0a32..d56e5478d1cf 100644 --- a/.hugo/hugo.toml +++ b/.hugo/hugo.toml @@ -51,6 +51,10 @@ ignoreFiles = ["quickstart/shared", "quickstart/python", "quickstart/js", "quick # Add a new version block here before every release # The order of versions in this file is mirrored into the dropdown +[[params.versions]] + version = "v0.28.0" + url = "https://googleapis.github.io/genai-toolbox/v0.28.0/" + [[params.versions]] version = "v0.27.0" url = "https://googleapis.github.io/genai-toolbox/v0.27.0/" diff --git a/cmd/internal/imports.go b/cmd/internal/imports.go index a0dac175e400..a3f0640ae861 100644 --- a/cmd/internal/imports.go +++ b/cmd/internal/imports.go @@ -107,6 +107,7 @@ import ( _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerconversationalanalytics" _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercreateprojectdirectory" _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercreateprojectfile" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercreateviewfromtable" _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerdeleteprojectdirectory" _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerdeleteprojectfile" _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerdevmode" @@ -120,6 +121,7 @@ import ( _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetdimensions" _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetexplores" _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetfilters" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetlookmltests" _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetlooks" _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetmeasures" _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetmodels" @@ -138,6 +140,7 @@ import ( _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerqueryurl" _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerrundashboard" _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerrunlook" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerrunlookmltests" _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerupdateprojectfile" _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookervalidateproject" _ "github.com/googleapis/genai-toolbox/internal/tools/mindsdb/mindsdbexecutesql" diff --git a/cmd/internal/tools_file_test.go b/cmd/internal/tools_file_test.go index 31e61392402b..5e445e2935db 100644 --- a/cmd/internal/tools_file_test.go +++ b/cmd/internal/tools_file_test.go @@ -1778,7 +1778,7 @@ func TestPrebuiltTools(t *testing.T) { wantToolset: server.ToolsetConfigs{ "looker_tools": tools.ToolsetConfig{ Name: "looker_tools", - ToolNames: []string{"get_models", "get_explores", "get_dimensions", "get_measures", "get_filters", "get_parameters", "query", "query_sql", "query_url", "get_looks", "run_look", "make_look", "get_dashboards", "run_dashboard", "make_dashboard", "add_dashboard_element", "add_dashboard_filter", "generate_embed_url", "health_pulse", "health_analyze", "health_vacuum", "dev_mode", "get_projects", "get_project_files", "get_project_file", "create_project_file", "update_project_file", "delete_project_file", "get_project_directories", "create_project_directory", "delete_project_directory", "validate_project", "get_connections", "get_connection_schemas", "get_connection_databases", "get_connection_tables", "get_connection_table_columns"}, + ToolNames: []string{"get_models", "get_explores", "get_dimensions", "get_measures", "get_filters", "get_parameters", "query", "query_sql", "query_url", "get_looks", "run_look", "make_look", "get_dashboards", "run_dashboard", "make_dashboard", "add_dashboard_element", "add_dashboard_filter", "generate_embed_url", "health_pulse", "health_analyze", "health_vacuum", "dev_mode", "get_projects", "get_project_files", "get_project_file", "create_project_file", "update_project_file", "delete_project_file", "get_project_directories", "create_project_directory", "delete_project_directory", "validate_project", "get_connections", "get_connection_schemas", "get_connection_databases", "get_connection_tables", "get_connection_table_columns", "get_lookml_tests", "run_lookml_tests", "create_view_from_table"}, }, }, }, diff --git a/docs/en/reference/prebuilt-tools.md b/docs/en/reference/prebuilt-tools.md index 5330d1735d48..4e36ebfa2c41 100644 --- a/docs/en/reference/prebuilt-tools.md +++ b/docs/en/reference/prebuilt-tools.md @@ -552,6 +552,9 @@ See [Usage Examples](../reference/cli.md#examples). * `get_connection_databases`: Get the available databases in a connection. * `get_connection_tables`: Get the available tables in a connection. * `get_connection_table_columns`: Get the available columns for a table. + * `get_lookml_tests`: Retrieves a list of available LookML tests for a project. + * `run_lookml_tests`: Executes specific LookML tests within a project. + * `create_view_from_table`: Generates boilerplate LookML views directly from the database schema. ## Looker Conversational Analytics diff --git a/docs/en/resources/tools/looker/looker-create-view-from-table.md b/docs/en/resources/tools/looker/looker-create-view-from-table.md new file mode 100644 index 000000000000..d40117826e89 --- /dev/null +++ b/docs/en/resources/tools/looker/looker-create-view-from-table.md @@ -0,0 +1,50 @@ +--- +title: "looker-create-view-from-table" +type: docs +weight: 1 +description: > + This tool generates boilerplate LookML views directly from the database schema. +aliases: +- /resources/tools/looker-create-view-from-table +--- + +## About + +A "looker-create-view-from-table" tool triggers the automatic generation of LookML view files based on database tables. + +It's compatible with the following sources: + +- [looker](../../sources/looker.md) + +`looker-create-view-from-table` accepts project_id, connection, tables, and folder_name parameters. + +## Example + +```yaml +tools: + create_view_from_table: + kind: looker-create-view-from-table + source: looker-source + description: | + This tool generates boilerplate LookML views directly from the database schema. + It does not create model or explore files, only view files in the specified folder. + + Prerequisite: The Looker session must be in Development Mode. Use `dev_mode: true` first. + + Parameters: + - project_id (required): The unique ID of the LookML project. + - connection (required): The database connection name. + - tables (required): A list of objects to generate views for. Each object must contain `schema` and `table_name` (note: table names are case-sensitive). Optional fields include `primary_key`, `base_view`, and `columns` (array of objects with `column_name`). + - folder_name (optional): The folder to place the view files in (defaults to 'views/'). + + Output: + A confirmation message upon successful view generation, or an error message if the operation fails. +``` + +## Reference + +| **field** | **type** | **required** | **description** | +|-------------|:--------:|:------------:|----------------------------------------------------| +| kind | string | true | Must be "looker-create-view-from-table". | +| source | string | true | Name of the source Looker instance. | +| description | string | true | Description of the tool that is passed to the LLM. | diff --git a/docs/en/resources/tools/looker/looker-get-lookml-tests.md b/docs/en/resources/tools/looker/looker-get-lookml-tests.md new file mode 100644 index 000000000000..7148034ef2d1 --- /dev/null +++ b/docs/en/resources/tools/looker/looker-get-lookml-tests.md @@ -0,0 +1,53 @@ +--- +title: "looker-get-lookml-tests" +type: docs +weight: 1 +description: > + Returns a list of tests which can be run to validate a project's LookML code and/or the underlying data, optionally filtered by the file id. +aliases: +- /resources/tools/looker-get-lookml-tests +--- + +## About + +A "looker-get-lookml-tests" tool retrieves a list of available LookML tests for a project. + +It's compatible with the following sources: + +- [looker](../../sources/looker.md) + +`looker-get-lookml-tests` accepts project_id and file_id parameters. + +## Example + +```yaml +tools: + get_lookml_tests: + kind: looker-get-lookml-tests + source: looker-source + description: | + Returns a list of tests which can be run to validate a project's LookML code and/or the underlying data, optionally filtered by the file id. + + Prerequisite: The Looker session must be in Development Mode. Use `dev_mode: true` first. + + Parameters: + - project_id (required): The unique ID of the LookML project. + - file_id (optional): The ID of the file to filter tests by. This must be the complete file path from the project root (e.g., `models/my_model.model.lkml` or `views/my_view.view.lkml`). + + Output: + A JSON array of LookML test objects, each containing: + - model_name: The name of the model. + - name: The name of the test. + - explore_name: The name of the explore being tested. + - query_url_params: The query parameters used for the test. + - file: The file path where the test is defined. + - line: The line number where the test is defined. +``` + +## Reference + +| **field** | **type** | **required** | **description** | +|-------------|:--------:|:------------:|----------------------------------------------------| +| kind | string | true | Must be "looker-get-lookml-tests". | +| source | string | true | Name of the source Looker instance. | +| description | string | true | Description of the tool that is passed to the LLM. | diff --git a/docs/en/resources/tools/looker/looker-run-lookml-tests.md b/docs/en/resources/tools/looker/looker-run-lookml-tests.md new file mode 100644 index 000000000000..94c80c433b87 --- /dev/null +++ b/docs/en/resources/tools/looker/looker-run-lookml-tests.md @@ -0,0 +1,56 @@ +--- +title: "looker-run-lookml-tests" +type: docs +weight: 1 +description: > + This tool runs LookML tests in the project, filtered by file, test, and/or model. +aliases: +- /resources/tools/looker-run-lookml-tests +--- + +## About + +A "looker-run-lookml-tests" tool executes specific LookML tests within a project. + +It's compatible with the following sources: + +- [looker](../../sources/looker.md) + +`looker-run-lookml-tests` accepts project_id, file_id, test, and model parameters. + +## Example + +```yaml +tools: + run_lookml_tests: + kind: looker-run-lookml-tests + source: looker-source + description: | + This tool runs LookML tests in the project, filtered by file, test, and/or model. These filters work in conjunction (logical AND). + + Prerequisite: The Looker session must be in Development Mode. Use `dev_mode: true` first. + + Parameters: + - project_id (required): The unique ID of the project to run LookML tests for. + - file_id (optional): The ID of the file to run tests for. This must be the complete file path from the project root (e.g., `models/my_model.model.lkml` or `views/my_view.view.lkml`). + - test (optional): The name of the test to run. + - model (optional): The name of the model to run tests for. + + Output: + A JSON array containing the results of the executed tests, where each object includes: + - model_name: Name of the model tested. + - test_name: Name of the test. + - assertions_count: Total number of assertions in the test. + - assertions_failed: Number of assertions that failed. + - success: Boolean indicating if the test passed. + - errors: Array of error objects (if any), containing details like `message`, `file_path`, `line_number`, and `severity`. + - warnings: Array of warning messages (if any). +``` + +## Reference + +| **field** | **type** | **required** | **description** | +|-------------|:--------:|:------------:|----------------------------------------------------| +| kind | string | true | Must be "looker-run-lookml-tests". | +| source | string | true | Name of the source Looker instance. | +| description | string | true | Description of the tool that is passed to the LLM. | diff --git a/go.mod b/go.mod index 9df36dbc2071..1c84e0c11092 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,7 @@ require ( cloud.google.com/go/dataplex v1.28.0 cloud.google.com/go/dataproc/v2 v2.15.0 cloud.google.com/go/firestore v1.20.0 - cloud.google.com/go/geminidataanalytics v0.3.0 + cloud.google.com/go/geminidataanalytics v0.5.0 cloud.google.com/go/logging v1.13.1 cloud.google.com/go/longrunning v0.7.0 cloud.google.com/go/spanner v1.86.1 diff --git a/go.sum b/go.sum index b1676bac848a..622acb84f943 100644 --- a/go.sum +++ b/go.sum @@ -311,8 +311,8 @@ cloud.google.com/go/gaming v1.6.0/go.mod h1:YMU1GEvA39Qt3zWGyAVA9bpYz/yAhTvaQ1t2 cloud.google.com/go/gaming v1.7.0/go.mod h1:LrB8U7MHdGgFG851iHAfqUdLcKBdQ55hzXy9xBJz0+w= cloud.google.com/go/gaming v1.8.0/go.mod h1:xAqjS8b7jAVW0KFYeRUxngo9My3f33kFmua++Pi+ggM= cloud.google.com/go/gaming v1.9.0/go.mod h1:Fc7kEmCObylSWLO334NcO+O9QMDyz+TKC4v1D7X+Bc0= -cloud.google.com/go/geminidataanalytics v0.3.0 h1:2Wi/kqFb5OLuEGH7q+/miE19VTqK1MYHjBEHENap9HI= -cloud.google.com/go/geminidataanalytics v0.3.0/go.mod h1:QRc0b6ywyc3Z7S3etFgslz7hippkW/jRvtops5rKqIg= +cloud.google.com/go/geminidataanalytics v0.5.0 h1:+1usY81Cb+hE8BokpqCM7EgJtRCKzUKx7FvrHbT5hCA= +cloud.google.com/go/geminidataanalytics v0.5.0/go.mod h1:QRc0b6ywyc3Z7S3etFgslz7hippkW/jRvtops5rKqIg= cloud.google.com/go/gkebackup v0.2.0/go.mod h1:XKvv/4LfG829/B8B7xRkk8zRrOEbKtEam6yNfuQNH60= cloud.google.com/go/gkebackup v0.3.0/go.mod h1:n/E671i1aOQvUxT541aTkCwExO/bTer2HDlj4TsBRAo= cloud.google.com/go/gkebackup v0.4.0/go.mod h1:byAyBGUwYGEEww7xsbnUTBHIYcOPy/PgUWUtOeRm9Vg= diff --git a/internal/prebuiltconfigs/tools/looker.yaml b/internal/prebuiltconfigs/tools/looker.yaml index 2f3f62f670eb..b751de822a3d 100644 --- a/internal/prebuiltconfigs/tools/looker.yaml +++ b/internal/prebuiltconfigs/tools/looker.yaml @@ -1098,6 +1098,68 @@ tools: A JSON array of objects, where each object represents a column and contains details such as `table_name`, `column_name`, `data_type`, and `is_nullable`. + get_lookml_tests: + kind: looker-get-lookml-tests + source: looker-source + description: | + Returns a list of tests which can be run to validate a project's LookML code and/or the underlying data, optionally filtered by the file id. + + Prerequisite: The Looker session must be in Development Mode. Use `dev_mode: true` first. + + Parameters: + - project_id (required): The unique ID of the LookML project. + - file_id (optional): The ID of the file to filter tests by. This must be the complete file path from the project root (e.g., `models/my_model.model.lkml` or `views/my_view.view.lkml`). + + Output: + A JSON array of LookML test objects, each containing: + - model_name: The name of the model. + - name: The name of the test. + - explore_name: The name of the explore being tested. + - query_url_params: The query parameters used for the test. + - file: The file path where the test is defined. + - line: The line number where the test is defined. + + run_lookml_tests: + kind: looker-run-lookml-tests + source: looker-source + description: | + This tool runs LookML tests in the project, filtered by file, test, and/or model. These filters work in conjunction (logical AND). + + Prerequisite: The Looker session must be in Development Mode. Use `dev_mode: true` first. + + Parameters: + - project_id (required): The unique ID of the project to run LookML tests for. + - file_id (optional): The ID of the file to run tests for. This must be the complete file path from the project root (e.g., `models/my_model.model.lkml` or `views/my_view.view.lkml`). + - test (optional): The name of the test to run. + - model (optional): The name of the model to run tests for. + + Output: + A JSON array containing the results of the executed tests, where each object includes: + - model_name: Name of the model tested. + - test_name: Name of the test. + - assertions_count: Total number of assertions in the test. + - assertions_failed: Number of assertions that failed. + - success: Boolean indicating if the test passed. + - errors: Array of error objects (if any), containing details like `message`, `file_path`, `line_number`, and `severity`. + - warnings: Array of warning messages (if any). + + create_view_from_table: + kind: looker-create-view-from-table + source: looker-source + description: | + This tool generates boilerplate LookML views directly from the database schema. + It does not create model or explore files, only view files in the specified folder. + + Prerequisite: The Looker session must be in Development Mode. Use `dev_mode: true` first. + + Parameters: + - project_id (required): The unique ID of the LookML project. + - connection (required): The database connection name. + - tables (required): A list of objects to generate views for. Each object must contain `schema` and `table_name` (note: table names are case-sensitive). Optional fields include `primary_key`, `base_view`, and `columns` (array of objects with `column_name`). + - folder_name (optional): The folder to place the view files in (defaults to 'views/'). + + Output: + A confirmation message upon successful view generation, or an error message if the operation fails. toolsets: looker_tools: @@ -1138,3 +1200,6 @@ toolsets: - get_connection_databases - get_connection_tables - get_connection_table_columns + - get_lookml_tests + - run_lookml_tests + - create_view_from_table diff --git a/internal/sources/cloudgda/cloud_gda.go b/internal/sources/cloudgda/cloud_gda.go index 80e8df431c7c..4c977418c6a9 100644 --- a/internal/sources/cloudgda/cloud_gda.go +++ b/internal/sources/cloudgda/cloud_gda.go @@ -14,23 +14,23 @@ package cloudgda import ( - "bytes" "context" - "encoding/json" "fmt" - "io" - "net/http" + geminidataanalytics "cloud.google.com/go/geminidataanalytics/apiv1beta" + "cloud.google.com/go/geminidataanalytics/apiv1beta/geminidataanalyticspb" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/util" "go.opentelemetry.io/otel/trace" "golang.org/x/oauth2" - "golang.org/x/oauth2/google" + "google.golang.org/api/option" ) const SourceType string = "cloud-gemini-data-analytics" -const Endpoint string = "https://geminidataanalytics.googleapis.com" + +// NewDataChatClient can be overridden for testing. +var NewDataChatClient = geminidataanalytics.NewDataChatClient // validate interface var _ sources.SourceConfig = Config{} @@ -67,29 +67,19 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So return nil, fmt.Errorf("error in User Agent retrieval: %s", err) } - var client *http.Client - if r.UseClientOAuth { - client = &http.Client{ - Transport: util.NewUserAgentRoundTripper(ua, http.DefaultTransport), - } - } else { - // Use Application Default Credentials - // Scope: "https://www.googleapis.com/auth/cloud-platform" is generally sufficient for GDA - creds, err := google.FindDefaultCredentials(ctx, "https://www.googleapis.com/auth/cloud-platform") - if err != nil { - return nil, fmt.Errorf("failed to find default credentials: %w", err) - } - baseClient := oauth2.NewClient(ctx, creds.TokenSource) - baseClient.Transport = util.NewUserAgentRoundTripper(ua, baseClient.Transport) - client = baseClient - } - s := &Source{ Config: r, - Client: client, - BaseURL: Endpoint, userAgent: ua, } + + if !r.UseClientOAuth { + client, err := NewDataChatClient(ctx, option.WithUserAgent(ua)) + if err != nil { + return nil, fmt.Errorf("failed to create DataChatClient: %w", err) + } + s.Client = client + } + return s, nil } @@ -97,8 +87,7 @@ var _ sources.Source = &Source{} type Source struct { Config - Client *http.Client - BaseURL string + Client *geminidataanalytics.DataChatClient userAgent string } @@ -114,63 +103,36 @@ func (s *Source) GetProjectID() string { return s.ProjectID } -func (s *Source) GetBaseURL() string { - return s.BaseURL -} - -func (s *Source) GetClient(ctx context.Context, accessToken string) (*http.Client, error) { - if s.UseClientOAuth { - if accessToken == "" { - return nil, fmt.Errorf("client-side OAuth is enabled but no access token was provided") - } - token := &oauth2.Token{AccessToken: accessToken} - baseClient := oauth2.NewClient(ctx, oauth2.StaticTokenSource(token)) - baseClient.Transport = util.NewUserAgentRoundTripper(s.userAgent, baseClient.Transport) - return baseClient, nil - } - return s.Client, nil -} - func (s *Source) UseClientAuthorization() bool { return s.UseClientOAuth } -func (s *Source) RunQuery(ctx context.Context, tokenStr string, bodyBytes []byte) (any, error) { - // The API endpoint itself always uses the "global" location. - apiLocation := "global" - apiParent := fmt.Sprintf("projects/%s/locations/%s", s.GetProjectID(), apiLocation) - apiURL := fmt.Sprintf("%s/v1beta/%s:queryData", s.GetBaseURL(), apiParent) - - client, err := s.GetClient(ctx, tokenStr) - if err != nil { - return nil, fmt.Errorf("failed to get HTTP client: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(bodyBytes)) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - req.Header.Set("Content-Type", "application/json") +func (s *Source) GetClient(ctx context.Context, tokenStr string) (*geminidataanalytics.DataChatClient, func(), error) { + if s.UseClientOAuth { + if tokenStr == "" { + return nil, nil, fmt.Errorf("client-side OAuth is enabled but no access token was provided") + } + token := &oauth2.Token{AccessToken: tokenStr} + opts := []option.ClientOption{ + option.WithUserAgent(s.userAgent), + option.WithTokenSource(oauth2.StaticTokenSource(token)), + } - resp, err := client.Do(req) - if err != nil { - return nil, fmt.Errorf("failed to execute request: %w", err) + client, err := NewDataChatClient(ctx, opts...) + if err != nil { + return nil, nil, fmt.Errorf("failed to create per-request DataChatClient: %w", err) + } + return client, func() { client.Close() }, nil } - defer resp.Body.Close() + return s.Client, func() {}, nil +} - respBody, err := io.ReadAll(resp.Body) +func (s *Source) RunQuery(ctx context.Context, tokenStr string, req *geminidataanalyticspb.QueryDataRequest) (*geminidataanalyticspb.QueryDataResponse, error) { + client, cleanup, err := s.GetClient(ctx, tokenStr) if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) - } - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(respBody)) - } - - var result map[string]any - if err := json.Unmarshal(respBody, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) + return nil, err } + defer cleanup() - return result, nil + return client.QueryData(ctx, req) } diff --git a/internal/sources/cloudgda/cloud_gda_test.go b/internal/sources/cloudgda/cloud_gda_test.go index 6ec771f60120..b081d84753c3 100644 --- a/internal/sources/cloudgda/cloud_gda_test.go +++ b/internal/sources/cloudgda/cloud_gda_test.go @@ -172,11 +172,9 @@ func TestInitialize(t *testing.T) { if gdaSrc.Client == nil && !tc.wantClientOAuth { t.Fatal("expected non-nil HTTP client for ADC, got nil") } - // When client OAuth is true, the source's client should be initialized with a base HTTP client - // that includes the user agent round tripper, but not the OAuth token. The token-aware - // client is created by GetClient. - if gdaSrc.Client == nil && tc.wantClientOAuth { - t.Fatal("expected non-nil HTTP client for client OAuth config, got nil") + // When client OAuth is true, the source's client should be nil. + if gdaSrc.Client != nil && tc.wantClientOAuth { + t.Fatal("expected nil HTTP client for client OAuth config, got non-nil") } // Test UseClientAuthorization method @@ -186,15 +184,16 @@ func TestInitialize(t *testing.T) { // Test GetClient with accessToken for client OAuth scenarios if tc.wantClientOAuth { - client, err := gdaSrc.GetClient(ctx, "dummy-token") + client, cleanup, err := gdaSrc.GetClient(ctx, "dummy-token") if err != nil { t.Fatalf("GetClient with token failed: %v", err) } + defer cleanup() if client == nil { t.Fatal("expected non-nil HTTP client from GetClient with token, got nil") } // Ensure passing empty token with UseClientOAuth enabled returns error - _, err = gdaSrc.GetClient(ctx, "") + _, _, err = gdaSrc.GetClient(ctx, "") if err == nil || err.Error() != "client-side OAuth is enabled but no access token was provided" { t.Errorf("expected 'client-side OAuth is enabled but no access token was provided' error, got: %v", err) } diff --git a/internal/tools/bigquery/bigqueryconversationalanalytics/bigqueryconversationalanalytics.go b/internal/tools/bigquery/bigqueryconversationalanalytics/bigqueryconversationalanalytics.go index 54d29d160582..226cca46039d 100644 --- a/internal/tools/bigquery/bigqueryconversationalanalytics/bigqueryconversationalanalytics.go +++ b/internal/tools/bigquery/bigqueryconversationalanalytics/bigqueryconversationalanalytics.go @@ -482,30 +482,33 @@ func handleTextResponse(resp *TextResponse) map[string]any { } func handleSchemaResponse(resp *SchemaResponse) map[string]any { + res := make(map[string]any) if resp.Query != nil { - return map[string]any{"Question": resp.Query.Question} + res["Question"] = resp.Query.Question } if resp.Result != nil { var formattedSources []map[string]any for _, ds := range resp.Result.Datasources { formattedSources = append(formattedSources, formatDatasourceAsDict(&ds)) } - return map[string]any{"Schema Resolved": formattedSources} + res["Schema Resolved"] = formattedSources } - return nil + if len(res) == 0 { + return nil + } + return res } func handleDataResponse(resp *DataResponse, maxRows int) map[string]any { + res := make(map[string]any) if resp.Query != nil { - return map[string]any{ - "Retrieval Query": map[string]any{ - "Query Name": resp.Query.Name, - "Question": resp.Query.Question, - }, + res["Retrieval Query"] = map[string]any{ + "Query Name": resp.Query.Name, + "Question": resp.Query.Question, } } if resp.GeneratedSQL != "" { - return map[string]any{"SQL Generated": resp.GeneratedSQL} + res["SQL Generated"] = resp.GeneratedSQL } if resp.Result != nil { var headers []string @@ -533,15 +536,16 @@ func handleDataResponse(resp *DataResponse, maxRows int) map[string]any { summary = fmt.Sprintf("Showing the first %d of %d total rows.", numRowsToDisplay, totalRows) } - return map[string]any{ - "Data Retrieved": map[string]any{ - "headers": headers, - "rows": compactRows, - "summary": summary, - }, + res["Data Retrieved"] = map[string]any{ + "headers": headers, + "rows": compactRows, + "summary": summary, } } - return nil + if len(res) == 0 { + return nil + } + return res } func handleError(resp *ErrorResponse) map[string]any { @@ -557,9 +561,17 @@ func appendMessage(messages []map[string]any, newMessage map[string]any) []map[s if newMessage == nil { return messages } - if len(messages) > 0 { - if _, ok := messages[len(messages)-1]["Data Retrieved"]; ok { - messages = messages[:len(messages)-1] + + if _, hasData := newMessage["Data Retrieved"]; hasData { + // Only keep the last data result while preserving SQL and other metadata. + for i := len(messages) - 1; i >= 0; i-- { + if _, ok := messages[i]["Data Retrieved"]; ok { + delete(messages[i], "Data Retrieved") + if len(messages[i]) == 0 { + messages = append(messages[:i], messages[i+1:]...) + } + break + } } } return append(messages, newMessage) diff --git a/internal/tools/cloudgda/cloudgda.go b/internal/tools/cloudgda/cloudgda.go index 14862909b43f..be351a1e1ca3 100644 --- a/internal/tools/cloudgda/cloudgda.go +++ b/internal/tools/cloudgda/cloudgda.go @@ -20,12 +20,14 @@ import ( "fmt" "net/http" + "cloud.google.com/go/geminidataanalytics/apiv1beta/geminidataanalyticspb" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" + "google.golang.org/protobuf/encoding/protojson" ) const resourceType string = "cloud-gemini-data-analytics-query" @@ -62,7 +64,49 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { GetProjectID() string UseClientAuthorization() bool - RunQuery(context.Context, string, []byte) (any, error) + RunQuery(context.Context, string, *geminidataanalyticspb.QueryDataRequest) (*geminidataanalyticspb.QueryDataResponse, error) +} + +// QueryDataContext wraps geminidataanalyticspb.QueryDataContext to support YAML decoding via protojson. +type QueryDataContext struct { + *geminidataanalyticspb.QueryDataContext +} + +func (q *QueryDataContext) UnmarshalYAML(b []byte) error { + var raw map[string]any + if err := yaml.Unmarshal(b, &raw); err != nil { + return fmt.Errorf("failed to unmarshal context from yaml: %w", err) + } + jsonBytes, err := json.Marshal(raw) + if err != nil { + return fmt.Errorf("failed to marshal context map: %w", err) + } + q.QueryDataContext = &geminidataanalyticspb.QueryDataContext{} + if err := protojson.Unmarshal(jsonBytes, q.QueryDataContext); err != nil { + return fmt.Errorf("failed to unmarshal context to proto: %w", err) + } + return nil +} + +// GenerationOptions wraps geminidataanalyticspb.GenerationOptions to support YAML decoding via protojson. +type GenerationOptions struct { + *geminidataanalyticspb.GenerationOptions +} + +func (g *GenerationOptions) UnmarshalYAML(b []byte) error { + var raw map[string]any + if err := yaml.Unmarshal(b, &raw); err != nil { + return fmt.Errorf("failed to unmarshal generation options from yaml: %w", err) + } + jsonBytes, err := json.Marshal(raw) + if err != nil { + return fmt.Errorf("failed to marshal generation options map: %w", err) + } + g.GenerationOptions = &geminidataanalyticspb.GenerationOptions{} + if err := protojson.Unmarshal(jsonBytes, g.GenerationOptions); err != nil { + return fmt.Errorf("failed to unmarshal generation options to proto: %w", err) + } + return nil } type Config struct { @@ -99,12 +143,14 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) } mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) - return Tool{ + t := Tool{ Config: cfg, AllParams: allParameters, manifest: tools.Manifest{Description: cfg.Description, Parameters: allParameters.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, - }, nil + } + + return t, nil } // validate interface @@ -146,19 +192,20 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para // The parent in the request payload uses the tool's configured location. payloadParent := fmt.Sprintf("projects/%s/locations/%s", source.GetProjectID(), t.Location) - payload := &QueryDataRequest{ - Parent: payloadParent, - Prompt: query, - Context: t.Context, - GenerationOptions: t.GenerationOptions, + req := &geminidataanalyticspb.QueryDataRequest{ + Parent: payloadParent, + Prompt: query, } - bodyBytes, err := json.Marshal(payload) - if err != nil { - return nil, util.NewClientServerError("failed to marshal request payload", http.StatusInternalServerError, err) + if t.Context != nil { + req.Context = t.Context.QueryDataContext + } + + if t.GenerationOptions != nil { + req.GenerationOptions = t.GenerationOptions.GenerationOptions } - resp, err := source.RunQuery(ctx, tokenStr, bodyBytes) + resp, err := source.RunQuery(ctx, tokenStr, req) if err != nil { return nil, util.ProcessGcpError(err) } diff --git a/internal/tools/cloudgda/cloudgda_test.go b/internal/tools/cloudgda/cloudgda_test.go index d5e73658ea3c..73b29ccf174b 100644 --- a/internal/tools/cloudgda/cloudgda_test.go +++ b/internal/tools/cloudgda/cloudgda_test.go @@ -16,18 +16,15 @@ package cloudgda_test import ( "context" - "encoding/json" "fmt" - "io" - "net/http" - "net/http/httptest" "testing" + "cloud.google.com/go/geminidataanalytics/apiv1beta/geminidataanalyticspb" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "github.com/googleapis/genai-toolbox/internal/server" "github.com/googleapis/genai-toolbox/internal/server/resources" "github.com/googleapis/genai-toolbox/internal/sources" - cloudgdasrc "github.com/googleapis/genai-toolbox/internal/sources/cloudgda" "github.com/googleapis/genai-toolbox/internal/testutils" "github.com/googleapis/genai-toolbox/internal/tools" cloudgdatool "github.com/googleapis/genai-toolbox/internal/tools/cloudgda" @@ -77,23 +74,29 @@ func TestParseFromYaml(t *testing.T) { Location: "us-central1", AuthRequired: []string{}, Context: &cloudgdatool.QueryDataContext{ - DatasourceReferences: &cloudgdatool.DatasourceReferences{ - SpannerReference: &cloudgdatool.SpannerReference{ - DatabaseReference: &cloudgdatool.SpannerDatabaseReference{ - ProjectID: "cloud-db-nl2sql", - Region: "us-central1", - InstanceID: "evalbench", - DatabaseID: "financial", - Engine: cloudgdatool.SpannerEngineGoogleSQL, - }, - AgentContextReference: &cloudgdatool.AgentContextReference{ - ContextSetID: "projects/cloud-db-nl2sql/locations/us-east1/contextSets/bdf_gsql_gemini_all_templates", + QueryDataContext: &geminidataanalyticspb.QueryDataContext{ + DatasourceReferences: &geminidataanalyticspb.DatasourceReferences{ + References: &geminidataanalyticspb.DatasourceReferences_SpannerReference{ + SpannerReference: &geminidataanalyticspb.SpannerReference{ + DatabaseReference: &geminidataanalyticspb.SpannerDatabaseReference{ + ProjectId: "cloud-db-nl2sql", + Region: "us-central1", + InstanceId: "evalbench", + DatabaseId: "financial", + Engine: geminidataanalyticspb.SpannerDatabaseReference_GOOGLE_SQL, + }, + AgentContextReference: &geminidataanalyticspb.AgentContextReference{ + ContextSetId: "projects/cloud-db-nl2sql/locations/us-east1/contextSets/bdf_gsql_gemini_all_templates", + }, + }, }, }, }, }, GenerationOptions: &cloudgdatool.GenerationOptions{ - GenerateQueryResult: true, + GenerationOptions: &geminidataanalyticspb.GenerationOptions{ + GenerateQueryResult: true, + }, }, }, }, @@ -107,68 +110,63 @@ func TestParseFromYaml(t *testing.T) { if err != nil { t.Fatalf("unable to unmarshal: %s", err) } - if !cmp.Equal(tc.want, got) { + if !cmp.Equal(tc.want, got, cmpopts.IgnoreUnexported(geminidataanalyticspb.QueryDataContext{}, geminidataanalyticspb.DatasourceReferences{}, geminidataanalyticspb.SpannerReference{}, geminidataanalyticspb.SpannerDatabaseReference{}, geminidataanalyticspb.AgentContextReference{}, geminidataanalyticspb.GenerationOptions{}, geminidataanalyticspb.DatasourceReferences_SpannerReference{})) { t.Fatalf("incorrect parse: want %v, got %v", tc.want, got) } }) } } -// authRoundTripper is a mock http.RoundTripper that adds a dummy Authorization header. -type authRoundTripper struct { - Token string - Next http.RoundTripper +// fakeSource implements the compatibleSource interface for testing. +type fakeSource struct { + projectID string + useClientOAuth bool + expectedQuery string + expectedParent string + response *geminidataanalyticspb.QueryDataResponse } -func (rt *authRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - newReq := *req - newReq.Header = make(http.Header) - for k, v := range req.Header { - newReq.Header[k] = v - } - newReq.Header.Set("Authorization", rt.Token) - if rt.Next == nil { - return http.DefaultTransport.RoundTrip(&newReq) - } - return rt.Next.RoundTrip(&newReq) +func (f *fakeSource) GetProjectID() string { + return f.projectID +} + +func (f *fakeSource) UseClientAuthorization() bool { + return f.useClientOAuth +} + +func (f *fakeSource) SourceType() string { + return "cloud-gemini-data-analytics" } -type mockSource struct { - Type string - client *http.Client // Can be used to inject a specific client - baseURL string // BaseURL is needed to implement sources.Source.BaseURL - config cloudgdasrc.Config // to return from ToConfig +func (f *fakeSource) ToConfig() sources.SourceConfig { + return nil } -func (m *mockSource) SourceType() string { return m.Type } -func (m *mockSource) ToConfig() sources.SourceConfig { return m.config } -func (m *mockSource) GetClient(ctx context.Context, token string) (*http.Client, error) { - if m.client != nil { - return m.client, nil +func (f *fakeSource) Initialize(ctx context.Context, tracer interface{}) (sources.Source, error) { + return f, nil +} + +func (f *fakeSource) RunQuery(ctx context.Context, token string, req *geminidataanalyticspb.QueryDataRequest) (*geminidataanalyticspb.QueryDataResponse, error) { + if req.Prompt != f.expectedQuery { + return nil, fmt.Errorf("unexpected query: got %q, want %q", req.Prompt, f.expectedQuery) } - // Default client for testing if not explicitly set - transport := &http.Transport{} - authTransport := &authRoundTripper{ - Token: "Bearer test-access-token", // Dummy token - Next: transport, + if req.Parent != f.expectedParent { + return nil, fmt.Errorf("unexpected parent: got %q, want %q", req.Parent, f.expectedParent) } - return &http.Client{Transport: authTransport}, nil -} -func (m *mockSource) UseClientAuthorization() bool { return false } -func (m *mockSource) Initialize(ctx context.Context, tracer interface{}) (sources.Source, error) { - return m, nil + // Basic validation of context/options could be added here if needed, + // but the test case mainly checks if they are passed correctly via successful invocation. + + return f.response, nil } -func (m *mockSource) BaseURL() string { return m.baseURL } func TestInitialize(t *testing.T) { t.Parallel() + // Minimal fake source + fake := &fakeSource{projectID: "test-project"} + srcs := map[string]sources.Source{ - "gda-api-source": &cloudgdasrc.Source{ - Config: cloudgdasrc.Config{Name: "gda-api-source", Type: cloudgdasrc.SourceType, ProjectID: "test-project"}, - Client: &http.Client{}, - BaseURL: cloudgdasrc.Endpoint, - }, + "gda-api-source": fake, } tcs := []struct { @@ -187,9 +185,7 @@ func TestInitialize(t *testing.T) { }, } - // Add an incompatible source for testing - srcs["incompatible-source"] = &mockSource{Type: "another-type"} - + // No incompatible source for testing needed with fakeSource for _, tc := range tcs { tc := tc t.Run(tc.desc, func(t *testing.T) { @@ -206,92 +202,27 @@ func TestInitialize(t *testing.T) { func TestInvoke(t *testing.T) { t.Parallel() - // Mock the HTTP client and server for Invoke testing - serverMux := http.NewServeMux() - // Update expected URL path to include the location "us-central1" - serverMux.HandleFunc("/v1beta/projects/test-project/locations/global:queryData", func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPost { - t.Errorf("expected POST method, got %s", r.Method) - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - return - } - if r.Header.Get("Content-Type") != "application/json" { - t.Errorf("expected Content-Type application/json, got %s", r.Header.Get("Content-Type")) - http.Error(w, "Bad request", http.StatusBadRequest) - return - } - - // Read and unmarshal the request body - bodyBytes, err := io.ReadAll(r.Body) - if err != nil { - t.Errorf("failed to read request body: %v", err) - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return - } - var reqPayload cloudgdatool.QueryDataRequest - if err := json.Unmarshal(bodyBytes, &reqPayload); err != nil { - t.Errorf("failed to unmarshal request payload: %v", err) - http.Error(w, "Bad request", http.StatusBadRequest) - return - } - // Verify expected fields - if r.Header.Get("Authorization") == "" { - t.Errorf("expected Authorization header, got empty") - http.Error(w, "Unauthorized", http.StatusUnauthorized) - return - } - if reqPayload.Prompt != "How many accounts who have region in Prague are eligible for loans?" { - t.Errorf("unexpected prompt: %s", reqPayload.Prompt) - } + projectID := "test-project" + location := "us-central1" + query := "How many accounts who have region in Prague are eligible for loans?" + expectedParent := fmt.Sprintf("projects/%s/locations/%s", projectID, location) - // Verify payload's parent uses the tool's configured location - if reqPayload.Parent != fmt.Sprintf("projects/%s/locations/%s", "test-project", "us-central1") { - t.Errorf("unexpected payload parent: got %q, want %q", reqPayload.Parent, fmt.Sprintf("projects/%s/locations/%s", "test-project", "us-central1")) - } - - // Verify context from config - if reqPayload.Context == nil || - reqPayload.Context.DatasourceReferences == nil || - reqPayload.Context.DatasourceReferences.SpannerReference == nil || - reqPayload.Context.DatasourceReferences.SpannerReference.DatabaseReference == nil || - reqPayload.Context.DatasourceReferences.SpannerReference.DatabaseReference.ProjectID != "cloud-db-nl2sql" { - t.Errorf("unexpected context: %v", reqPayload.Context) - } - - // Verify generation options from config - if reqPayload.GenerationOptions == nil || !reqPayload.GenerationOptions.GenerateQueryResult { - t.Errorf("unexpected generation options: %v", reqPayload.GenerationOptions) - } - - // Simulate a successful response - resp := map[string]any{ - "queryResult": "SELECT count(*) FROM accounts WHERE region = 'Prague' AND eligible_for_loans = true;", - "naturalLanguageAnswer": "There are 5 accounts in Prague eligible for loans.", - } - _ = json.NewEncoder(w).Encode(resp) - }) - - mockServer := httptest.NewServer(serverMux) - defer mockServer.Close() - - ctx := testutils.ContextWithUserAgent(context.Background(), "test-user-agent") - - // Create an authenticated client that uses the mock server - authTransport := &authRoundTripper{ - Token: "Bearer test-access-token", - Next: mockServer.Client().Transport, + // Prepare expected response + expectedResp := &geminidataanalyticspb.QueryDataResponse{ + GeneratedQuery: "SELECT count(*) FROM accounts WHERE region = 'Prague' AND eligible_for_loans = true;", + NaturalLanguageAnswer: "There are 5 accounts in Prague eligible for loans.", } - authClient := &http.Client{Transport: authTransport} - // Create a real cloudgdasrc.Source but inject the authenticated client - mockGdaSource := &cloudgdasrc.Source{ - Config: cloudgdasrc.Config{Name: "mock-gda-source", Type: cloudgdasrc.SourceType, ProjectID: "test-project"}, - Client: authClient, - BaseURL: mockServer.URL, + fake := &fakeSource{ + projectID: projectID, + expectedQuery: query, + expectedParent: expectedParent, + response: expectedResp, } + srcs := map[string]sources.Source{ - "mock-gda-source": mockGdaSource, + "mock-gda-source": fake, } // Initialize the tool config with context @@ -300,25 +231,31 @@ func TestInvoke(t *testing.T) { Type: "cloud-gemini-data-analytics-query", Source: "mock-gda-source", Description: "Query Gemini Data Analytics", - Location: "us-central1", // Set location for the test + Location: location, Context: &cloudgdatool.QueryDataContext{ - DatasourceReferences: &cloudgdatool.DatasourceReferences{ - SpannerReference: &cloudgdatool.SpannerReference{ - DatabaseReference: &cloudgdatool.SpannerDatabaseReference{ - ProjectID: "cloud-db-nl2sql", - Region: "us-central1", - InstanceID: "evalbench", - DatabaseID: "financial", - Engine: cloudgdatool.SpannerEngineGoogleSQL, - }, - AgentContextReference: &cloudgdatool.AgentContextReference{ - ContextSetID: "projects/cloud-db-nl2sql/locations/us-east1/contextSets/bdf_gsql_gemini_all_templates", + QueryDataContext: &geminidataanalyticspb.QueryDataContext{ + DatasourceReferences: &geminidataanalyticspb.DatasourceReferences{ + References: &geminidataanalyticspb.DatasourceReferences_SpannerReference{ + SpannerReference: &geminidataanalyticspb.SpannerReference{ + DatabaseReference: &geminidataanalyticspb.SpannerDatabaseReference{ + ProjectId: "cloud-db-nl2sql", + Region: "us-central1", + InstanceId: "evalbench", + DatabaseId: "financial", + Engine: geminidataanalyticspb.SpannerDatabaseReference_GOOGLE_SQL, + }, + AgentContextReference: &geminidataanalyticspb.AgentContextReference{ + ContextSetId: "projects/cloud-db-nl2sql/locations/us-east1/contextSets/bdf_gsql_gemini_all_templates", + }, + }, }, }, }, }, GenerationOptions: &cloudgdatool.GenerationOptions{ - GenerateQueryResult: true, + GenerationOptions: &geminidataanalyticspb.GenerationOptions{ + GenerateQueryResult: true, + }, }, } @@ -329,24 +266,25 @@ func TestInvoke(t *testing.T) { // Prepare parameters for invocation - ONLY query params := parameters.ParamValues{ - {Name: "query", Value: "How many accounts who have region in Prague are eligible for loans?"}, + {Name: "query", Value: query}, } resourceMgr := resources.NewResourceManager(srcs, nil, nil, nil, nil, nil, nil) + ctx := testutils.ContextWithUserAgent(context.Background(), "test-user-agent") + // Invoke the tool - result, err := tool.Invoke(ctx, resourceMgr, params, "") // No accessToken needed for ADC client + result, err := tool.Invoke(ctx, resourceMgr, params, "") if err != nil { t.Fatalf("tool invocation failed: %v", err) } - // Validate the result - expectedResult := map[string]any{ - "queryResult": "SELECT count(*) FROM accounts WHERE region = 'Prague' AND eligible_for_loans = true;", - "naturalLanguageAnswer": "There are 5 accounts in Prague eligible for loans.", + gotResp, ok := result.(*geminidataanalyticspb.QueryDataResponse) + if !ok { + t.Fatalf("expected result type *geminidataanalyticspb.QueryDataResponse, got %T", result) } - if !cmp.Equal(expectedResult, result) { - t.Errorf("unexpected result: got %v, want %v", result, expectedResult) + if diff := cmp.Diff(expectedResp, gotResp, cmpopts.IgnoreUnexported(geminidataanalyticspb.QueryDataResponse{})); diff != "" { + t.Errorf("unexpected result mismatch (-want +got):\n%s", diff) } } diff --git a/internal/tools/cloudgda/types.go b/internal/tools/cloudgda/types.go deleted file mode 100644 index 8e82cb50c226..000000000000 --- a/internal/tools/cloudgda/types.go +++ /dev/null @@ -1,116 +0,0 @@ -// Copyright 2025 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package cloudgda - -// See full service definition at: https://github.com/googleapis/googleapis/blob/master/google/cloud/geminidataanalytics/v1beta/data_chat_service.proto - -// QueryDataRequest represents the JSON body for the queryData API -type QueryDataRequest struct { - Parent string `json:"parent"` - Prompt string `json:"prompt"` - Context *QueryDataContext `json:"context,omitempty"` - GenerationOptions *GenerationOptions `json:"generationOptions,omitempty"` -} - -// QueryDataContext reflects the proto definition for the query context. -type QueryDataContext struct { - DatasourceReferences *DatasourceReferences `json:"datasourceReferences,omitempty" yaml:"datasourceReferences,omitempty"` -} - -// DatasourceReferences reflects the proto definition for datasource references, using a oneof. -type DatasourceReferences struct { - SpannerReference *SpannerReference `json:"spannerReference,omitempty" yaml:"spannerReference,omitempty"` - AlloyDBReference *AlloyDBReference `json:"alloydb,omitempty" yaml:"alloydb,omitempty"` - CloudSQLReference *CloudSQLReference `json:"cloudSqlReference,omitempty" yaml:"cloudSqlReference,omitempty"` -} - -// SpannerReference reflects the proto definition for Spanner database reference. -type SpannerReference struct { - DatabaseReference *SpannerDatabaseReference `json:"databaseReference,omitempty" yaml:"databaseReference,omitempty"` - AgentContextReference *AgentContextReference `json:"agentContextReference,omitempty" yaml:"agentContextReference,omitempty"` -} - -// SpannerDatabaseReference reflects the proto definition for a Spanner database reference. -type SpannerDatabaseReference struct { - Engine SpannerEngine `json:"engine,omitempty" yaml:"engine,omitempty"` - ProjectID string `json:"projectId,omitempty" yaml:"projectId,omitempty"` - Region string `json:"region,omitempty" yaml:"region,omitempty"` - InstanceID string `json:"instanceId,omitempty" yaml:"instanceId,omitempty"` - DatabaseID string `json:"databaseId,omitempty" yaml:"databaseId,omitempty"` - TableIDs []string `json:"tableIds,omitempty" yaml:"tableIds,omitempty"` -} - -// SpannerEngine represents the engine of the Spanner instance. -type SpannerEngine string - -const ( - SpannerEngineUnspecified SpannerEngine = "ENGINE_UNSPECIFIED" - SpannerEngineGoogleSQL SpannerEngine = "GOOGLE_SQL" - SpannerEnginePostgreSQL SpannerEngine = "POSTGRESQL" -) - -// AlloyDBReference reflects the proto definition for an AlloyDB database reference. -type AlloyDBReference struct { - DatabaseReference *AlloyDBDatabaseReference `json:"databaseReference,omitempty" yaml:"databaseReference,omitempty"` - AgentContextReference *AgentContextReference `json:"agentContextReference,omitempty" yaml:"agentContextReference,omitempty"` -} - -// AlloyDBDatabaseReference reflects the proto definition for an AlloyDB database reference. -type AlloyDBDatabaseReference struct { - ProjectID string `json:"projectId,omitempty" yaml:"projectId,omitempty"` - Region string `json:"region,omitempty" yaml:"region,omitempty"` - ClusterID string `json:"clusterId,omitempty" yaml:"clusterId,omitempty"` - InstanceID string `json:"instanceId,omitempty" yaml:"instanceId,omitempty"` - DatabaseID string `json:"databaseId,omitempty" yaml:"databaseId,omitempty"` - TableIDs []string `json:"tableIds,omitempty" yaml:"tableIds,omitempty"` -} - -// CloudSQLReference reflects the proto definition for a Cloud SQL database reference. -type CloudSQLReference struct { - DatabaseReference *CloudSQLDatabaseReference `json:"databaseReference,omitempty" yaml:"databaseReference,omitempty"` - AgentContextReference *AgentContextReference `json:"agentContextReference,omitempty" yaml:"agentContextReference,omitempty"` -} - -// CloudSQLDatabaseReference reflects the proto definition for a Cloud SQL database reference. -type CloudSQLDatabaseReference struct { - Engine CloudSQLEngine `json:"engine,omitempty" yaml:"engine,omitempty"` - ProjectID string `json:"projectId,omitempty" yaml:"projectId,omitempty"` - Region string `json:"region,omitempty" yaml:"region,omitempty"` - InstanceID string `json:"instanceId,omitempty" yaml:"instanceId,omitempty"` - DatabaseID string `json:"databaseId,omitempty" yaml:"databaseId,omitempty"` - TableIDs []string `json:"tableIds,omitempty" yaml:"tableIds,omitempty"` -} - -// CloudSQLEngine represents the engine of the Cloud SQL instance. -type CloudSQLEngine string - -const ( - CloudSQLEngineUnspecified CloudSQLEngine = "ENGINE_UNSPECIFIED" - CloudSQLEnginePostgreSQL CloudSQLEngine = "POSTGRESQL" - CloudSQLEngineMySQL CloudSQLEngine = "MYSQL" -) - -// AgentContextReference reflects the proto definition for agent context. -type AgentContextReference struct { - ContextSetID string `json:"contextSetId,omitempty" yaml:"contextSetId,omitempty"` -} - -// GenerationOptions reflects the proto definition for generation options. -type GenerationOptions struct { - GenerateQueryResult bool `json:"generateQueryResult" yaml:"generateQueryResult"` - GenerateNaturalLanguageAnswer bool `json:"generateNaturalLanguageAnswer" yaml:"generateNaturalLanguageAnswer"` - GenerateExplanation bool `json:"generateExplanation" yaml:"generateExplanation"` - GenerateDisambiguationQuestion bool `json:"generateDisambiguationQuestion" yaml:"generateDisambiguationQuestion"` -} diff --git a/internal/tools/looker/lookercommon/lookercommon.go b/internal/tools/looker/lookercommon/lookercommon.go index 1749a459bd0d..683b82c01cd9 100644 --- a/internal/tools/looker/lookercommon/lookercommon.go +++ b/internal/tools/looker/lookercommon/lookercommon.go @@ -330,3 +330,45 @@ func DeleteProjectDirectory(l *v4.LookerSDK, projectId string, directoryPath str path := fmt.Sprintf("/projects/%s/directories", url.PathEscape(projectId)) return l.AuthSession.Do(&result, "DELETE", "/4.0", path, query, nil, options) } + +type ProjectGeneratorColumn struct { + ColumnName string `json:"column_name"` +} + +type ProjectGeneratorTable struct { + Schema string `json:"schema"` + TableName string `json:"table_name"` + PrimaryKey *string `json:"primary_key,omitempty"` + BaseView *bool `json:"base_view,omitempty"` + Columns []ProjectGeneratorColumn `json:"columns,omitempty"` +} + +type ProjectGeneratorRequestBody struct { + Tables []ProjectGeneratorTable `json:"tables"` +} + +type ProjectGeneratorQueryParams struct { + Connection string `json:"connection"` + FileTypeForExplores string `json:"file_type_for_explores"` + FolderName string `json:"folder_name,omitempty"` +} + +func CreateViewsFromTables(ctx context.Context, l *v4.LookerSDK, projectId string, queryParams ProjectGeneratorQueryParams, reqBody ProjectGeneratorRequestBody, options *rtl.ApiSettings) error { + path := fmt.Sprintf("/projects/%s/generate", url.PathEscape(projectId)) + + // Construct query parameter map + query := map[string]any{ + "connection": queryParams.Connection, + "file_type_for_explores": queryParams.FileTypeForExplores, + "folder_name": queryParams.FolderName, + } + + // Pass the Tables slice directly as the body, not the wrapped struct. + // The API spec defines `tables` as `body_param ... array: true`, + // which means the body itself should be the array. + err := l.AuthSession.Do(nil, "POST", "/4.0", path, query, reqBody.Tables, options) + + logger, _ := util.LoggerFromContext(ctx) + logger.DebugContext(ctx, fmt.Sprintf("generating views with request: query=%v body=%v error=%v", query, reqBody.Tables, err)) + return err +} diff --git a/internal/tools/looker/lookercreateviewfromtable/lookercreateviewfromtable.go b/internal/tools/looker/lookercreateviewfromtable/lookercreateviewfromtable.go new file mode 100644 index 000000000000..08c6d75c15b1 --- /dev/null +++ b/internal/tools/looker/lookercreateviewfromtable/lookercreateviewfromtable.go @@ -0,0 +1,272 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package lookercreateviewfromtable + +import ( + "context" + "fmt" + "net/http" + + yaml "github.com/goccy/go-yaml" + "github.com/googleapis/genai-toolbox/internal/embeddingmodels" + "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" + "github.com/googleapis/genai-toolbox/internal/util" + "github.com/googleapis/genai-toolbox/internal/util/parameters" + + "github.com/looker-open-source/sdk-codegen/go/rtl" + v4 "github.com/looker-open-source/sdk-codegen/go/sdk/v4" +) + +const resourceType string = "looker-create-view-from-table" + +func init() { + if !tools.Register(resourceType, newConfig) { + panic(fmt.Sprintf("tool type %q already registered", resourceType)) + } +} + +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { + actual := Config{Name: name} + if err := decoder.DecodeContext(ctx, &actual); err != nil { + return nil, err + } + return actual, nil +} + +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerApiSettings() *rtl.ApiSettings + GetLookerSDK(string) (*v4.LookerSDK, error) +} + +type Config struct { + Name string `yaml:"name" validate:"required"` + Type string `yaml:"type" validate:"required"` + Source string `yaml:"source" validate:"required"` + Description string `yaml:"description" validate:"required"` + AuthRequired []string `yaml:"authRequired"` + Annotations *tools.ToolAnnotations `yaml:"annotations,omitempty"` +} + +// validate interface +var _ tools.ToolConfig = Config{} + +func (cfg Config) ToolConfigType() string { + return resourceType +} + +func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { + projectIdParameter := parameters.NewStringParameter("project_id", "The id of the project to create the view in.") + connectionParameter := parameters.NewStringParameter("connection", "The database connection name.") + + tableDef := parameters.NewMapParameter("table", "Table definition.", "") + tablesParameter := parameters.NewArrayParameter("tables", `The tables to generate views for. + Each item must be a map with: + - schema (string, required) + - table_name (string, required) + - primary_key (string, optional) + - base_view (boolean, optional) + - columns (array of objects, optional): Each object must have 'column_name' (string).`, tableDef) + + folderNameParameter := parameters.NewStringParameterWithDefault("folder_name", "views", "The folder to place the view files in (e.g., 'views').") + + params := parameters.Parameters{projectIdParameter, connectionParameter, tablesParameter, folderNameParameter} + + annotations := cfg.Annotations + if annotations == nil { + readOnlyHint := false + annotations = &tools.ToolAnnotations{ + ReadOnlyHint: &readOnlyHint, + } + } + + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, annotations) + + // finish tool setup + return Tool{ + Config: cfg, + Parameters: params, + manifest: tools.Manifest{ + Description: cfg.Description, + Parameters: params.Manifest(), + AuthRequired: cfg.AuthRequired, + }, + mcpManifest: mcpManifest, + }, nil +} + +// validate interface +var _ tools.Tool = Tool{} + +type Tool struct { + Config + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest +} + +func (t Tool) ToConfig() tools.ToolConfig { + return t.Config +} + +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) + if err != nil { + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) + } + + logger, err := util.LoggerFromContext(ctx) + if err != nil { + return nil, util.NewClientServerError(fmt.Sprintf("error getting logger from context: %s", err), http.StatusInternalServerError, err) + } + + sdk, err := source.GetLookerSDK(string(accessToken)) + if err != nil { + return nil, util.NewClientServerError(fmt.Sprintf("error getting sdk: %v", err), http.StatusInternalServerError, err) + } + + mapParams := params.AsMap() + projectId, ok := mapParams["project_id"].(string) + if !ok { + return nil, util.NewAgentError(fmt.Sprintf("'project_id' must be a string, got %T", mapParams["project_id"]), nil) + } + connection, ok := mapParams["connection"].(string) + if !ok { + return nil, util.NewAgentError(fmt.Sprintf("'connection' must be a string, got %T", mapParams["connection"]), nil) + } + folderName, ok := mapParams["folder_name"].(string) + if !ok { + return nil, util.NewAgentError(fmt.Sprintf("'folder_name' must be a string, got %T", mapParams["folder_name"]), nil) + } + + tablesSlice, ok := mapParams["tables"].([]any) + if !ok { + return nil, util.NewAgentError(fmt.Sprintf("'tables' must be an array, got %T", mapParams["tables"]), nil) + } + + logger.DebugContext(ctx, "generating views with request", "tables", tablesSlice) + + var generatorTables []lookercommon.ProjectGeneratorTable + for _, tRaw := range tablesSlice { + t, ok := tRaw.(map[string]any) + if !ok { + return nil, util.NewClientServerError(fmt.Sprintf("expected map in tables list, got %T", tRaw), http.StatusInternalServerError, nil) + } + + var schema, tableName string + var primaryKey *string + var baseView *bool + var columns []lookercommon.ProjectGeneratorColumn + + if s, ok := t["schema"].(string); ok { + schema = s + } + if tn, ok := t["table_name"].(string); ok { + tableName = tn + } + // Enforce required fields for map input + if schema == "" || tableName == "" { + return nil, util.NewClientServerError("schema and table_name are required in table map", http.StatusInternalServerError, nil) + } + + if pk, ok := t["primary_key"].(string); ok { + primaryKey = &pk + } + if bv, ok := t["base_view"].(bool); ok { + baseView = &bv + } + if colsRaw, ok := t["columns"].([]any); ok { + for _, cRaw := range colsRaw { + if cMap, ok := cRaw.(map[string]any); ok { + if cName, ok := cMap["column_name"].(string); ok { + columns = append(columns, lookercommon.ProjectGeneratorColumn{ColumnName: cName}) + } + } + } + } + + if tableName == "" { + continue // Skip invalid entries + } + + generatorTables = append(generatorTables, lookercommon.ProjectGeneratorTable{ + Schema: schema, + TableName: tableName, + PrimaryKey: primaryKey, + BaseView: baseView, + Columns: columns, + }) + } + + queryParams := lookercommon.ProjectGeneratorQueryParams{ + Connection: connection, + FileTypeForExplores: "none", + FolderName: folderName, + } + + reqBody := lookercommon.ProjectGeneratorRequestBody{ + Tables: generatorTables, + } + + logger.DebugContext(ctx, "generating views with request", "query", queryParams, "body", reqBody) + + err = lookercommon.CreateViewsFromTables(ctx, sdk, projectId, queryParams, reqBody, source.LookerApiSettings()) + if err != nil { + return nil, util.NewClientServerError(fmt.Sprintf("error generating views: %s", err), http.StatusInternalServerError, err) + } + + return map[string]string{ + "status": "success", + "message": fmt.Sprintf("Triggered view generation for project %s in folder %s", projectId, folderName), + }, nil +} + +func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { + return parameters.EmbedParams(ctx, t.Parameters, paramValues, embeddingModelsMap, nil) +} + +func (t Tool) Manifest() tools.Manifest { + return t.manifest +} + +func (t Tool) McpManifest() tools.McpManifest { + return t.mcpManifest +} + +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + +func (t Tool) Authorized(verifiedAuthServices []string) bool { + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) +} + +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil +} + +func (t Tool) GetParameters() parameters.Parameters { + return t.Parameters +} diff --git a/internal/tools/looker/lookercreateviewfromtable/lookercreateviewfromtable_test.go b/internal/tools/looker/lookercreateviewfromtable/lookercreateviewfromtable_test.go new file mode 100644 index 000000000000..265345442f32 --- /dev/null +++ b/internal/tools/looker/lookercreateviewfromtable/lookercreateviewfromtable_test.go @@ -0,0 +1,109 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lookercreateviewfromtable_test + +import ( + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/testutils" + lkr "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercreateviewfromtable" +) + +func TestParseFromYamlLookerCreateViewFromTable(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + desc string + in string + want server.ToolConfigs + }{ + { + desc: "basic example", + in: ` + kind: tools + name: example_tool + type: looker-create-view-from-table + source: my-instance + description: some description + `, + want: server.ToolConfigs{ + "example_tool": lkr.Config{ + Name: "example_tool", + Type: "looker-create-view-from-table", + Source: "my-instance", + Description: "some description", + AuthRequired: []string{}, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + // Parse contents + _, _, _, got, _, _, err := server.UnmarshalResourceConfig(ctx, testutils.FormatYaml(tc.in)) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if diff := cmp.Diff(tc.want, got); diff != "" { + t.Fatalf("incorrect parse: diff %v", diff) + } + }) + } + +} + +func TestFailParseFromYamlLookerCreateViewFromTable(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + desc string + in string + err string + }{ + { + desc: "Invalid method", + in: ` + kind: tools + name: example_tool + type: looker-create-view-from-table + source: my-instance + method: GOT + description: some description + `, + err: "error unmarshaling tools: unable to parse tool \"example_tool\" as type \"looker-create-view-from-table\": [3:1] unknown field \"method\"", + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + // Parse contents + _, _, _, _, _, _, err := server.UnmarshalResourceConfig(ctx, testutils.FormatYaml(tc.in)) + if err == nil { + t.Fatalf("expect parsing to fail") + } + errStr := err.Error() + if !strings.Contains(errStr, tc.err) { + t.Fatalf("unexpected error string: got %q, want substring %q", errStr, tc.err) + } + }) + } + +} diff --git a/internal/tools/looker/lookergetlookmltests/lookergetlookmltests.go b/internal/tools/looker/lookergetlookmltests/lookergetlookmltests.go new file mode 100644 index 000000000000..d1e9df65700d --- /dev/null +++ b/internal/tools/looker/lookergetlookmltests/lookergetlookmltests.go @@ -0,0 +1,178 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lookergetlookmltests + +import ( + "context" + "fmt" + "net/http" + + yaml "github.com/goccy/go-yaml" + "github.com/googleapis/genai-toolbox/internal/embeddingmodels" + "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" + "github.com/googleapis/genai-toolbox/internal/util/parameters" + + "github.com/looker-open-source/sdk-codegen/go/rtl" + v4 "github.com/looker-open-source/sdk-codegen/go/sdk/v4" +) + +const resourceType string = "looker-get-lookml-tests" + +func init() { + if !tools.Register(resourceType, newConfig) { + panic(fmt.Sprintf("tool type %q already registered", resourceType)) + } +} + +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { + actual := Config{Name: name} + if err := decoder.DecodeContext(ctx, &actual); err != nil { + return nil, err + } + return actual, nil +} + +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerApiSettings() *rtl.ApiSettings + GetLookerSDK(string) (*v4.LookerSDK, error) +} + +type Config struct { + Name string `yaml:"name" validate:"required"` + Type string `yaml:"type" validate:"required"` + Source string `yaml:"source" validate:"required"` + Description string `yaml:"description" validate:"required"` + AuthRequired []string `yaml:"authRequired"` + Annotations *tools.ToolAnnotations `yaml:"annotations,omitempty"` +} + +// validate interface +var _ tools.ToolConfig = Config{} + +func (cfg Config) ToolConfigType() string { + return resourceType +} + +func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { + projectIdParameter := parameters.NewStringParameter("project_id", "The unique ID of the LookML project.") + fileIdParameter := parameters.NewStringParameterWithRequired("file_id", "Optional ID of the file to filter tests by. This must be the complete file path from the project root (e.g., 'models/my_model.model.lkml').", false) + params := parameters.Parameters{projectIdParameter, fileIdParameter} + + annotations := cfg.Annotations + if annotations == nil { + readOnlyHint := true + annotations = &tools.ToolAnnotations{ + ReadOnlyHint: &readOnlyHint, + } + } + + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, annotations) + + // finish tool setup + return Tool{ + Config: cfg, + Parameters: params, + manifest: tools.Manifest{ + Description: cfg.Description, + Parameters: params.Manifest(), + AuthRequired: cfg.AuthRequired, + }, + mcpManifest: mcpManifest, + }, nil +} + +// validate interface +var _ tools.Tool = Tool{} + +type Tool struct { + Config + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest +} + +func (t Tool) ToConfig() tools.ToolConfig { + return t.Config +} + +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) + if err != nil { + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) + } + + sdk, err := source.GetLookerSDK(string(accessToken)) + if err != nil { + return nil, util.NewClientServerError(fmt.Sprintf("error getting sdk: %v", err), http.StatusInternalServerError, err) + } + + mapParams := params.AsMap() + projectId, ok := mapParams["project_id"].(string) + if !ok { + return nil, util.NewAgentError(fmt.Sprintf("'project_id' must be a string, got %T", mapParams["project_id"]), nil) + } + + var fileId string + if val, ok := mapParams["file_id"].(string); ok { + fileId = val + } + + resp, err := sdk.AllLookmlTests(projectId, fileId, source.LookerApiSettings()) + if err != nil { + return nil, util.NewClientServerError(fmt.Sprintf("error retrieving lookml tests: %s", err), http.StatusInternalServerError, err) + } + + return resp, nil +} + +func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { + return parameters.EmbedParams(ctx, t.Parameters, paramValues, embeddingModelsMap, nil) +} + +func (t Tool) Manifest() tools.Manifest { + return t.manifest +} + +func (t Tool) McpManifest() tools.McpManifest { + return t.mcpManifest +} + +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + +func (t Tool) Authorized(verifiedAuthServices []string) bool { + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) +} + +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil +} + +func (t Tool) GetParameters() parameters.Parameters { + return t.Parameters +} diff --git a/internal/tools/looker/lookergetlookmltests/lookergetlookmltests_test.go b/internal/tools/looker/lookergetlookmltests/lookergetlookmltests_test.go new file mode 100644 index 000000000000..e00b2bdac226 --- /dev/null +++ b/internal/tools/looker/lookergetlookmltests/lookergetlookmltests_test.go @@ -0,0 +1,109 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lookergetlookmltests_test + +import ( + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/testutils" + lkr "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetlookmltests" +) + +func TestParseFromYamlLookerGetLookmlTests(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + desc string + in string + want server.ToolConfigs + }{ + { + desc: "basic example", + in: ` + kind: tools + name: example_tool + type: looker-get-lookml-tests + source: my-instance + description: some description + `, + want: server.ToolConfigs{ + "example_tool": lkr.Config{ + Name: "example_tool", + Type: "looker-get-lookml-tests", + Source: "my-instance", + Description: "some description", + AuthRequired: []string{}, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + // Parse contents + _, _, _, got, _, _, err := server.UnmarshalResourceConfig(ctx, testutils.FormatYaml(tc.in)) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if diff := cmp.Diff(tc.want, got); diff != "" { + t.Fatalf("incorrect parse: diff %v", diff) + } + }) + } + +} + +func TestFailParseFromYamlLookerGetAllLookmlTests(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + desc string + in string + err string + }{ + { + desc: "Invalid method", + in: ` + kind: tools + name: example_tool + type: looker-get-lookml-tests + source: my-instance + method: GOT + description: some description + `, + err: "error unmarshaling tools: unable to parse tool \"example_tool\" as type \"looker-get-lookml-tests\": [3:1] unknown field \"method\"", + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + // Parse contents + _, _, _, _, _, _, err := server.UnmarshalResourceConfig(ctx, testutils.FormatYaml(tc.in)) + if err == nil { + t.Fatalf("expect parsing to fail") + } + errStr := err.Error() + if !strings.Contains(errStr, tc.err) { + t.Fatalf("unexpected error string: got %q, want substring %q", errStr, tc.err) + } + }) + } + +} diff --git a/internal/tools/looker/lookerrunlookmltests/lookerrunlookmltests.go b/internal/tools/looker/lookerrunlookmltests/lookerrunlookmltests.go new file mode 100644 index 000000000000..43f242203a7d --- /dev/null +++ b/internal/tools/looker/lookerrunlookmltests/lookerrunlookmltests.go @@ -0,0 +1,199 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package lookerrunlookmltests + +import ( + "context" + "fmt" + "net/http" + + yaml "github.com/goccy/go-yaml" + "github.com/googleapis/genai-toolbox/internal/embeddingmodels" + "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" + "github.com/googleapis/genai-toolbox/internal/util/parameters" + + "github.com/looker-open-source/sdk-codegen/go/rtl" + v4 "github.com/looker-open-source/sdk-codegen/go/sdk/v4" +) + +const resourceType string = "looker-run-lookml-tests" + +func init() { + if !tools.Register(resourceType, newConfig) { + panic(fmt.Sprintf("tool type %q already registered", resourceType)) + } +} + +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { + actual := Config{Name: name} + if err := decoder.DecodeContext(ctx, &actual); err != nil { + return nil, err + } + return actual, nil +} + +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerApiSettings() *rtl.ApiSettings + GetLookerSDK(string) (*v4.LookerSDK, error) +} + +type Config struct { + Name string `yaml:"name" validate:"required"` + Type string `yaml:"type" validate:"required"` + Source string `yaml:"source" validate:"required"` + Description string `yaml:"description" validate:"required"` + AuthRequired []string `yaml:"authRequired"` + Annotations *tools.ToolAnnotations `yaml:"annotations,omitempty"` +} + +// validate interface +var _ tools.ToolConfig = Config{} + +func (cfg Config) ToolConfigType() string { + return resourceType +} + +func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { + projectIdParameter := parameters.NewStringParameter("project_id", "The id of the project to run LookML tests for.") + fileIdParameter := parameters.NewStringParameterWithRequired("file_id", "Optional id of the file to run tests for.", false) + testParameter := parameters.NewStringParameterWithRequired("test", "Optional name of the test to run.", false) + modelParameter := parameters.NewStringParameterWithRequired("model", "Optional name of the model to run tests for.", false) + params := parameters.Parameters{projectIdParameter, fileIdParameter, testParameter, modelParameter} + + annotations := cfg.Annotations + if annotations == nil { + readOnlyHint := true + annotations = &tools.ToolAnnotations{ + ReadOnlyHint: &readOnlyHint, + } + } + + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, annotations) + + // finish tool setup + return Tool{ + Config: cfg, + Parameters: params, + manifest: tools.Manifest{ + Description: cfg.Description, + Parameters: params.Manifest(), + AuthRequired: cfg.AuthRequired, + }, + mcpManifest: mcpManifest, + }, nil +} + +// validate interface +var _ tools.Tool = Tool{} + +type Tool struct { + Config + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest +} + +func (t Tool) ToConfig() tools.ToolConfig { + return t.Config +} + +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) + if err != nil { + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) + } + + sdk, err := source.GetLookerSDK(string(accessToken)) + if err != nil { + return nil, util.NewClientServerError(fmt.Sprintf("error getting sdk: %v", err), http.StatusInternalServerError, err) + } + + mapParams := params.AsMap() + projectId, ok := mapParams["project_id"].(string) + if !ok { + return nil, util.NewAgentError(fmt.Sprintf("'project_id' must be a string, got %T", mapParams["project_id"]), nil) + } + + var fileId *string + if val, ok := mapParams["file_id"].(string); ok && val != "" { + fileId = &val + } + + var test *string + if val, ok := mapParams["test"].(string); ok && val != "" { + test = &val + } + + var model *string + if val, ok := mapParams["model"].(string); ok && val != "" { + model = &val + } + + req := v4.RequestRunLookmlTest{ + ProjectId: projectId, + FileId: fileId, + Test: test, + Model: model, + } + + resp, err := sdk.RunLookmlTest(req, source.LookerApiSettings()) + if err != nil { + return nil, util.NewClientServerError(fmt.Sprintf("error running lookml tests: %s", err), http.StatusInternalServerError, err) + } + + // Filter out pointer fields for better JSON marshaling in basic map if needed, + // but the SDK struct usually has JSON tags. + // Returning directly as it should marshal correctly. + return resp, nil +} + +func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { + return parameters.EmbedParams(ctx, t.Parameters, paramValues, embeddingModelsMap, nil) +} + +func (t Tool) Manifest() tools.Manifest { + return t.manifest +} + +func (t Tool) McpManifest() tools.McpManifest { + return t.mcpManifest +} + +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + +func (t Tool) Authorized(verifiedAuthServices []string) bool { + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) +} + +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil +} + +func (t Tool) GetParameters() parameters.Parameters { + return t.Parameters +} diff --git a/internal/tools/looker/lookerrunlookmltests/lookerrunlookmltests_test.go b/internal/tools/looker/lookerrunlookmltests/lookerrunlookmltests_test.go new file mode 100644 index 000000000000..6352fb4257f9 --- /dev/null +++ b/internal/tools/looker/lookerrunlookmltests/lookerrunlookmltests_test.go @@ -0,0 +1,109 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lookerrunlookmltests_test + +import ( + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/testutils" + lkr "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerrunlookmltests" +) + +func TestParseFromYamlLookerRunLookmlTests(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + desc string + in string + want server.ToolConfigs + }{ + { + desc: "basic example", + in: ` + kind: tools + name: example_tool + type: looker-run-lookml-tests + source: my-instance + description: some description + `, + want: server.ToolConfigs{ + "example_tool": lkr.Config{ + Name: "example_tool", + Type: "looker-run-lookml-tests", + Source: "my-instance", + Description: "some description", + AuthRequired: []string{}, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + // Parse contents + _, _, _, got, _, _, err := server.UnmarshalResourceConfig(ctx, testutils.FormatYaml(tc.in)) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if diff := cmp.Diff(tc.want, got); diff != "" { + t.Fatalf("incorrect parse: diff %v", diff) + } + }) + } + +} + +func TestFailParseFromYamlLookerRunLookmlTests(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + desc string + in string + err string + }{ + { + desc: "Invalid method", + in: ` + kind: tools + name: example_tool + type: looker-run-lookml-tests + source: my-instance + method: GOT + description: some description + `, + err: "error unmarshaling tools: unable to parse tool \"example_tool\" as type \"looker-run-lookml-tests\": [3:1] unknown field \"method\"", + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + // Parse contents + _, _, _, _, _, _, err := server.UnmarshalResourceConfig(ctx, testutils.FormatYaml(tc.in)) + if err == nil { + t.Fatalf("expect parsing to fail") + } + errStr := err.Error() + if !strings.Contains(errStr, tc.err) { + t.Fatalf("unexpected error string: got %q, want substring %q", errStr, tc.err) + } + }) + } + +} diff --git a/internal/util/parameters/parameters.go b/internal/util/parameters/parameters.go index 7c991f61be61..a249cc77a282 100644 --- a/internal/util/parameters/parameters.go +++ b/internal/util/parameters/parameters.go @@ -144,7 +144,7 @@ func ParseParams(ps Parameters, data map[string]any, claimsMap map[string]map[st // parse non auth-required parameter var ok bool v, ok = data[name] - if !ok { + if !ok || v == nil { v = p.GetDefault() // if the parameter is required and no value given, throw an error if CheckParamRequired(p.GetRequired(), v) { diff --git a/internal/util/parameters/parameters_test.go b/internal/util/parameters/parameters_test.go index 624a05ab03ef..7542bccfa134 100644 --- a/internal/util/parameters/parameters_test.go +++ b/internal/util/parameters/parameters_test.go @@ -1362,6 +1362,25 @@ func TestParametersParse(t *testing.T) { } }) } + t.Run("CheckNullForRequiredParam", func(t *testing.T) { + // Define a required string parameter + params := parameters.Parameters{ + parameters.NewStringParameter("required_param", "this is required"), + } + + // Input map with explicit nil + input := map[string]any{ + "required_param": nil, + } + + // Call ParseParams + _, err := parameters.ParseParams(params, input, nil) + + // Expect an error because the parameter is required + if err == nil { + t.Errorf("ParseParams allowed explicit nil for required parameter, expected error") + } + }) } func TestAuthParametersParse(t *testing.T) { diff --git a/tests/bigquery/bigquery_integration_test.go b/tests/bigquery/bigquery_integration_test.go index 30307296f1e4..575c101b7580 100644 --- a/tests/bigquery/bigquery_integration_test.go +++ b/tests/bigquery/bigquery_integration_test.go @@ -173,7 +173,7 @@ func TestBigQueryToolEndpoints(t *testing.T) { datasetInfoWant := "\"Location\":\"US\",\"DefaultTableExpiration\":0,\"Labels\":null,\"Access\":" tableInfoWant := "{\"Name\":\"\",\"Location\":\"US\",\"Description\":\"\",\"Schema\":[{\"Name\":\"id\"" ddlWant := `"Query executed successfully and returned no content."` - dataInsightsWant := `(?s)Schema Resolved.*Retrieval Query.*SQL Generated.*Answer` + dataInsightsWant := `(?s)(Schema Resolved.*)?(Retrieval Query.*)?SQL Generated.*Data Retrieved.*Answer` // Partial message; the full error message is too long. mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"error processing GCP request: failed to insert dry run job: googleapi: Error 400: Syntax error: Unexpected identifier \"SELEC\" at [1:1]` mcpSelect1Want := `{"jsonrpc":"2.0","id":"invoke my-auth-required-tool","result":{"content":[{"type":"text","text":"{\"f0_\":1}"}]}}` @@ -2393,7 +2393,7 @@ func runBigQueryConversationalAnalyticsInvokeTest(t *testing.T, datasetName, tab `{"user_query_with_context": "What are the names in the table?", "table_references": %q}`, tableRefsJSON, ))), - want: "[{\"f0_\":1}]", + want: dataInsightsWant, isErr: false, }, { diff --git a/tests/cloudgda/cloud_gda_integration_test.go b/tests/cloudgda/cloud_gda_integration_test.go index 24c0cab1cbe1..557f80bdd93e 100644 --- a/tests/cloudgda/cloud_gda_integration_test.go +++ b/tests/cloudgda/cloud_gda_integration_test.go @@ -18,78 +18,75 @@ import ( "bytes" "context" "encoding/json" + "fmt" + "net" "net/http" - "net/http/httptest" - "net/url" "regexp" "strings" "testing" "time" + geminidataanalytics "cloud.google.com/go/geminidataanalytics/apiv1beta" + "cloud.google.com/go/geminidataanalytics/apiv1beta/geminidataanalyticspb" "github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc" + source "github.com/googleapis/genai-toolbox/internal/sources/cloudgda" "github.com/googleapis/genai-toolbox/internal/testutils" "github.com/googleapis/genai-toolbox/internal/tools/cloudgda" "github.com/googleapis/genai-toolbox/tests" + "google.golang.org/api/option" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" ) var ( cloudGdaToolType = "cloud-gemini-data-analytics-query" ) -type cloudGdaTransport struct { - transport http.RoundTripper - url *url.URL -} - -func (t *cloudGdaTransport) RoundTrip(req *http.Request) (*http.Response, error) { - if strings.HasPrefix(req.URL.String(), "https://geminidataanalytics.googleapis.com") { - req.URL.Scheme = t.url.Scheme - req.URL.Host = t.url.Host - } - return t.transport.RoundTrip(req) -} - -type masterHandler struct { +type mockDataChatServer struct { + geminidataanalyticspb.UnimplementedDataChatServiceServer t *testing.T } -func (h *masterHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if !strings.Contains(r.UserAgent(), "genai-toolbox/") { - h.t.Errorf("User-Agent header not found") - } - - if r.Method != http.MethodPost { - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - return - } - - // Verify URL structure - // Expected: /v1beta/projects/{project}/locations/global:queryData - if !strings.Contains(r.URL.Path, ":queryData") || !strings.Contains(r.URL.Path, "locations/global") { - h.t.Errorf("unexpected URL path: %s", r.URL.Path) - http.Error(w, "Not found", http.StatusNotFound) - return - } - - var reqBody cloudgda.QueryDataRequest - if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil { - h.t.Fatalf("failed to decode request body: %v", err) - } - - if reqBody.Prompt == "" { - http.Error(w, "missing prompt", http.StatusBadRequest) - return +func (s *mockDataChatServer) QueryData(ctx context.Context, req *geminidataanalyticspb.QueryDataRequest) (*geminidataanalyticspb.QueryDataResponse, error) { + if req.Prompt == "" { + s.t.Errorf("missing prompt") + return nil, fmt.Errorf("missing prompt") } - response := map[string]any{ - "queryResult": "SELECT * FROM table;", - "naturalLanguageAnswer": "Here is the answer.", - } + return &geminidataanalyticspb.QueryDataResponse{ + GeneratedQuery: "SELECT * FROM table;", + NaturalLanguageAnswer: "Here is the answer.", + }, nil +} - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - if err := json.NewEncoder(w).Encode(response); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) +func getCloudGdaToolsConfig() map[string]any { + return map[string]any{ + "sources": map[string]any{ + "my-gda-source": map[string]any{ + "type": "cloud-gemini-data-analytics", + "projectId": "test-project", + }, + }, + "tools": map[string]any{ + "cloud-gda-query": map[string]any{ + "type": cloudGdaToolType, + "source": "my-gda-source", + "description": "Test GDA Tool", + "location": "us-central1", + "context": map[string]any{ + "datasourceReferences": map[string]any{ + "spannerReference": map[string]any{ + "databaseReference": map[string]any{ + "projectId": "test-project", + "instanceId": "test-instance", + "databaseId": "test-db", + "engine": "GOOGLE_SQL", + }, + }, + }, + }, + }, + }, } } @@ -97,27 +94,38 @@ func TestCloudGdaToolEndpoints(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() - handler := &masterHandler{t: t} - server := httptest.NewServer(handler) - defer server.Close() - - serverURL, err := url.Parse(server.URL) + // Start a gRPC server + lis, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { - t.Fatalf("failed to parse server URL: %v", err) + t.Fatalf("failed to listen: %v", err) + } + s := grpc.NewServer() + geminidataanalyticspb.RegisterDataChatServiceServer(s, &mockDataChatServer{t: t}) + go func() { + if err := s.Serve(lis); err != nil { + // This might happen on strict shutdown, log if unexpected + t.Logf("server executed: %v", err) + } + }() + defer s.Stop() + + // Configure toolbox to use the gRPC server + endpoint := lis.Addr().String() + + // Override client creation + origFunc := source.NewDataChatClient + defer func() { + source.NewDataChatClient = origFunc + }() + + source.NewDataChatClient = func(ctx context.Context, opts ...option.ClientOption) (*geminidataanalytics.DataChatClient, error) { + opts = append(opts, + option.WithEndpoint(endpoint), + option.WithoutAuthentication(), + option.WithGRPCDialOption(grpc.WithTransportCredentials(insecure.NewCredentials()))) + return origFunc(ctx, opts...) } - originalTransport := http.DefaultClient.Transport - if originalTransport == nil { - originalTransport = http.DefaultTransport - } - http.DefaultClient.Transport = &cloudGdaTransport{ - transport: originalTransport, - url: serverURL, - } - t.Cleanup(func() { - http.DefaultClient.Transport = originalTransport - }) - var args []string toolsFile := getCloudGdaToolsConfig() cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...) @@ -156,7 +164,7 @@ func TestCloudGdaToolEndpoints(t *testing.T) { // 2. RunToolInvokeParametersTest params := []byte(`{"query": "test question"}`) - tests.RunToolInvokeParametersTest(t, toolName, params, "\"queryResult\":\"SELECT * FROM table;\"") + tests.RunToolInvokeParametersTest(t, toolName, params, "\"generated_query\":\"SELECT * FROM table;\"") // 3. Manual MCP Tool Call Test // Initialize MCP session @@ -196,38 +204,3 @@ func TestCloudGdaToolEndpoints(t *testing.T) { t.Errorf("MCP response does not contain expected query result: %s", respStr) } } - -func getCloudGdaToolsConfig() map[string]any { - // Mocked responses and a dummy `projectId` are used in this integration - // test due to limited project-specific allowlisting. API functionality is - // verified via internal monitoring; this test specifically validates the - // integration flow between the source and the tool. - return map[string]any{ - "sources": map[string]any{ - "my-gda-source": map[string]any{ - "type": "cloud-gemini-data-analytics", - "projectId": "test-project", - }, - }, - "tools": map[string]any{ - "cloud-gda-query": map[string]any{ - "type": cloudGdaToolType, - "source": "my-gda-source", - "description": "Test GDA Tool", - "location": "us-central1", - "context": map[string]any{ - "datasourceReferences": map[string]any{ - "spannerReference": map[string]any{ - "databaseReference": map[string]any{ - "projectId": "test-project", - "instanceId": "test-instance", - "databaseId": "test-db", - "engine": "GOOGLE_SQL", - }, - }, - }, - }, - }, - }, - } -} diff --git a/tests/http/http_integration_test.go b/tests/http/http_integration_test.go index 6ec82f4ce745..b767bb31e786 100644 --- a/tests/http/http_integration_test.go +++ b/tests/http/http_integration_test.go @@ -362,7 +362,7 @@ func runQueryParamInvokeTest(t *testing.T) { name: "invoke query-param-tool (required param nil)", api: "http://127.0.0.1:5000/api/tool/my-query-param-tool/invoke", requestBody: bytes.NewBuffer([]byte(`{"reqId": null, "page": "1"}`)), - want: `"page=1\u0026reqId="`, // reqId becomes "", + want: `{"error":"parameter \"reqId\" is required"}`, }, } for _, tc := range invokeTcs { diff --git a/tests/looker/looker_integration_test.go b/tests/looker/looker_integration_test.go index 33688899626c..5c275333ed14 100644 --- a/tests/looker/looker_integration_test.go +++ b/tests/looker/looker_integration_test.go @@ -272,6 +272,21 @@ func TestLooker(t *testing.T) { "source": "my-instance", "description": "Simple tool to test end to end functionality.", }, + "get_lookml_tests": map[string]any{ + "type": "looker-get-lookml-tests", + "source": "my-instance", + "description": "Simple tool to test end to end functionality.", + }, + "run_lookml_tests": map[string]any{ + "type": "looker-run-lookml-tests", + "source": "my-instance", + "description": "Simple tool to test end to end functionality.", + }, + "create_view_from_table": map[string]any{ + "type": "looker-create-view-from-table", + "source": "my-instance", + "description": "Simple tool to test end to end functionality.", + }, }, } @@ -696,6 +711,115 @@ func TestLooker(t *testing.T) { }, }, ) + tests.RunToolGetTestByName(t, "get_lookml_tests", + map[string]any{ + "get_lookml_tests": map[string]any{ + "description": "Simple tool to test end to end functionality.", + "authRequired": []any{}, + "parameters": []any{ + map[string]any{ + "authSources": []any{}, + "description": "The unique ID of the LookML project.", + "name": "project_id", + "required": true, + "type": "string", + }, + map[string]any{ + "authSources": []any{}, + "description": "Optional ID of the file to filter tests by. This must be the complete file path from the project root (e.g., 'models/my_model.model.lkml').", + "name": "file_id", + "required": false, + "type": "string", + }, + }, + }, + }, + ) + tests.RunToolGetTestByName(t, "run_lookml_tests", + map[string]any{ + "run_lookml_tests": map[string]any{ + "description": "Simple tool to test end to end functionality.", + "authRequired": []any{}, + "parameters": []any{ + map[string]any{ + "authSources": []any{}, + "description": "The id of the project to run LookML tests for.", + "name": "project_id", + "required": true, + "type": "string", + }, + map[string]any{ + "authSources": []any{}, + "description": "Optional id of the file to run tests for.", + "name": "file_id", + "required": false, + "type": "string", + }, + map[string]any{ + "authSources": []any{}, + "description": "Optional name of the test to run.", + "name": "test", + "required": false, + "type": "string", + }, + map[string]any{ + "authSources": []any{}, + "description": "Optional name of the model to run tests for.", + "name": "model", + "required": false, + "type": "string", + }, + }, + }, + }, + ) + tests.RunToolGetTestByName(t, "create_view_from_table", + map[string]any{ + "create_view_from_table": map[string]any{ + "description": "Simple tool to test end to end functionality.", + "authRequired": []any{}, + "parameters": []any{ + map[string]any{ + "authSources": []any{}, + "description": "The id of the project to create the view in.", + "name": "project_id", + "required": true, + "type": "string", + }, + map[string]any{ + "authSources": []any{}, + "description": "The database connection name.", + "name": "connection", + "required": true, + "type": "string", + }, + map[string]any{ + "authSources": []any{}, + "description": "The tables to generate views for.\n\t\tEach item must be a map with:\n\t\t- schema (string, required)\n\t\t- table_name (string, required)\n\t\t- primary_key (string, optional)\n\t\t- base_view (boolean, optional)\n\t\t- columns (array of objects, optional): Each object must have 'column_name' (string).", + "items": map[string]any{ + "additionalProperties": true, + "authSources": []any{}, + "description": "Table definition.", + "name": "table", + "required": true, + "type": "object", + }, + "name": "tables", + "required": true, + "type": "array", + }, + map[string]any{ + "authSources": []any{}, + "default": "views", + "description": "The folder to place the view files in (e.g., 'views').", + "name": "folder_name", + "required": false, + "type": "string", + }, + }, + }, + }, + ) tests.RunToolGetTestByName(t, "get_looks", map[string]any{ "get_looks": map[string]any{ @@ -1768,17 +1892,27 @@ func TestLooker(t *testing.T) { tests.RunToolInvokeParametersTest(t, "delete_project_file", []byte(`{"project_id": "the_look", "file_path": "foo.view.lkml"}`), wantResult) wantResult = "Created" - tests.RunToolInvokeParametersTest(t, "create_project_directory", []byte(`{"project_id": "the_look", "directory_path": "foo_dir"}`), wantResult) + tests.RunToolInvokeParametersTest(t, "create_project_directory", []byte(`{"project_id": "the_look", "directory_path": "views"}`), wantResult) - wantResult = "foo_dir" + wantResult = "views" tests.RunToolInvokeParametersTest(t, "get_project_directories", []byte(`{"project_id": "the_look"}`), wantResult) + // Add test back when infrastructure for testing supports it. + // wantResult = "{\"status\": \"success\", \"message\": \"Triggered view generation for project the_look in folder views\"}" + // tests.RunToolInvokeParametersTest(t, "create_view_from_table", []byte(`{"project_id": "the_look", "connection": "thelook", "tables": [{"schema": "demo_db", "table_name": "Employees"}]}`), wantResult) + wantResult = "Deleted" - tests.RunToolInvokeParametersTest(t, "delete_project_directory", []byte(`{"project_id": "the_look", "directory_path": "foo_dir"}`), wantResult) + tests.RunToolInvokeParametersTest(t, "delete_project_directory", []byte(`{"project_id": "the_look", "directory_path": "views"}`), wantResult) wantResult = "\"errors\":[]" tests.RunToolInvokeParametersTest(t, "validate_project", []byte(`{"project_id": "the_look"}`), wantResult) + wantResult = "[]" + tests.RunToolInvokeParametersTest(t, "get_lookml_tests", []byte(`{"project_id": "the_look"}`), wantResult) + + wantResult = "[]" + tests.RunToolInvokeParametersTest(t, "run_lookml_tests", []byte(`{"project_id": "the_look"}`), wantResult) + wantResult = "production" tests.RunToolInvokeParametersTest(t, "dev_mode", []byte(`{"devMode": false}`), wantResult) diff --git a/tests/tool.go b/tests/tool.go index cd79e606ded4..591c5fb9646f 100644 --- a/tests/tool.go +++ b/tests/tool.go @@ -1377,13 +1377,14 @@ func RunPostgresListSchemasTest(t *testing.T, ctx context.Context, pool *pgxpool wantStatusCode: http.StatusOK, want: []map[string]any{wantSchema}, }, - { - name: "invoke list_schemas with owner name", - requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"owner": "%s"}`, owner))), - wantStatusCode: http.StatusOK, - want: []map[string]any{wantSchema}, - compareSubset: true, - }, + // TODO: Re-enable this test case after this issue is fixed: https://github.com/googleapis/genai-toolbox/issues/2562 + // { + // name: "invoke list_schemas with owner name", + // requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"owner": "%s"}`, owner))), + // wantStatusCode: http.StatusOK, + // want: []map[string]any{wantSchema}, + // compareSubset: true, + // }, { name: "invoke list_schemas with limit 1", requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"schema_name": "%s","limit": 1}`, schemaName))),