diff --git a/.ci/integration.cloudbuild.yaml b/.ci/integration.cloudbuild.yaml index 72de4cbdef19..58f0308620aa 100644 --- a/.ci/integration.cloudbuild.yaml +++ b/.ci/integration.cloudbuild.yaml @@ -214,6 +214,28 @@ steps: dataform \ dataform + - id: "healthcare" + name: golang:1 + waitFor: ["compile-test-binary"] + entrypoint: /bin/bash + env: + - "GOPATH=/gopath" + - "HEALTHCARE_PROJECT=$PROJECT_ID" + - "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL" + - "HEALTHCARE_REGION=$_REGION" + - "HEALTHCARE_DATASET=$_HEALTHCARE_DATASET" + secretEnv: ["CLIENT_ID"] + volumes: + - name: "go" + path: "/gopath" + args: + - -c + - | + .ci/test_with_coverage.sh \ + "Healthcare" \ + healthcare \ + healthcare + - id: "postgres" name: golang:1 waitFor: ["compile-test-binary"] @@ -813,6 +835,7 @@ substitutions: _ALLOYDB_AI_NL_CLUSTER: "alloydb-ai-nl-testing" _ALLOYDB_AI_NL_INSTANCE: "alloydb-ai-nl-testing-instance" _BIGTABLE_INSTANCE: "bigtable-testing-instance" + _HEALTHCARE_DATASET: "test-dataset" _POSTGRES_HOST: 127.0.0.1 _POSTGRES_PORT: "5432" _SPANNER_INSTANCE: "spanner-testing" diff --git a/cmd/root.go b/cmd/root.go index 1ac289c62b61..d1d5425799bb 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -96,6 +96,9 @@ import ( _ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestorequerycollection" _ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestoreupdatedocument" _ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestorevalidaterules" + _ "github.com/googleapis/genai-toolbox/internal/tools/healthcare/gethealthcaredataset" + _ "github.com/googleapis/genai-toolbox/internal/tools/healthcare/listdicomstores" + _ "github.com/googleapis/genai-toolbox/internal/tools/healthcare/listfhirstores" _ "github.com/googleapis/genai-toolbox/internal/tools/http" _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookeradddashboardelement" _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerconversationalanalytics" @@ -177,6 +180,7 @@ import ( _ "github.com/googleapis/genai-toolbox/internal/sources/dgraph" _ "github.com/googleapis/genai-toolbox/internal/sources/firebird" _ "github.com/googleapis/genai-toolbox/internal/sources/firestore" + _ "github.com/googleapis/genai-toolbox/internal/sources/healthcare" _ "github.com/googleapis/genai-toolbox/internal/sources/http" _ "github.com/googleapis/genai-toolbox/internal/sources/looker" _ "github.com/googleapis/genai-toolbox/internal/sources/mongodb" diff --git a/docs/en/resources/sources/healthcare.md b/docs/en/resources/sources/healthcare.md new file mode 100644 index 000000000000..755414504016 --- /dev/null +++ b/docs/en/resources/sources/healthcare.md @@ -0,0 +1,128 @@ +--- +title: "Cloud Healthcare API" +linkTitle: "Healthcare" +type: docs +weight: 1 +description: > + The Cloud Healthcare API provides a managed solution for storing and + accessing healthcare data in Google Cloud, providing a critical bridge + between existing care systems and applications hosted on Google Cloud. +--- + +## About + +The [Cloud Healthcare API][healthcare-docs] provides a managed solution +for storing and accessing healthcare data in Google Cloud, providing a +critical bridge between existing care systems and applications hosted on +Google Cloud. It supports healthcare data standards such as HL7® FHIR®, +HL7® v2, and DICOM®. It provides a fully managed, highly scalable, +enterprise-grade development environment for building clinical and analytics +solutions securely on Google Cloud. + +A dataset is a container in your Google Cloud project that holds modality-specific +healthcare data. Datasets contain other data stores, such as FHIR stores and DICOM +stores, which in turn hold their own types of healthcare data. + +A single dataset can contain one or many data stores, and those stores can all service +the same modality or different modalities as application needs dictate. Using multiple +stores in the same dataset might be appropriate in various situations. + +If you are new to the Healthcare API, you can try to +[create and view datasets and stores using curl][healthcare-quickstart-curl]. + +[healthcare-docs]: https://cloud.google.com/healthcare/docs +[healthcare-quickstart-curl]: + https://cloud.google.com/healthcare-api/docs/store-healthcare-data-rest + +## Available Tools + +- [`get-healthcare-dataset`](../tools/healthcare/get-healthcare-dataset.md) + Retrieves a dataset’s details. + +- [`list-fhir-stores`](../tools/healthcare/list-fhir-stores.md) + Lists the available FHIR stores in the healthcare dataset. + +- [`list-dicom-stores`](../tools/healthcare/list-dicom-stores.md) + Lists the available DICOM stores in the healthcare dataset. + +## Requirements + +### IAM Permissions + +The Healthcare API uses [Identity and Access Management (IAM)][iam-overview] to control +user and group access to Healthcare resources like projects, datasets, and stores. + +### Authentication via Application Default Credentials (ADC) + +By **default**, Toolbox will use your [Application Default Credentials +(ADC)][adc] to authorize and authenticate when interacting with the +[Healthcare API][healthcare-docs]. + +When using this method, you need to ensure the IAM identity associated with your +ADC (such as a service account) has the correct permissions for the queries you +intend to run. Common roles include `roles/healthcare.fhirResourceReader` (which includes +permissions to read and search for FHIR resources) or `roles/healthcare.dicomViewer` (for +retrieving DICOM images). +Follow this [guide][set-adc] to set up your ADC. + +### Authentication via User's OAuth Access Token + +If the `useClientOAuth` parameter is set to `true`, Toolbox will instead use the +OAuth access token for authentication. This token is parsed from the +`Authorization` header passed in with the tool invocation request. This method +allows Toolbox to make queries to the [Healthcare API][healthcare-docs] on behalf of the +client or the end-user. + +When using this on-behalf-of authentication, you must ensure that the +identity used has been granted the correct IAM permissions. + +[iam-overview]: +[adc]: +[set-adc]: + +## Example + +Initialize a Healthcare source that uses ADC: + +```yaml +sources: + my-healthcare-source: + kind: "healthcare" + project: "my-project-id" + region: "us-central1" + dataset: "my-healthcare-dataset-id" + # allowedFhirStores: # Optional: Restricts tool access to a specific list of FHIR store IDs. + # - "my_fhir_store_1" + # allowedDicomStores: # Optional: Restricts tool access to a specific list of DICOM store IDs. + # - "my_dicom_store_1" + # - "my_dicom_store_2" +``` + +Initialize a Healthcare source that uses the client's access token: + +```yaml +sources: + my-healthcare-client-auth-source: + kind: "healthcare" + project: "my-project-id" + region: "us-central1" + dataset: "my-healthcare-dataset-id" + useClientOAuth: true + # allowedFhirStores: # Optional: Restricts tool access to a specific list of FHIR store IDs. + # - "my_fhir_store_1" + # allowedDicomStores: # Optional: Restricts tool access to a specific list of DICOM store IDs. + # - "my_dicom_store_1" + # - "my_dicom_store_2" +``` + +## Reference + +| **field** | **type** | **required** | **description** | +|--------------------|:--------:|:------------:|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| kind | string | true | Must be "healthcare". | +| project | string | true | ID of the GCP project that the dataset lives in. | +| region | string | true | Specifies the region (e.g., 'us', 'asia-northeast1') of the healthcare dataset. [Learn More](https://cloud.google.com/healthcare-api/docs/regions) | +| dataset | string | true | ID of the healthcare dataset. | +| allowedFhirStores | []string | false | An optional list of FHIR store IDs that tools using this source are allowed to access. If provided, any tool operation attempting to access a store not in this list will be rejected. If a single store is provided, it will be treated as the default for prebuilt tools. | +| allowedDicomStores | []string | false | An optional list of DICOM store IDs that tools using this source are allowed to access. If provided, any tool operation attempting to access a store not in this list will be rejected. If a single store is provided, it will be treated as the default for prebuilt tools. | +| useClientOAuth | bool | false | If true, forwards the client's OAuth access token from the "Authorization" header to downstream queries. | diff --git a/docs/en/resources/tools/healthcare/_index.md b/docs/en/resources/tools/healthcare/_index.md new file mode 100644 index 000000000000..8781823bbeb1 --- /dev/null +++ b/docs/en/resources/tools/healthcare/_index.md @@ -0,0 +1,8 @@ +--- +title: "Cloud Healthcare API" +linkTitle: "Healthcare" +type: docs +weight: 1 +description: > + Tools that work with Healthcare Sources. +--- \ No newline at end of file diff --git a/docs/en/resources/tools/healthcare/get-healthcare-dataset.md b/docs/en/resources/tools/healthcare/get-healthcare-dataset.md new file mode 100644 index 000000000000..0d2c78165496 --- /dev/null +++ b/docs/en/resources/tools/healthcare/get-healthcare-dataset.md @@ -0,0 +1,38 @@ +--- +title: "get-healthcare-dataset" +linkTitle: "get-healthcare-dataset" +type: docs +weight: 1 +description: > + A "get-healthcare-dataset" tool retrieves metadata for the Healthcare dataset in the source. +aliases: +- /resources/tools/healthcare-get-healthcare-dataset +--- + +## About + +A `get-healthcare-dataset` tool retrieves metadata for a Healthcare dataset. +It's compatible with the following sources: + +- [healthcare](../../sources/healthcare.md) + +`get-healthcare-dataset` returns the metadata of the healthcare dataset +configured in the source. It takes no extra parameters. + +## Example + +```yaml +tools: + get_healthcare_dataset: + kind: get-healthcare-dataset + source: my-healthcare-source + description: Use this tool to get healthcare dataset metadata. +``` + +## Reference + +| **field** | **type** | **required** | **description** | +|-------------|:------------------------------------------:|:------------:|----------------------------------------------------| +| kind | string | true | Must be "get-healthcare-dataset". | +| source | string | true | Name of the healthcare source. | +| description | string | true | Description of the tool that is passed to the LLM. | diff --git a/docs/en/resources/tools/healthcare/list-dicom-stores.md b/docs/en/resources/tools/healthcare/list-dicom-stores.md new file mode 100644 index 000000000000..e32703494097 --- /dev/null +++ b/docs/en/resources/tools/healthcare/list-dicom-stores.md @@ -0,0 +1,38 @@ +--- +title: "list-dicom-stores" +linkTitle: "list-dicom-stores" +type: docs +weight: 1 +description: > + A "list-dicom-stores" lists the available DICOM stores in the healthcare dataset. +aliases: +- /resources/tools/healthcare-list-dicom-stores +--- + +## About + +A `list-dicom-stores` lists the available DICOM stores in the healthcare dataset. +It's compatible with the following sources: + +- [healthcare](../../sources/healthcare.md) + +`list-dicom-stores` returns the details of the available DICOM stores in the +dataset of the healthcare source. It takes no extra parameters. + +## Example + +```yaml +tools: + list_dicom_stores: + kind: list-dicom-stores + source: my-healthcare-source + description: Use this tool to list DICOM stores in the healthcare dataset. +``` + +## Reference + +| **field** | **type** | **required** | **description** | +|-------------|:------------------------------------------:|:------------:|----------------------------------------------------| +| kind | string | true | Must be "list-dicom-stores". | +| source | string | true | Name of the healthcare source. | +| description | string | true | Description of the tool that is passed to the LLM. | diff --git a/docs/en/resources/tools/healthcare/list-fhir-stores.md b/docs/en/resources/tools/healthcare/list-fhir-stores.md new file mode 100644 index 000000000000..e53368ad92fe --- /dev/null +++ b/docs/en/resources/tools/healthcare/list-fhir-stores.md @@ -0,0 +1,38 @@ +--- +title: "list-fhir-stores" +linkTitle: "list-fhir-stores" +type: docs +weight: 1 +description: > + A "list-fhir-stores" lists the available FHIR stores in the healthcare dataset. +aliases: +- /resources/tools/healthcare-list-fhir-stores +--- + +## About + +A `list-fhir-stores` lists the available FHIR stores in the healthcare dataset. +It's compatible with the following sources: + +- [healthcare](../../sources/healthcare.md) + +`list-fhir-stores` returns the details of the available FHIR stores in the +dataset of the healthcare source. It takes no extra parameters. + +## Example + +```yaml +tools: + list_fhir_stores: + kind: list-fhir-stores + source: my-healthcare-source + description: Use this tool to list FHIR stores in the healthcare dataset. +``` + +## Reference + +| **field** | **type** | **required** | **description** | +|-------------|:------------------------------------------:|:------------:|----------------------------------------------------| +| kind | string | true | Must be "list-fhir-stores". | +| source | string | true | Name of the healthcare source. | +| description | string | true | Description of the tool that is passed to the LLM. | diff --git a/internal/sources/healthcare/healthcare.go b/internal/sources/healthcare/healthcare.go new file mode 100644 index 000000000000..8684a60ea951 --- /dev/null +++ b/internal/sources/healthcare/healthcare.go @@ -0,0 +1,261 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package healthcare + +import ( + "context" + "fmt" + "net/http" + + "github.com/goccy/go-yaml" + "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/util" + "go.opentelemetry.io/otel/trace" + "golang.org/x/oauth2" + "golang.org/x/oauth2/google" + "google.golang.org/api/googleapi" + "google.golang.org/api/healthcare/v1" + "google.golang.org/api/option" +) + +const SourceKind string = "healthcare" + +// validate interface +var _ sources.SourceConfig = Config{} + +type HealthcareServiceCreator func(tokenString string) (*healthcare.Service, error) + +func init() { + if !sources.Register(SourceKind, newConfig) { + panic(fmt.Sprintf("source kind %q already registered", SourceKind)) + } +} + +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources.SourceConfig, error) { + actual := Config{Name: name} + if err := decoder.DecodeContext(ctx, &actual); err != nil { + return nil, err + } + return actual, nil +} + +type Config struct { + // Healthcare configs + Name string `yaml:"name" validate:"required"` + Kind string `yaml:"kind" validate:"required"` + Project string `yaml:"project" validate:"required"` + Region string `yaml:"region" validate:"required"` + Dataset string `yaml:"dataset" validate:"required"` + AllowedFHIRStores []string `yaml:"allowedFhirStores"` + AllowedDICOMStores []string `yaml:"allowedDicomStores"` + UseClientOAuth bool `yaml:"useClientOAuth"` +} + +func (c Config) SourceConfigKind() string { + return SourceKind +} + +func (c Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) { + var service *healthcare.Service + var serviceCreator HealthcareServiceCreator + var tokenSource oauth2.TokenSource + + svc, tok, err := initHealthcareConnection(ctx, tracer, c.Name) + if err != nil { + return nil, fmt.Errorf("error creating service from ADC: %w", err) + } + if c.UseClientOAuth { + serviceCreator, err = newHealthcareServiceCreator(ctx, tracer, c.Name) + if err != nil { + return nil, fmt.Errorf("error constructing service creator: %w", err) + } + } else { + service = svc + tokenSource = tok + } + + dsName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", c.Project, c.Region, c.Dataset) + if _, err = svc.Projects.Locations.Datasets.FhirStores.Get(dsName).Do(); err != nil { + if gerr, ok := err.(*googleapi.Error); ok && gerr.Code == http.StatusNotFound { + return nil, fmt.Errorf("dataset '%s' not found", dsName) + } + return nil, fmt.Errorf("failed to verify existence of dataset '%s': %w", dsName, err) + } + + allowedFHIRStores := make(map[string]struct{}) + for _, store := range c.AllowedFHIRStores { + name := fmt.Sprintf("%s/fhirStores/%s", dsName, store) + _, err := svc.Projects.Locations.Datasets.FhirStores.Get(name).Do() + if err != nil { + if gerr, ok := err.(*googleapi.Error); ok && gerr.Code == http.StatusNotFound { + return nil, fmt.Errorf("allowedFhirStore '%s' not found in dataset '%s'", store, dsName) + } + return nil, fmt.Errorf("failed to verify allowedFhirStore '%s' in datasest '%s': %w", store, dsName, err) + } + allowedFHIRStores[store] = struct{}{} + } + allowedDICOMStores := make(map[string]struct{}) + for _, store := range c.AllowedDICOMStores { + name := fmt.Sprintf("%s/dicomStores/%s", dsName, store) + _, err := svc.Projects.Locations.Datasets.DicomStores.Get(name).Do() + if err != nil { + if gerr, ok := err.(*googleapi.Error); ok && gerr.Code == http.StatusNotFound { + return nil, fmt.Errorf("allowedDicomStore '%s' not found in dataset '%s'", store, dsName) + } + return nil, fmt.Errorf("failed to verify allowedDicomFhirStore '%s' in datasest '%s': %w", store, dsName, err) + } + allowedDICOMStores[store] = struct{}{} + } + s := &Source{ + name: c.Name, + kind: SourceKind, + project: c.Project, + region: c.Region, + dataset: c.Dataset, + service: service, + serviceCreator: serviceCreator, + tokenSource: tokenSource, + allowedFHIRStores: allowedFHIRStores, + allowedDICOMStores: allowedDICOMStores, + useClientOAuth: c.UseClientOAuth, + } + return s, nil +} + +func newHealthcareServiceCreator(ctx context.Context, tracer trace.Tracer, name string) (func(string) (*healthcare.Service, error), error) { + userAgent, err := util.UserAgentFromContext(ctx) + if err != nil { + return nil, err + } + return func(tokenString string) (*healthcare.Service, error) { + return initHealthcareConnectionWithOAuthToken(ctx, tracer, name, userAgent, tokenString) + }, nil +} + +func initHealthcareConnectionWithOAuthToken(ctx context.Context, tracer trace.Tracer, name string, userAgent string, tokenString string) (*healthcare.Service, error) { + ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name) + defer span.End() + // Construct token source + token := &oauth2.Token{ + AccessToken: string(tokenString), + } + ts := oauth2.StaticTokenSource(token) + + // Initialize the Healthcare service with tokenSource + service, err := healthcare.NewService(ctx, option.WithUserAgent(userAgent), option.WithTokenSource(ts)) + if err != nil { + return nil, fmt.Errorf("failed to create Healthcare service: %w", err) + } + return service, nil +} + +func initHealthcareConnection(ctx context.Context, tracer trace.Tracer, name string) (*healthcare.Service, oauth2.TokenSource, error) { + ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name) + defer span.End() + + cred, err := google.FindDefaultCredentials(ctx, healthcare.CloudHealthcareScope) + if err != nil { + return nil, nil, fmt.Errorf("failed to find default Google Cloud credentials with scope %q: %w", healthcare.CloudHealthcareScope, err) + } + + userAgent, err := util.UserAgentFromContext(ctx) + if err != nil { + return nil, nil, err + } + + service, err := healthcare.NewService(ctx, option.WithUserAgent(userAgent), option.WithCredentials(cred)) + if err != nil { + return nil, nil, fmt.Errorf("failed to create Healthcare service: %w", err) + } + return service, cred.TokenSource, nil +} + +var _ sources.Source = &Source{} + +type Source struct { + name string `yaml:"name"` + kind string `yaml:"kind"` + project string + region string + dataset string + service *healthcare.Service + serviceCreator HealthcareServiceCreator + tokenSource oauth2.TokenSource + allowedFHIRStores map[string]struct{} + allowedDICOMStores map[string]struct{} + useClientOAuth bool +} + +func (s *Source) SourceKind() string { + return SourceKind +} + +func (s *Source) Project() string { + return s.project +} + +func (s *Source) Region() string { + return s.region +} + +func (s *Source) DatasetID() string { + return s.dataset +} + +func (s *Source) Service() *healthcare.Service { + return s.service +} + +func (s *Source) ServiceCreator() HealthcareServiceCreator { + return s.serviceCreator +} + +func (s *Source) TokenSource() oauth2.TokenSource { + return s.tokenSource +} + +func (s *Source) AllowedFHIRStores() map[string]struct{} { + if len(s.allowedFHIRStores) == 0 { + return nil + } + return s.allowedFHIRStores +} + +func (s *Source) AllowedDICOMStores() map[string]struct{} { + if len(s.allowedDICOMStores) == 0 { + return nil + } + return s.allowedDICOMStores +} + +func (s *Source) IsFHIRStoreAllowed(storeID string) bool { + if len(s.allowedFHIRStores) == 0 { + return true + } + _, ok := s.allowedFHIRStores[storeID] + return ok +} + +func (s *Source) IsDICOMStoreAllowed(storeID string) bool { + if len(s.allowedDICOMStores) == 0 { + return true + } + _, ok := s.allowedDICOMStores[storeID] + return ok +} + +func (s *Source) UseClientAuthorization() bool { + return s.useClientOAuth +} diff --git a/internal/sources/healthcare/healthcare_test.go b/internal/sources/healthcare/healthcare_test.go new file mode 100644 index 000000000000..8e0e84b8fe05 --- /dev/null +++ b/internal/sources/healthcare/healthcare_test.go @@ -0,0 +1,168 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package healthcare_test + +import ( + "testing" + + "github.com/goccy/go-yaml" + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/sources/healthcare" + "github.com/googleapis/genai-toolbox/internal/testutils" +) + +func TestParseFromYamlHealthcare(t *testing.T) { + tcs := []struct { + desc string + in string + want server.SourceConfigs + }{ + { + desc: "basic example", + in: ` + sources: + my-instance: + kind: healthcare + project: my-project + region: us-central1 + dataset: my-dataset + `, + want: server.SourceConfigs{ + "my-instance": healthcare.Config{ + Name: "my-instance", + Kind: healthcare.SourceKind, + Project: "my-project", + Region: "us-central1", + Dataset: "my-dataset", + UseClientOAuth: false, + }, + }, + }, + { + desc: "use client auth example", + in: ` + sources: + my-instance: + kind: healthcare + project: my-project + region: us + dataset: my-dataset + useClientOAuth: true + `, + want: server.SourceConfigs{ + "my-instance": healthcare.Config{ + Name: "my-instance", + Kind: healthcare.SourceKind, + Project: "my-project", + Region: "us", + Dataset: "my-dataset", + UseClientOAuth: true, + }, + }, + }, + { + desc: "with allowed stores example", + in: ` + sources: + my-instance: + kind: healthcare + project: my-project + region: us + dataset: my-dataset + allowedFhirStores: + - my-fhir-store + allowedDicomStores: + - my-dicom-store1 + - my-dicom-store2 + `, + want: server.SourceConfigs{ + "my-instance": healthcare.Config{ + Name: "my-instance", + Kind: healthcare.SourceKind, + Project: "my-project", + Region: "us", + Dataset: "my-dataset", + AllowedFHIRStores: []string{"my-fhir-store"}, + AllowedDICOMStores: []string{"my-dicom-store1", "my-dicom-store2"}, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Sources server.SourceConfigs `yaml:"sources"` + }{} + // Parse contents + err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if !cmp.Equal(tc.want, got.Sources) { + t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources) + } + }) + } +} + +func TestFailParseFromYaml(t *testing.T) { + tcs := []struct { + desc string + in string + err string + }{ + { + desc: "extra field", + in: ` + sources: + my-instance: + kind: healthcare + project: my-project + region: us-central1 + dataset: my-dataset + foo: bar + `, + err: "unable to parse source \"my-instance\" as \"healthcare\": [2:1] unknown field \"foo\"\n 1 | dataset: my-dataset\n> 2 | foo: bar\n ^\n 3 | kind: healthcare\n 4 | project: my-project\n 5 | region: us-central1", + }, + { + desc: "missing required field", + in: ` + sources: + my-instance: + kind: healthcare + project: my-project + region: us-central1 + `, + err: `unable to parse source "my-instance" as "healthcare": Key: 'Config.Dataset' Error:Field validation for 'Dataset' failed on the 'required' tag`, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Sources server.SourceConfigs `yaml:"sources"` + }{} + // Parse contents + err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got) + if err == nil { + t.Fatalf("expect parsing to fail") + } + errStr := err.Error() + if errStr != tc.err { + t.Fatalf("unexpected error: got %q, want %q", errStr, tc.err) + } + }) + } +} diff --git a/internal/tools/healthcare/gethealthcaredataset/gethealthcaredataset.go b/internal/tools/healthcare/gethealthcaredataset/gethealthcaredataset.go new file mode 100644 index 000000000000..e25cb8a672be --- /dev/null +++ b/internal/tools/healthcare/gethealthcaredataset/gethealthcaredataset.go @@ -0,0 +1,166 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gethealthcaredataset + +import ( + "context" + "fmt" + + "github.com/goccy/go-yaml" + "github.com/googleapis/genai-toolbox/internal/sources" + healthcareds "github.com/googleapis/genai-toolbox/internal/sources/healthcare" + "github.com/googleapis/genai-toolbox/internal/tools" + "google.golang.org/api/healthcare/v1" +) + +const kind string = "get-healthcare-dataset" + +func init() { + if !tools.Register(kind, newConfig) { + panic(fmt.Sprintf("tool kind %q already registered", kind)) + } +} + +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { + actual := Config{Name: name} + if err := decoder.DecodeContext(ctx, &actual); err != nil { + return nil, err + } + return actual, nil +} + +type compatibleSource interface { + Project() string + Region() string + DatasetID() string + Service() *healthcare.Service + ServiceCreator() healthcareds.HealthcareServiceCreator + UseClientAuthorization() bool +} + +// validate compatible sources are still compatible +var _ compatibleSource = &healthcareds.Source{} + +var compatibleSources = [...]string{healthcareds.SourceKind} + +type Config struct { + Name string `yaml:"name" validate:"required"` + Kind string `yaml:"kind" validate:"required"` + Source string `yaml:"source" validate:"required"` + Description string `yaml:"description" validate:"required"` + AuthRequired []string `yaml:"authRequired"` +} + +// validate interface +var _ tools.ToolConfig = Config{} + +func (cfg Config) ToolConfigKind() string { + return kind +} + +func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { + // verify source exists + rawS, ok := srcs[cfg.Source] + if !ok { + return nil, fmt.Errorf("no source named %q configured", cfg.Source) + } + + // verify the source is compatible + s, ok := rawS.(compatibleSource) + if !ok { + return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + } + + parameters := tools.Parameters{} + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, parameters) + + // finish tool setup + t := Tool{ + Name: cfg.Name, + Kind: kind, + Parameters: parameters, + AuthRequired: cfg.AuthRequired, + Project: s.Project(), + Region: s.Region(), + Dataset: s.DatasetID(), + UseClientOAuth: s.UseClientAuthorization(), + ServiceCreator: s.ServiceCreator(), + Service: s.Service(), + manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, + } + return t, nil +} + +// validate interface +var _ tools.Tool = Tool{} + +type Tool struct { + Name string `yaml:"name"` + Kind string `yaml:"kind"` + AuthRequired []string `yaml:"authRequired"` + UseClientOAuth bool `yaml:"useClientOAuth"` + Parameters tools.Parameters `yaml:"parameters"` + + Project, Region, Dataset string + Service *healthcare.Service + ServiceCreator healthcareds.HealthcareServiceCreator + manifest tools.Manifest + mcpManifest tools.McpManifest +} + +func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) { + svc := t.Service + var err error + + // Initialize new service if using user OAuth token + if t.UseClientOAuth { + tokenStr, err := accessToken.ParseBearerToken() + if err != nil { + return nil, fmt.Errorf("error parsing access token: %w", err) + } + svc, err = t.ServiceCreator(tokenStr) + if err != nil { + return nil, fmt.Errorf("error creating service from OAuth access token: %w", err) + } + } + + datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", t.Project, t.Region, t.Dataset) + dataset, err := svc.Projects.Locations.Datasets.Get(datasetName).Do() + if err != nil { + return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err) + } + return dataset, nil +} + +func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) { + return tools.ParseParams(t.Parameters, data, claims) +} + +func (t Tool) Manifest() tools.Manifest { + return t.manifest +} + +func (t Tool) McpManifest() tools.McpManifest { + return t.mcpManifest +} + +func (t Tool) Authorized(verifiedAuthServices []string) bool { + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) +} + +func (t Tool) RequiresClientAuthorization() bool { + return t.UseClientOAuth +} diff --git a/internal/tools/healthcare/gethealthcaredataset/gethealthcaredataset_test.go b/internal/tools/healthcare/gethealthcaredataset/gethealthcaredataset_test.go new file mode 100644 index 000000000000..f3711a4202d6 --- /dev/null +++ b/internal/tools/healthcare/gethealthcaredataset/gethealthcaredataset_test.go @@ -0,0 +1,72 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gethealthcaredataset_test + +import ( + "testing" + + yaml "github.com/goccy/go-yaml" + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/testutils" + getdataset "github.com/googleapis/genai-toolbox/internal/tools/healthcare/gethealthcaredataset" +) + +func TestParseFromYamlGetHealthcareDataset(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + desc string + in string + want server.ToolConfigs + }{ + { + desc: "basic example", + in: ` + tools: + example_tool: + kind: get-healthcare-dataset + source: my-instance + description: some description + `, + want: server.ToolConfigs{ + "example_tool": getdataset.Config{ + Name: "example_tool", + Kind: "get-healthcare-dataset", + Source: "my-instance", + Description: "some description", + AuthRequired: []string{}, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Tools server.ToolConfigs `yaml:"tools"` + }{} + // Parse contents + err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if diff := cmp.Diff(tc.want, got.Tools); diff != "" { + t.Fatalf("incorrect parse: diff %v", diff) + } + }) + } +} diff --git a/internal/tools/healthcare/listdicomstores/listdicomstores.go b/internal/tools/healthcare/listdicomstores/listdicomstores.go new file mode 100644 index 000000000000..aec8b21e33a5 --- /dev/null +++ b/internal/tools/healthcare/listdicomstores/listdicomstores.go @@ -0,0 +1,184 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package listdicomstores + +import ( + "context" + "fmt" + "strings" + + "github.com/goccy/go-yaml" + "github.com/googleapis/genai-toolbox/internal/sources" + healthcareds "github.com/googleapis/genai-toolbox/internal/sources/healthcare" + "github.com/googleapis/genai-toolbox/internal/tools" + "google.golang.org/api/healthcare/v1" +) + +const kind string = "list-dicom-stores" + +func init() { + if !tools.Register(kind, newConfig) { + panic(fmt.Sprintf("tool kind %q already registered", kind)) + } +} + +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { + actual := Config{Name: name} + if err := decoder.DecodeContext(ctx, &actual); err != nil { + return nil, err + } + return actual, nil +} + +type compatibleSource interface { + Project() string + Region() string + DatasetID() string + AllowedDICOMStores() map[string]struct{} + Service() *healthcare.Service + ServiceCreator() healthcareds.HealthcareServiceCreator + UseClientAuthorization() bool +} + +// validate compatible sources are still compatible +var _ compatibleSource = &healthcareds.Source{} + +var compatibleSources = [...]string{healthcareds.SourceKind} + +type Config struct { + Name string `yaml:"name" validate:"required"` + Kind string `yaml:"kind" validate:"required"` + Source string `yaml:"source" validate:"required"` + Description string `yaml:"description" validate:"required"` + AuthRequired []string `yaml:"authRequired"` +} + +// validate interface +var _ tools.ToolConfig = Config{} + +func (cfg Config) ToolConfigKind() string { + return kind +} + +func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { + // verify source exists + rawS, ok := srcs[cfg.Source] + if !ok { + return nil, fmt.Errorf("no source named %q configured", cfg.Source) + } + + // verify the source is compatible + s, ok := rawS.(compatibleSource) + if !ok { + return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + } + + parameters := tools.Parameters{} + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, parameters) + + // finish tool setup + t := Tool{ + Name: cfg.Name, + Kind: kind, + Parameters: parameters, + AuthRequired: cfg.AuthRequired, + Project: s.Project(), + Region: s.Region(), + Dataset: s.DatasetID(), + AllowedStores: s.AllowedDICOMStores(), + UseClientOAuth: s.UseClientAuthorization(), + ServiceCreator: s.ServiceCreator(), + Service: s.Service(), + manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, + } + return t, nil +} + +// validate interface +var _ tools.Tool = Tool{} + +type Tool struct { + Name string `yaml:"name"` + Kind string `yaml:"kind"` + AuthRequired []string `yaml:"authRequired"` + UseClientOAuth bool `yaml:"useClientOAuth"` + Parameters tools.Parameters `yaml:"parameters"` + + Project, Region, Dataset string + AllowedStores map[string]struct{} + Service *healthcare.Service + ServiceCreator healthcareds.HealthcareServiceCreator + manifest tools.Manifest + mcpManifest tools.McpManifest +} + +func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) { + svc := t.Service + var err error + + // Initialize new service if using user OAuth token + if t.UseClientOAuth { + tokenStr, err := accessToken.ParseBearerToken() + if err != nil { + return nil, fmt.Errorf("error parsing access token: %w", err) + } + svc, err = t.ServiceCreator(tokenStr) + if err != nil { + return nil, fmt.Errorf("error creating service from OAuth access token: %w", err) + } + } + + datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", t.Project, t.Region, t.Dataset) + stores, err := svc.Projects.Locations.Datasets.DicomStores.List(datasetName).Do() + if err != nil { + return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err) + } + var filtered []*healthcare.DicomStore + for _, store := range stores.DicomStores { + if len(t.AllowedStores) == 0 { + filtered = append(filtered, store) + continue + } + if len(store.Name) == 0 { + continue + } + parts := strings.Split(store.Name, "/") + if _, ok := t.AllowedStores[parts[len(parts)-1]]; ok { + filtered = append(filtered, store) + } + } + return filtered, nil +} + +func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) { + return tools.ParseParams(t.Parameters, data, claims) +} + +func (t Tool) Manifest() tools.Manifest { + return t.manifest +} + +func (t Tool) McpManifest() tools.McpManifest { + return t.mcpManifest +} + +func (t Tool) Authorized(verifiedAuthServices []string) bool { + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) +} + +func (t Tool) RequiresClientAuthorization() bool { + return t.UseClientOAuth +} diff --git a/internal/tools/healthcare/listdicomstores/listdicomstores_test.go b/internal/tools/healthcare/listdicomstores/listdicomstores_test.go new file mode 100644 index 000000000000..707a5a16390e --- /dev/null +++ b/internal/tools/healthcare/listdicomstores/listdicomstores_test.go @@ -0,0 +1,72 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package listdicomstores_test + +import ( + "testing" + + yaml "github.com/goccy/go-yaml" + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/internal/tools/healthcare/listdicomstores" +) + +func TestParseFromYamlHealthcareListDICOMStores(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + desc string + in string + want server.ToolConfigs + }{ + { + desc: "basic example", + in: ` + tools: + example_tool: + kind: list-dicom-stores + source: my-instance + description: some description + `, + want: server.ToolConfigs{ + "example_tool": listdicomstores.Config{ + Name: "example_tool", + Kind: "list-dicom-stores", + Source: "my-instance", + Description: "some description", + AuthRequired: []string{}, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Tools server.ToolConfigs `yaml:"tools"` + }{} + // Parse contents + err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if diff := cmp.Diff(tc.want, got.Tools); diff != "" { + t.Fatalf("incorrect parse: diff %v", diff) + } + }) + } +} diff --git a/internal/tools/healthcare/listfhirstores/listfhirstores.go b/internal/tools/healthcare/listfhirstores/listfhirstores.go new file mode 100644 index 000000000000..e2b76ff89258 --- /dev/null +++ b/internal/tools/healthcare/listfhirstores/listfhirstores.go @@ -0,0 +1,184 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package listfhirstores + +import ( + "context" + "fmt" + "strings" + + "github.com/goccy/go-yaml" + "github.com/googleapis/genai-toolbox/internal/sources" + healthcareds "github.com/googleapis/genai-toolbox/internal/sources/healthcare" + "github.com/googleapis/genai-toolbox/internal/tools" + "google.golang.org/api/healthcare/v1" +) + +const kind string = "list-fhir-stores" + +func init() { + if !tools.Register(kind, newConfig) { + panic(fmt.Sprintf("tool kind %q already registered", kind)) + } +} + +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { + actual := Config{Name: name} + if err := decoder.DecodeContext(ctx, &actual); err != nil { + return nil, err + } + return actual, nil +} + +type compatibleSource interface { + Project() string + Region() string + DatasetID() string + AllowedFHIRStores() map[string]struct{} + Service() *healthcare.Service + ServiceCreator() healthcareds.HealthcareServiceCreator + UseClientAuthorization() bool +} + +// validate compatible sources are still compatible +var _ compatibleSource = &healthcareds.Source{} + +var compatibleSources = [...]string{healthcareds.SourceKind} + +type Config struct { + Name string `yaml:"name" validate:"required"` + Kind string `yaml:"kind" validate:"required"` + Source string `yaml:"source" validate:"required"` + Description string `yaml:"description" validate:"required"` + AuthRequired []string `yaml:"authRequired"` +} + +// validate interface +var _ tools.ToolConfig = Config{} + +func (cfg Config) ToolConfigKind() string { + return kind +} + +func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { + // verify source exists + rawS, ok := srcs[cfg.Source] + if !ok { + return nil, fmt.Errorf("no source named %q configured", cfg.Source) + } + + // verify the source is compatible + s, ok := rawS.(compatibleSource) + if !ok { + return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + } + + parameters := tools.Parameters{} + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, parameters) + + // finish tool setup + t := Tool{ + Name: cfg.Name, + Kind: kind, + Parameters: parameters, + AuthRequired: cfg.AuthRequired, + Project: s.Project(), + Region: s.Region(), + Dataset: s.DatasetID(), + AllowedStores: s.AllowedFHIRStores(), + UseClientOAuth: s.UseClientAuthorization(), + ServiceCreator: s.ServiceCreator(), + Service: s.Service(), + manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, + } + return t, nil +} + +// validate interface +var _ tools.Tool = Tool{} + +type Tool struct { + Name string `yaml:"name"` + Kind string `yaml:"kind"` + AuthRequired []string `yaml:"authRequired"` + UseClientOAuth bool `yaml:"useClientOAuth"` + Parameters tools.Parameters `yaml:"parameters"` + + Project, Region, Dataset string + AllowedStores map[string]struct{} + Service *healthcare.Service + ServiceCreator healthcareds.HealthcareServiceCreator + manifest tools.Manifest + mcpManifest tools.McpManifest +} + +func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) { + svc := t.Service + var err error + + // Initialize new service if using user OAuth token + if t.UseClientOAuth { + tokenStr, err := accessToken.ParseBearerToken() + if err != nil { + return nil, fmt.Errorf("error parsing access token: %w", err) + } + svc, err = t.ServiceCreator(tokenStr) + if err != nil { + return nil, fmt.Errorf("error creating service from OAuth access token: %w", err) + } + } + + datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", t.Project, t.Region, t.Dataset) + stores, err := svc.Projects.Locations.Datasets.FhirStores.List(datasetName).Do() + if err != nil { + return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err) + } + var filtered []*healthcare.FhirStore + for _, store := range stores.FhirStores { + if len(t.AllowedStores) == 0 { + filtered = append(filtered, store) + continue + } + if len(store.Name) == 0 { + continue + } + parts := strings.Split(store.Name, "/") + if _, ok := t.AllowedStores[parts[len(parts)-1]]; ok { + filtered = append(filtered, store) + } + } + return filtered, nil +} + +func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) { + return tools.ParseParams(t.Parameters, data, claims) +} + +func (t Tool) Manifest() tools.Manifest { + return t.manifest +} + +func (t Tool) McpManifest() tools.McpManifest { + return t.mcpManifest +} + +func (t Tool) Authorized(verifiedAuthServices []string) bool { + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) +} + +func (t Tool) RequiresClientAuthorization() bool { + return t.UseClientOAuth +} diff --git a/internal/tools/healthcare/listfhirstores/listfhirstores_test.go b/internal/tools/healthcare/listfhirstores/listfhirstores_test.go new file mode 100644 index 000000000000..fbb9463780e2 --- /dev/null +++ b/internal/tools/healthcare/listfhirstores/listfhirstores_test.go @@ -0,0 +1,72 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package listfhirstores_test + +import ( + "testing" + + yaml "github.com/goccy/go-yaml" + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/internal/tools/healthcare/listfhirstores" +) + +func TestParseFromYamlHealthcareListFHIRStores(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + desc string + in string + want server.ToolConfigs + }{ + { + desc: "basic example", + in: ` + tools: + example_tool: + kind: list-fhir-stores + source: my-instance + description: some description + `, + want: server.ToolConfigs{ + "example_tool": listfhirstores.Config{ + Name: "example_tool", + Kind: "list-fhir-stores", + Source: "my-instance", + Description: "some description", + AuthRequired: []string{}, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Tools server.ToolConfigs `yaml:"tools"` + }{} + // Parse contents + err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if diff := cmp.Diff(tc.want, got.Tools); diff != "" { + t.Fatalf("incorrect parse: diff %v", diff) + } + }) + } +} diff --git a/tests/healthcare/healthcare_integration_test.go b/tests/healthcare/healthcare_integration_test.go new file mode 100644 index 000000000000..c973e2f217bb --- /dev/null +++ b/tests/healthcare/healthcare_integration_test.go @@ -0,0 +1,613 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// To run these tests, set the following environment variables: +// HEALTHCARE_PROJECT: Google Cloud project ID for healthcare resources. +// HEALTHCARE_REGION: Google Cloud region for healthcare resources. + +package healthcare + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "regexp" + "strings" + "testing" + "time" + + "github.com/google/uuid" + "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/tests" + "golang.org/x/oauth2/google" + "google.golang.org/api/healthcare/v1" + "google.golang.org/api/option" +) + +var ( + healthcareSourceKind = "healthcare" + getDatasetToolKind = "get-healthcare-dataset" + listFHIRStoresToolKind = "list-fhir-stores" + listDICOMStoresToolKind = "list-dicom-stores" + healthcareProject = os.Getenv("HEALTHCARE_PROJECT") + healthcareRegion = os.Getenv("HEALTHCARE_REGION") + healthcareDataset = os.Getenv("HEALTHCARE_DATASET") +) + +func verifyHealthcareVars(t *testing.T) { + switch "" { + case healthcareProject: + t.Fatal("'HEALTHCARE_PROJECT' not set") + case healthcareRegion: + t.Fatal("'HEALTHCARE_REGION' not set") + case healthcareDataset: + t.Fatal("'HEALTHCARE_DATASET' not set") + } +} + +func TestHealthcareToolEndpoints(t *testing.T) { + verifyHealthcareVars(t) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + + healthcareService, err := newHealthcareService(ctx) + if err != nil { + t.Fatalf("failed to create healthcare service: %v", err) + } + + fhirStoreID := "fhir-store-" + uuid.New().String() + dicomStoreID := "dicom-store-" + uuid.New().String() + + teardown := setupHealthcareResources(t, ctx, healthcareService, healthcareDataset, fhirStoreID, dicomStoreID) + defer teardown(t) + + sourceConfig := map[string]any{ + "kind": healthcareSourceKind, + "project": healthcareProject, + "region": healthcareRegion, + "dataset": healthcareDataset, + } + + toolsFile := getToolsConfig(sourceConfig) + toolsFile = addClientAuthSourceConfig(t, toolsFile, healthcareDataset) + + var args []string + cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...) + if err != nil { + t.Fatalf("command initialization returned an error: %s", err) + } + defer cleanup() + + waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out) + if err != nil { + t.Logf("toolbox command logs: %s", out) + t.Fatalf("toolbox didn't start successfully: %s", err) + } + + datasetWant := fmt.Sprintf(`"name":"projects/%s/locations/%s/datasets/%s"`, healthcareProject, healthcareRegion, healthcareDataset) + fhirStoreWant := fmt.Sprintf(`"name":"projects/%s/locations/%s/datasets/%s/fhirStores/%s"`, healthcareProject, healthcareRegion, healthcareDataset, fhirStoreID) + dicomStoreWant := fmt.Sprintf(`"name":"projects/%s/locations/%s/datasets/%s/dicomStores/%s"`, healthcareProject, healthcareRegion, healthcareDataset, dicomStoreID) + + runGetDatasetToolInvokeTest(t, datasetWant) + runListFHIRStoresToolInvokeTest(t, fhirStoreWant) + runListDICOMStoresToolInvokeTest(t, dicomStoreWant) +} + +func TestHealthcareToolWithStoreRestriction(t *testing.T) { + verifyHealthcareVars(t) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + + healthcareService, err := newHealthcareService(ctx) + if err != nil { + t.Fatalf("failed to create healthcare service: %v", err) + } + + // Create stores + allowedFHIRStoreID := "fhir-store-allowed-" + uuid.New().String() + allowedDICOMStoreID := "dicom-store-allowed-" + uuid.New().String() + disallowedFHIRStoreID := "fhir-store-disallowed-" + uuid.New().String() + disallowedDICOMStoreID := "dicom-store-disallowed-" + uuid.New().String() + + teardownAllowedStores := setupHealthcareResources(t, ctx, healthcareService, healthcareDataset, allowedFHIRStoreID, allowedDICOMStoreID) + defer teardownAllowedStores(t) + teardownDisallowedStores := setupHealthcareResources(t, ctx, healthcareService, healthcareDataset, disallowedFHIRStoreID, disallowedDICOMStoreID) + defer teardownDisallowedStores(t) + + // Configure source with dataset restriction. + sourceConfig := map[string]any{ + "kind": healthcareSourceKind, + "project": healthcareProject, + "region": healthcareRegion, + "dataset": healthcareDataset, + "allowedFhirStores": []string{ + allowedFHIRStoreID, + }, + "allowedDicomStores": []string{ + allowedDICOMStoreID, + }, + } + + // Configure tool + toolsConfig := map[string]any{ + "list-fhir-stores-restricted": map[string]any{ + "kind": "list-fhir-stores", + "source": "my-instance", + "description": "Tool to list fhir stores", + }, + "list-dicom-stores-restricted": map[string]any{ + "kind": "list-dicom-stores", + "source": "my-instance", + "description": "Tool to list dicom stores", + }, + } + + // Create config file + config := map[string]any{ + "sources": map[string]any{ + "my-instance": sourceConfig, + }, + "tools": toolsConfig, + } + + // Start server + cmd, cleanup, err := tests.StartCmd(ctx, config) + if err != nil { + t.Fatalf("command initialization returned an error: %s", err) + } + defer cleanup() + + waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out) + if err != nil { + t.Logf("toolbox command logs: \n%s", out) + t.Fatalf("toolbox didn't start successfully: %s", err) + } + + // Run tests + runListFHIRStoresWithRestriction(t, allowedFHIRStoreID, disallowedFHIRStoreID) + runListDICOMStoresWithRestriction(t, allowedDICOMStoreID, disallowedDICOMStoreID) +} + +func newHealthcareService(ctx context.Context) (*healthcare.Service, error) { + creds, err := google.FindDefaultCredentials(ctx, healthcare.CloudHealthcareScope) + if err != nil { + return nil, fmt.Errorf("failed to find default credentials: %w", err) + } + + healthcareService, err := healthcare.NewService(ctx, option.WithCredentials(creds)) + if err != nil { + return nil, fmt.Errorf("failed to create healthcare service: %w", err) + } + return healthcareService, nil +} + +func setupHealthcareResources(t *testing.T, ctx context.Context, service *healthcare.Service, datasetID, fhirStoreID, dicomStoreID string) func(*testing.T) { + datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", healthcareProject, healthcareRegion, datasetID) + var err error + + // Create FHIR store + fhirStore := &healthcare.FhirStore{Version: "R4"} + if fhirStore, err = service.Projects.Locations.Datasets.FhirStores.Create(datasetName, fhirStore).FhirStoreId(fhirStoreID).Do(); err != nil { + t.Fatalf("failed to create fhir store: %v", err) + } + + // Create DICOM store + dicomStore := &healthcare.DicomStore{} + if dicomStore, err = service.Projects.Locations.Datasets.DicomStores.Create(datasetName, dicomStore).DicomStoreId(dicomStoreID).Do(); err != nil { + t.Fatalf("failed to create dicom store: %v", err) + } + + return func(t *testing.T) { + if _, err := service.Projects.Locations.Datasets.FhirStores.Delete(fhirStore.Name).Do(); err != nil { + t.Logf("failed to delete fhir store: %v", err) + } + if _, err := service.Projects.Locations.Datasets.DicomStores.Delete(dicomStore.Name).Do(); err != nil { + t.Logf("failed to delete dicom store: %v", err) + } + } +} + +func getToolsConfig(sourceConfig map[string]any) map[string]any { + config := map[string]any{ + "sources": map[string]any{ + "my-instance": sourceConfig, + }, + "tools": map[string]any{ + "my-get-dataset-tool": map[string]any{ + "kind": getDatasetToolKind, + "source": "my-instance", + "description": "Tool to get a healthcare dataset", + }, + "my-list-fhir-stores-tool": map[string]any{ + "kind": listFHIRStoresToolKind, + "source": "my-instance", + "description": "Tool to list FHIR stores", + }, + "my-list-dicom-stores-tool": map[string]any{ + "kind": listDICOMStoresToolKind, + "source": "my-instance", + "description": "Tool to list DICOM stores", + }, + "my-auth-get-dataset-tool": map[string]any{ + "kind": getDatasetToolKind, + "source": "my-instance", + "description": "Tool to get a healthcare dataset with auth", + "authRequired": []string{ + "my-google-auth", + }, + }, + "my-auth-list-fhir-stores-tool": map[string]any{ + "kind": listFHIRStoresToolKind, + "source": "my-instance", + "description": "Tool to list FHIR stores with auth", + "authRequired": []string{ + "my-google-auth", + }, + }, + "my-auth-list-dicom-stores-tool": map[string]any{ + "kind": listDICOMStoresToolKind, + "source": "my-instance", + "description": "Tool to list DICOM stores with auth", + "authRequired": []string{ + "my-google-auth", + }, + }, + }, + "authServices": map[string]any{ + "my-google-auth": map[string]any{ + "kind": "google", + "clientId": tests.ClientId, + }, + }, + } + return config +} + +func addClientAuthSourceConfig(t *testing.T, config map[string]any, datasetID string) map[string]any { + sources, ok := config["sources"].(map[string]any) + if !ok { + t.Fatalf("unable to get sources from config") + } + sources["my-client-auth-source"] = map[string]any{ + "kind": healthcareSourceKind, + "project": healthcareProject, + "region": healthcareRegion, + "dataset": datasetID, + "useClientOAuth": true, + } + config["sources"] = sources + return config +} + +func runGetDatasetToolInvokeTest(t *testing.T, want string) { + idToken, err := tests.GetGoogleIdToken(tests.ClientId) + if err != nil { + t.Fatalf("error getting Google ID token: %s", err) + } + + accessToken, err := sources.GetIAMAccessToken(t.Context()) + if err != nil { + t.Fatalf("error getting access token from ADC: %s", err) + } + accessToken = "Bearer " + accessToken + + invokeTcs := []struct { + name string + api string + requestHeader map[string]string + requestBody io.Reader + want string + isErr bool + }{ + { + name: "invoke my-get-dataset-tool", + api: "http://127.0.0.1:5000/api/tool/my-get-dataset-tool/invoke", + requestHeader: map[string]string{}, + requestBody: bytes.NewBuffer([]byte(`{}`)), + want: want, + isErr: false, + }, + { + name: "invoke my-auth-get-dataset-tool with auth", + api: "http://127.0.0.1:5000/api/tool/my-auth-get-dataset-tool/invoke", + requestHeader: map[string]string{"my-google-auth_token": idToken}, + requestBody: bytes.NewBuffer([]byte(`{}`)), + want: want, + isErr: false, + }, + { + name: "invoke my-auth-get-dataset-tool with client auth", + api: "http://127.0.0.1:5000/api/tool/my-get-dataset-tool/invoke", + requestHeader: map[string]string{"Authorization": accessToken}, + requestBody: bytes.NewBuffer([]byte(`{}`)), + want: want, + isErr: false, + }, + { + name: "invoke my-auth-get-dataset-tool without auth token", + api: "http://127.0.0.1:5000/api/tool/my-auth-get-dataset-tool/invoke", + requestHeader: map[string]string{}, + requestBody: bytes.NewBuffer([]byte(`{}`)), + isErr: true, + }, + { + name: "invoke my-auth-get-dataset-tool with invalid auth token", + api: "http://127.0.0.1:5000/api/tool/my-auth-get-dataset-tool/invoke", + requestHeader: map[string]string{"Authorization": "invalid-token"}, + requestBody: bytes.NewBuffer([]byte(`{}`)), + isErr: true, + }, + } + for _, tc := range invokeTcs { + t.Run(tc.name, func(t *testing.T) { + runTest(t, tc.api, tc.requestHeader, tc.requestBody, tc.want, tc.isErr) + }) + } +} + +func runListFHIRStoresToolInvokeTest(t *testing.T, want string) { + idToken, err := tests.GetGoogleIdToken(tests.ClientId) + if err != nil { + t.Fatalf("error getting Google ID token: %s", err) + } + + accessToken, err := sources.GetIAMAccessToken(t.Context()) + if err != nil { + t.Fatalf("error getting access token from ADC: %s", err) + } + accessToken = "Bearer " + accessToken + + invokeTcs := []struct { + name string + api string + requestHeader map[string]string + requestBody io.Reader + want string + isErr bool + }{ + { + name: "invoke my-list-fhir-stores-tool", + api: "http://127.0.0.1:5000/api/tool/my-list-fhir-stores-tool/invoke", + requestHeader: map[string]string{}, + requestBody: bytes.NewBuffer([]byte(`{}`)), + want: want, + isErr: false, + }, + { + name: "invoke my-auth-list-fhir-stores-tool with auth", + api: "http://127.0.0.1:5000/api/tool/my-auth-list-fhir-stores-tool/invoke", + requestHeader: map[string]string{"my-google-auth_token": idToken}, + requestBody: bytes.NewBuffer([]byte(`{}`)), + want: want, + isErr: false, + }, + { + name: "invoke my-auth-list-fhir-stores-tool with client auth", + api: "http://127.0.0.1:5000/api/tool/my-list-fhir-stores-tool/invoke", + requestHeader: map[string]string{"Authorization": accessToken}, + requestBody: bytes.NewBuffer([]byte(`{}`)), + want: want, + isErr: false, + }, + { + name: "invoke my-auth-list-fhir-stores-tool without auth token", + api: "http://127.0.0.1:5000/api/tool/my-auth-list-fhir-stores-tool/invoke", + requestHeader: map[string]string{}, + requestBody: bytes.NewBuffer([]byte(`{}`)), + isErr: true, + }, + { + name: "invoke my-auth-list-fhir-stores-tool with invalid auth token", + api: "http://127.0.0.1:5000/api/tool/my-auth-list-fhir-stores-tool/invoke", + requestHeader: map[string]string{"Authorization": "invalid-token"}, + requestBody: bytes.NewBuffer([]byte(`{}`)), + isErr: true, + }, + } + for _, tc := range invokeTcs { + t.Run(tc.name, func(t *testing.T) { + runTest(t, tc.api, tc.requestHeader, tc.requestBody, tc.want, tc.isErr) + }) + } +} + +func runListDICOMStoresToolInvokeTest(t *testing.T, want string) { + idToken, err := tests.GetGoogleIdToken(tests.ClientId) + if err != nil { + t.Fatalf("error getting Google ID token: %s", err) + } + + accessToken, err := sources.GetIAMAccessToken(t.Context()) + if err != nil { + t.Fatalf("error getting access token from ADC: %s", err) + } + accessToken = "Bearer " + accessToken + + invokeTcs := []struct { + name string + api string + requestHeader map[string]string + requestBody io.Reader + want string + isErr bool + }{ + { + name: "invoke my-list-dicom-stores-tool", + api: "http://127.0.0.1:5000/api/tool/my-list-dicom-stores-tool/invoke", + requestHeader: map[string]string{}, + requestBody: bytes.NewBuffer([]byte(`{}`)), + want: want, + isErr: false, + }, + { + name: "invoke my-auth-list-dicom-stores-tool with auth", + api: "http://127.0.0.1:5000/api/tool/my-auth-list-dicom-stores-tool/invoke", + requestHeader: map[string]string{"my-google-auth_token": idToken}, + requestBody: bytes.NewBuffer([]byte(`{}`)), + want: want, + isErr: false, + }, + { + name: "invoke my-auth-list-dicom-stores-tool with client auth", + api: "http://127.0.0.1:5000/api/tool/my-list-dicom-stores-tool/invoke", + requestHeader: map[string]string{"Authorization": accessToken}, + requestBody: bytes.NewBuffer([]byte(`{}`)), + want: want, + isErr: false, + }, + { + name: "invoke my-auth-list-dicom-stores-tool without auth token", + api: "http://127.0.0.1:5000/api/tool/my-auth-list-dicom-stores-tool/invoke", + requestHeader: map[string]string{}, + requestBody: bytes.NewBuffer([]byte(`{}`)), + isErr: true, + }, + { + name: "invoke my-auth-list-dicom-stores-tool with invalid auth token", + api: "http://127.0.0.1:5000/api/tool/my-auth-list-dicom-stores-tool/invoke", + requestHeader: map[string]string{"Authorization": "invalid-token"}, + requestBody: bytes.NewBuffer([]byte(`{}`)), + isErr: true, + }, + } + for _, tc := range invokeTcs { + t.Run(tc.name, func(t *testing.T) { + runTest(t, tc.api, tc.requestHeader, tc.requestBody, tc.want, tc.isErr) + }) + } +} + +func runTest(t *testing.T, api string, requestHeader map[string]string, requestBody io.Reader, want string, isErr bool) { + req, err := http.NewRequest(http.MethodPost, api, requestBody) + if err != nil { + t.Fatalf("unable to create request: %s", err) + } + req.Header.Add("Content-type", "application/json") + for k, v := range requestHeader { + req.Header.Add(k, v) + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("unable to send request: %s", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + if isErr { + return + } + bodyBytes, _ := io.ReadAll(resp.Body) + t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes)) + } + + var body map[string]interface{} + err = json.NewDecoder(resp.Body).Decode(&body) + if err != nil { + t.Fatalf("error parsing response body") + } + + got, ok := body["result"].(string) + if !ok { + t.Fatalf("unable to find result in response body") + } + + if !strings.Contains(got, want) { + t.Fatalf("expected %q to contain %q, but it did not", got, want) + } +} + +func runListFHIRStoresWithRestriction(t *testing.T, allowedFHIRStore, disallowedFHIRStore string) { + api := "http://127.0.0.1:5000/api/tool/list-fhir-stores-restricted/invoke" + req, err := http.NewRequest(http.MethodPost, api, bytes.NewBuffer([]byte(`{}`))) + if err != nil { + t.Fatalf("unable to create request: %s", err) + } + req.Header.Add("Content-type", "application/json") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("unable to send request: %s", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes)) + } + + var body map[string]interface{} + err = json.NewDecoder(resp.Body).Decode(&body) + if err != nil { + t.Fatalf("error parsing response body") + } + + got, ok := body["result"].(string) + if !ok { + t.Fatalf("unable to find result in response body") + } + + if !strings.Contains(got, allowedFHIRStore) { + t.Fatalf("expected %q to contain %q, but it did not", got, allowedFHIRStore) + } + if strings.Contains(got, disallowedFHIRStore) { + t.Fatalf("expected %q to NOT contain %q, but it did", got, disallowedFHIRStore) + } +} + +func runListDICOMStoresWithRestriction(t *testing.T, allowedDICOMStore, disallowedDICOMStore string) { + api := "http://127.0.0.1:5000/api/tool/list-dicom-stores-restricted/invoke" + req, err := http.NewRequest(http.MethodPost, api, bytes.NewBuffer([]byte(`{}`))) + if err != nil { + t.Fatalf("unable to create request: %s", err) + } + req.Header.Add("Content-type", "application/json") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("unable to send request: %s", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes)) + } + + var body map[string]interface{} + err = json.NewDecoder(resp.Body).Decode(&body) + if err != nil { + t.Fatalf("error parsing response body") + } + + got, ok := body["result"].(string) + if !ok { + t.Fatalf("unable to find result in response body") + } + + if !strings.Contains(got, allowedDICOMStore) { + t.Fatalf("expected %q to contain %q, but it did not", got, allowedDICOMStore) + } + if strings.Contains(got, disallowedDICOMStore) { + t.Fatalf("expected %q to NOT contain %q, but it did", got, disallowedDICOMStore) + } +}