diff --git a/.ci/integration.cloudbuild.yaml b/.ci/integration.cloudbuild.yaml index 6f9e6f48ae5a..e99b147c8130 100644 --- a/.ci/integration.cloudbuild.yaml +++ b/.ci/integration.cloudbuild.yaml @@ -214,6 +214,28 @@ steps: dataform \ dataform + - id: "healthcare" + name: golang:1 + waitFor: ["compile-test-binary"] + entrypoint: /bin/bash + env: + - "GOPATH=/gopath" + - "HEALTHCARE_PROJECT=$PROJECT_ID" + - "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL" + - "HEALTHCARE_REGION=$_REGION" + - "HEALTHCARE_DATASET=$_HEALTHCARE_DATASET" + secretEnv: ["CLIENT_ID"] + volumes: + - name: "go" + path: "/gopath" + args: + - -c + - | + .ci/test_with_coverage.sh \ + "Healthcare" \ + healthcare \ + healthcare + - id: "postgres" name: golang:1 waitFor: ["compile-test-binary"] @@ -858,6 +880,7 @@ substitutions: _ALLOYDB_AI_NL_CLUSTER: "alloydb-ai-nl-testing" _ALLOYDB_AI_NL_INSTANCE: "alloydb-ai-nl-testing-instance" _BIGTABLE_INSTANCE: "bigtable-testing-instance" + _HEALTHCARE_DATASET: "test-dataset" _POSTGRES_HOST: 127.0.0.1 _POSTGRES_PORT: "5432" _SPANNER_INSTANCE: "spanner-testing" diff --git a/cmd/root.go b/cmd/root.go index 83162d042814..d4ce5c9c6624 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -96,6 +96,9 @@ import ( _ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestorequerycollection" _ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestoreupdatedocument" _ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestorevalidaterules" + _ "github.com/googleapis/genai-toolbox/internal/tools/healthcare/gethealthcaredataset" + _ "github.com/googleapis/genai-toolbox/internal/tools/healthcare/listdicomstores" + _ "github.com/googleapis/genai-toolbox/internal/tools/healthcare/listfhirstores" _ "github.com/googleapis/genai-toolbox/internal/tools/http" _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookeradddashboardelement" _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercreateprojectfile" @@ -187,6 +190,7 @@ import ( _ "github.com/googleapis/genai-toolbox/internal/sources/dgraph" _ "github.com/googleapis/genai-toolbox/internal/sources/firebird" _ "github.com/googleapis/genai-toolbox/internal/sources/firestore" + _ "github.com/googleapis/genai-toolbox/internal/sources/healthcare" _ "github.com/googleapis/genai-toolbox/internal/sources/http" _ "github.com/googleapis/genai-toolbox/internal/sources/looker" _ "github.com/googleapis/genai-toolbox/internal/sources/mongodb" diff --git a/docs/en/resources/sources/healthcare.md b/docs/en/resources/sources/healthcare.md index feb8491595e9..755414504016 100644 --- a/docs/en/resources/sources/healthcare.md +++ b/docs/en/resources/sources/healthcare.md @@ -34,6 +34,17 @@ If you are new to the Healthcare API, you can try to [healthcare-quickstart-curl]: https://cloud.google.com/healthcare-api/docs/store-healthcare-data-rest +## Available Tools + +- [`get-healthcare-dataset`](../tools/healthcare/get-healthcare-dataset.md) + Retrieves a dataset’s details. + +- [`list-fhir-stores`](../tools/healthcare/list-fhir-stores.md) + Lists the available FHIR stores in the healthcare dataset. + +- [`list-dicom-stores`](../tools/healthcare/list-dicom-stores.md) + Lists the available DICOM stores in the healthcare dataset. + ## Requirements ### IAM Permissions diff --git a/docs/en/resources/tools/healthcare/_index.md b/docs/en/resources/tools/healthcare/_index.md new file mode 100644 index 000000000000..8781823bbeb1 --- /dev/null +++ b/docs/en/resources/tools/healthcare/_index.md @@ -0,0 +1,8 @@ +--- +title: "Cloud Healthcare API" +linkTitle: "Healthcare" +type: docs +weight: 1 +description: > + Tools that work with Healthcare Sources. +--- \ No newline at end of file diff --git a/docs/en/resources/tools/healthcare/get-healthcare-dataset.md b/docs/en/resources/tools/healthcare/get-healthcare-dataset.md new file mode 100644 index 000000000000..0d2c78165496 --- /dev/null +++ b/docs/en/resources/tools/healthcare/get-healthcare-dataset.md @@ -0,0 +1,38 @@ +--- +title: "get-healthcare-dataset" +linkTitle: "get-healthcare-dataset" +type: docs +weight: 1 +description: > + A "get-healthcare-dataset" tool retrieves metadata for the Healthcare dataset in the source. +aliases: +- /resources/tools/healthcare-get-healthcare-dataset +--- + +## About + +A `get-healthcare-dataset` tool retrieves metadata for a Healthcare dataset. +It's compatible with the following sources: + +- [healthcare](../../sources/healthcare.md) + +`get-healthcare-dataset` returns the metadata of the healthcare dataset +configured in the source. It takes no extra parameters. + +## Example + +```yaml +tools: + get_healthcare_dataset: + kind: get-healthcare-dataset + source: my-healthcare-source + description: Use this tool to get healthcare dataset metadata. +``` + +## Reference + +| **field** | **type** | **required** | **description** | +|-------------|:------------------------------------------:|:------------:|----------------------------------------------------| +| kind | string | true | Must be "get-healthcare-dataset". | +| source | string | true | Name of the healthcare source. | +| description | string | true | Description of the tool that is passed to the LLM. | diff --git a/docs/en/resources/tools/healthcare/list-dicom-stores.md b/docs/en/resources/tools/healthcare/list-dicom-stores.md new file mode 100644 index 000000000000..e32703494097 --- /dev/null +++ b/docs/en/resources/tools/healthcare/list-dicom-stores.md @@ -0,0 +1,38 @@ +--- +title: "list-dicom-stores" +linkTitle: "list-dicom-stores" +type: docs +weight: 1 +description: > + A "list-dicom-stores" lists the available DICOM stores in the healthcare dataset. +aliases: +- /resources/tools/healthcare-list-dicom-stores +--- + +## About + +A `list-dicom-stores` lists the available DICOM stores in the healthcare dataset. +It's compatible with the following sources: + +- [healthcare](../../sources/healthcare.md) + +`list-dicom-stores` returns the details of the available DICOM stores in the +dataset of the healthcare source. It takes no extra parameters. + +## Example + +```yaml +tools: + list_dicom_stores: + kind: list-dicom-stores + source: my-healthcare-source + description: Use this tool to list DICOM stores in the healthcare dataset. +``` + +## Reference + +| **field** | **type** | **required** | **description** | +|-------------|:------------------------------------------:|:------------:|----------------------------------------------------| +| kind | string | true | Must be "list-dicom-stores". | +| source | string | true | Name of the healthcare source. | +| description | string | true | Description of the tool that is passed to the LLM. | diff --git a/docs/en/resources/tools/healthcare/list-fhir-stores.md b/docs/en/resources/tools/healthcare/list-fhir-stores.md new file mode 100644 index 000000000000..e53368ad92fe --- /dev/null +++ b/docs/en/resources/tools/healthcare/list-fhir-stores.md @@ -0,0 +1,38 @@ +--- +title: "list-fhir-stores" +linkTitle: "list-fhir-stores" +type: docs +weight: 1 +description: > + A "list-fhir-stores" lists the available FHIR stores in the healthcare dataset. +aliases: +- /resources/tools/healthcare-list-fhir-stores +--- + +## About + +A `list-fhir-stores` lists the available FHIR stores in the healthcare dataset. +It's compatible with the following sources: + +- [healthcare](../../sources/healthcare.md) + +`list-fhir-stores` returns the details of the available FHIR stores in the +dataset of the healthcare source. It takes no extra parameters. + +## Example + +```yaml +tools: + list_fhir_stores: + kind: list-fhir-stores + source: my-healthcare-source + description: Use this tool to list FHIR stores in the healthcare dataset. +``` + +## Reference + +| **field** | **type** | **required** | **description** | +|-------------|:------------------------------------------:|:------------:|----------------------------------------------------| +| kind | string | true | Must be "list-fhir-stores". | +| source | string | true | Name of the healthcare source. | +| description | string | true | Description of the tool that is passed to the LLM. | diff --git a/internal/tools/healthcare/gethealthcaredataset/gethealthcaredataset.go b/internal/tools/healthcare/gethealthcaredataset/gethealthcaredataset.go new file mode 100644 index 000000000000..e25cb8a672be --- /dev/null +++ b/internal/tools/healthcare/gethealthcaredataset/gethealthcaredataset.go @@ -0,0 +1,166 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gethealthcaredataset + +import ( + "context" + "fmt" + + "github.com/goccy/go-yaml" + "github.com/googleapis/genai-toolbox/internal/sources" + healthcareds "github.com/googleapis/genai-toolbox/internal/sources/healthcare" + "github.com/googleapis/genai-toolbox/internal/tools" + "google.golang.org/api/healthcare/v1" +) + +const kind string = "get-healthcare-dataset" + +func init() { + if !tools.Register(kind, newConfig) { + panic(fmt.Sprintf("tool kind %q already registered", kind)) + } +} + +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { + actual := Config{Name: name} + if err := decoder.DecodeContext(ctx, &actual); err != nil { + return nil, err + } + return actual, nil +} + +type compatibleSource interface { + Project() string + Region() string + DatasetID() string + Service() *healthcare.Service + ServiceCreator() healthcareds.HealthcareServiceCreator + UseClientAuthorization() bool +} + +// validate compatible sources are still compatible +var _ compatibleSource = &healthcareds.Source{} + +var compatibleSources = [...]string{healthcareds.SourceKind} + +type Config struct { + Name string `yaml:"name" validate:"required"` + Kind string `yaml:"kind" validate:"required"` + Source string `yaml:"source" validate:"required"` + Description string `yaml:"description" validate:"required"` + AuthRequired []string `yaml:"authRequired"` +} + +// validate interface +var _ tools.ToolConfig = Config{} + +func (cfg Config) ToolConfigKind() string { + return kind +} + +func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { + // verify source exists + rawS, ok := srcs[cfg.Source] + if !ok { + return nil, fmt.Errorf("no source named %q configured", cfg.Source) + } + + // verify the source is compatible + s, ok := rawS.(compatibleSource) + if !ok { + return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + } + + parameters := tools.Parameters{} + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, parameters) + + // finish tool setup + t := Tool{ + Name: cfg.Name, + Kind: kind, + Parameters: parameters, + AuthRequired: cfg.AuthRequired, + Project: s.Project(), + Region: s.Region(), + Dataset: s.DatasetID(), + UseClientOAuth: s.UseClientAuthorization(), + ServiceCreator: s.ServiceCreator(), + Service: s.Service(), + manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, + } + return t, nil +} + +// validate interface +var _ tools.Tool = Tool{} + +type Tool struct { + Name string `yaml:"name"` + Kind string `yaml:"kind"` + AuthRequired []string `yaml:"authRequired"` + UseClientOAuth bool `yaml:"useClientOAuth"` + Parameters tools.Parameters `yaml:"parameters"` + + Project, Region, Dataset string + Service *healthcare.Service + ServiceCreator healthcareds.HealthcareServiceCreator + manifest tools.Manifest + mcpManifest tools.McpManifest +} + +func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) { + svc := t.Service + var err error + + // Initialize new service if using user OAuth token + if t.UseClientOAuth { + tokenStr, err := accessToken.ParseBearerToken() + if err != nil { + return nil, fmt.Errorf("error parsing access token: %w", err) + } + svc, err = t.ServiceCreator(tokenStr) + if err != nil { + return nil, fmt.Errorf("error creating service from OAuth access token: %w", err) + } + } + + datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", t.Project, t.Region, t.Dataset) + dataset, err := svc.Projects.Locations.Datasets.Get(datasetName).Do() + if err != nil { + return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err) + } + return dataset, nil +} + +func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) { + return tools.ParseParams(t.Parameters, data, claims) +} + +func (t Tool) Manifest() tools.Manifest { + return t.manifest +} + +func (t Tool) McpManifest() tools.McpManifest { + return t.mcpManifest +} + +func (t Tool) Authorized(verifiedAuthServices []string) bool { + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) +} + +func (t Tool) RequiresClientAuthorization() bool { + return t.UseClientOAuth +} diff --git a/internal/tools/healthcare/gethealthcaredataset/gethealthcaredataset_test.go b/internal/tools/healthcare/gethealthcaredataset/gethealthcaredataset_test.go new file mode 100644 index 000000000000..f3711a4202d6 --- /dev/null +++ b/internal/tools/healthcare/gethealthcaredataset/gethealthcaredataset_test.go @@ -0,0 +1,72 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gethealthcaredataset_test + +import ( + "testing" + + yaml "github.com/goccy/go-yaml" + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/testutils" + getdataset "github.com/googleapis/genai-toolbox/internal/tools/healthcare/gethealthcaredataset" +) + +func TestParseFromYamlGetHealthcareDataset(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + desc string + in string + want server.ToolConfigs + }{ + { + desc: "basic example", + in: ` + tools: + example_tool: + kind: get-healthcare-dataset + source: my-instance + description: some description + `, + want: server.ToolConfigs{ + "example_tool": getdataset.Config{ + Name: "example_tool", + Kind: "get-healthcare-dataset", + Source: "my-instance", + Description: "some description", + AuthRequired: []string{}, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Tools server.ToolConfigs `yaml:"tools"` + }{} + // Parse contents + err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if diff := cmp.Diff(tc.want, got.Tools); diff != "" { + t.Fatalf("incorrect parse: diff %v", diff) + } + }) + } +} diff --git a/internal/tools/healthcare/listdicomstores/listdicomstores.go b/internal/tools/healthcare/listdicomstores/listdicomstores.go new file mode 100644 index 000000000000..aec8b21e33a5 --- /dev/null +++ b/internal/tools/healthcare/listdicomstores/listdicomstores.go @@ -0,0 +1,184 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package listdicomstores + +import ( + "context" + "fmt" + "strings" + + "github.com/goccy/go-yaml" + "github.com/googleapis/genai-toolbox/internal/sources" + healthcareds "github.com/googleapis/genai-toolbox/internal/sources/healthcare" + "github.com/googleapis/genai-toolbox/internal/tools" + "google.golang.org/api/healthcare/v1" +) + +const kind string = "list-dicom-stores" + +func init() { + if !tools.Register(kind, newConfig) { + panic(fmt.Sprintf("tool kind %q already registered", kind)) + } +} + +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { + actual := Config{Name: name} + if err := decoder.DecodeContext(ctx, &actual); err != nil { + return nil, err + } + return actual, nil +} + +type compatibleSource interface { + Project() string + Region() string + DatasetID() string + AllowedDICOMStores() map[string]struct{} + Service() *healthcare.Service + ServiceCreator() healthcareds.HealthcareServiceCreator + UseClientAuthorization() bool +} + +// validate compatible sources are still compatible +var _ compatibleSource = &healthcareds.Source{} + +var compatibleSources = [...]string{healthcareds.SourceKind} + +type Config struct { + Name string `yaml:"name" validate:"required"` + Kind string `yaml:"kind" validate:"required"` + Source string `yaml:"source" validate:"required"` + Description string `yaml:"description" validate:"required"` + AuthRequired []string `yaml:"authRequired"` +} + +// validate interface +var _ tools.ToolConfig = Config{} + +func (cfg Config) ToolConfigKind() string { + return kind +} + +func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { + // verify source exists + rawS, ok := srcs[cfg.Source] + if !ok { + return nil, fmt.Errorf("no source named %q configured", cfg.Source) + } + + // verify the source is compatible + s, ok := rawS.(compatibleSource) + if !ok { + return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + } + + parameters := tools.Parameters{} + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, parameters) + + // finish tool setup + t := Tool{ + Name: cfg.Name, + Kind: kind, + Parameters: parameters, + AuthRequired: cfg.AuthRequired, + Project: s.Project(), + Region: s.Region(), + Dataset: s.DatasetID(), + AllowedStores: s.AllowedDICOMStores(), + UseClientOAuth: s.UseClientAuthorization(), + ServiceCreator: s.ServiceCreator(), + Service: s.Service(), + manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, + } + return t, nil +} + +// validate interface +var _ tools.Tool = Tool{} + +type Tool struct { + Name string `yaml:"name"` + Kind string `yaml:"kind"` + AuthRequired []string `yaml:"authRequired"` + UseClientOAuth bool `yaml:"useClientOAuth"` + Parameters tools.Parameters `yaml:"parameters"` + + Project, Region, Dataset string + AllowedStores map[string]struct{} + Service *healthcare.Service + ServiceCreator healthcareds.HealthcareServiceCreator + manifest tools.Manifest + mcpManifest tools.McpManifest +} + +func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) { + svc := t.Service + var err error + + // Initialize new service if using user OAuth token + if t.UseClientOAuth { + tokenStr, err := accessToken.ParseBearerToken() + if err != nil { + return nil, fmt.Errorf("error parsing access token: %w", err) + } + svc, err = t.ServiceCreator(tokenStr) + if err != nil { + return nil, fmt.Errorf("error creating service from OAuth access token: %w", err) + } + } + + datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", t.Project, t.Region, t.Dataset) + stores, err := svc.Projects.Locations.Datasets.DicomStores.List(datasetName).Do() + if err != nil { + return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err) + } + var filtered []*healthcare.DicomStore + for _, store := range stores.DicomStores { + if len(t.AllowedStores) == 0 { + filtered = append(filtered, store) + continue + } + if len(store.Name) == 0 { + continue + } + parts := strings.Split(store.Name, "/") + if _, ok := t.AllowedStores[parts[len(parts)-1]]; ok { + filtered = append(filtered, store) + } + } + return filtered, nil +} + +func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) { + return tools.ParseParams(t.Parameters, data, claims) +} + +func (t Tool) Manifest() tools.Manifest { + return t.manifest +} + +func (t Tool) McpManifest() tools.McpManifest { + return t.mcpManifest +} + +func (t Tool) Authorized(verifiedAuthServices []string) bool { + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) +} + +func (t Tool) RequiresClientAuthorization() bool { + return t.UseClientOAuth +} diff --git a/internal/tools/healthcare/listdicomstores/listdicomstores_test.go b/internal/tools/healthcare/listdicomstores/listdicomstores_test.go new file mode 100644 index 000000000000..707a5a16390e --- /dev/null +++ b/internal/tools/healthcare/listdicomstores/listdicomstores_test.go @@ -0,0 +1,72 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package listdicomstores_test + +import ( + "testing" + + yaml "github.com/goccy/go-yaml" + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/internal/tools/healthcare/listdicomstores" +) + +func TestParseFromYamlHealthcareListDICOMStores(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + desc string + in string + want server.ToolConfigs + }{ + { + desc: "basic example", + in: ` + tools: + example_tool: + kind: list-dicom-stores + source: my-instance + description: some description + `, + want: server.ToolConfigs{ + "example_tool": listdicomstores.Config{ + Name: "example_tool", + Kind: "list-dicom-stores", + Source: "my-instance", + Description: "some description", + AuthRequired: []string{}, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Tools server.ToolConfigs `yaml:"tools"` + }{} + // Parse contents + err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if diff := cmp.Diff(tc.want, got.Tools); diff != "" { + t.Fatalf("incorrect parse: diff %v", diff) + } + }) + } +} diff --git a/internal/tools/healthcare/listfhirstores/listfhirstores.go b/internal/tools/healthcare/listfhirstores/listfhirstores.go new file mode 100644 index 000000000000..e2b76ff89258 --- /dev/null +++ b/internal/tools/healthcare/listfhirstores/listfhirstores.go @@ -0,0 +1,184 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package listfhirstores + +import ( + "context" + "fmt" + "strings" + + "github.com/goccy/go-yaml" + "github.com/googleapis/genai-toolbox/internal/sources" + healthcareds "github.com/googleapis/genai-toolbox/internal/sources/healthcare" + "github.com/googleapis/genai-toolbox/internal/tools" + "google.golang.org/api/healthcare/v1" +) + +const kind string = "list-fhir-stores" + +func init() { + if !tools.Register(kind, newConfig) { + panic(fmt.Sprintf("tool kind %q already registered", kind)) + } +} + +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { + actual := Config{Name: name} + if err := decoder.DecodeContext(ctx, &actual); err != nil { + return nil, err + } + return actual, nil +} + +type compatibleSource interface { + Project() string + Region() string + DatasetID() string + AllowedFHIRStores() map[string]struct{} + Service() *healthcare.Service + ServiceCreator() healthcareds.HealthcareServiceCreator + UseClientAuthorization() bool +} + +// validate compatible sources are still compatible +var _ compatibleSource = &healthcareds.Source{} + +var compatibleSources = [...]string{healthcareds.SourceKind} + +type Config struct { + Name string `yaml:"name" validate:"required"` + Kind string `yaml:"kind" validate:"required"` + Source string `yaml:"source" validate:"required"` + Description string `yaml:"description" validate:"required"` + AuthRequired []string `yaml:"authRequired"` +} + +// validate interface +var _ tools.ToolConfig = Config{} + +func (cfg Config) ToolConfigKind() string { + return kind +} + +func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { + // verify source exists + rawS, ok := srcs[cfg.Source] + if !ok { + return nil, fmt.Errorf("no source named %q configured", cfg.Source) + } + + // verify the source is compatible + s, ok := rawS.(compatibleSource) + if !ok { + return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + } + + parameters := tools.Parameters{} + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, parameters) + + // finish tool setup + t := Tool{ + Name: cfg.Name, + Kind: kind, + Parameters: parameters, + AuthRequired: cfg.AuthRequired, + Project: s.Project(), + Region: s.Region(), + Dataset: s.DatasetID(), + AllowedStores: s.AllowedFHIRStores(), + UseClientOAuth: s.UseClientAuthorization(), + ServiceCreator: s.ServiceCreator(), + Service: s.Service(), + manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, + } + return t, nil +} + +// validate interface +var _ tools.Tool = Tool{} + +type Tool struct { + Name string `yaml:"name"` + Kind string `yaml:"kind"` + AuthRequired []string `yaml:"authRequired"` + UseClientOAuth bool `yaml:"useClientOAuth"` + Parameters tools.Parameters `yaml:"parameters"` + + Project, Region, Dataset string + AllowedStores map[string]struct{} + Service *healthcare.Service + ServiceCreator healthcareds.HealthcareServiceCreator + manifest tools.Manifest + mcpManifest tools.McpManifest +} + +func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) { + svc := t.Service + var err error + + // Initialize new service if using user OAuth token + if t.UseClientOAuth { + tokenStr, err := accessToken.ParseBearerToken() + if err != nil { + return nil, fmt.Errorf("error parsing access token: %w", err) + } + svc, err = t.ServiceCreator(tokenStr) + if err != nil { + return nil, fmt.Errorf("error creating service from OAuth access token: %w", err) + } + } + + datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", t.Project, t.Region, t.Dataset) + stores, err := svc.Projects.Locations.Datasets.FhirStores.List(datasetName).Do() + if err != nil { + return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err) + } + var filtered []*healthcare.FhirStore + for _, store := range stores.FhirStores { + if len(t.AllowedStores) == 0 { + filtered = append(filtered, store) + continue + } + if len(store.Name) == 0 { + continue + } + parts := strings.Split(store.Name, "/") + if _, ok := t.AllowedStores[parts[len(parts)-1]]; ok { + filtered = append(filtered, store) + } + } + return filtered, nil +} + +func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) { + return tools.ParseParams(t.Parameters, data, claims) +} + +func (t Tool) Manifest() tools.Manifest { + return t.manifest +} + +func (t Tool) McpManifest() tools.McpManifest { + return t.mcpManifest +} + +func (t Tool) Authorized(verifiedAuthServices []string) bool { + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) +} + +func (t Tool) RequiresClientAuthorization() bool { + return t.UseClientOAuth +} diff --git a/internal/tools/healthcare/listfhirstores/listfhirstores_test.go b/internal/tools/healthcare/listfhirstores/listfhirstores_test.go new file mode 100644 index 000000000000..fbb9463780e2 --- /dev/null +++ b/internal/tools/healthcare/listfhirstores/listfhirstores_test.go @@ -0,0 +1,72 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package listfhirstores_test + +import ( + "testing" + + yaml "github.com/goccy/go-yaml" + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/internal/tools/healthcare/listfhirstores" +) + +func TestParseFromYamlHealthcareListFHIRStores(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + desc string + in string + want server.ToolConfigs + }{ + { + desc: "basic example", + in: ` + tools: + example_tool: + kind: list-fhir-stores + source: my-instance + description: some description + `, + want: server.ToolConfigs{ + "example_tool": listfhirstores.Config{ + Name: "example_tool", + Kind: "list-fhir-stores", + Source: "my-instance", + Description: "some description", + AuthRequired: []string{}, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Tools server.ToolConfigs `yaml:"tools"` + }{} + // Parse contents + err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if diff := cmp.Diff(tc.want, got.Tools); diff != "" { + t.Fatalf("incorrect parse: diff %v", diff) + } + }) + } +} diff --git a/tests/healthcare/healthcare_integration_test.go b/tests/healthcare/healthcare_integration_test.go new file mode 100644 index 000000000000..44b35cfd903c --- /dev/null +++ b/tests/healthcare/healthcare_integration_test.go @@ -0,0 +1,652 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// To run these tests, set the following environment variables: +// HEALTHCARE_PROJECT: Google Cloud project ID for healthcare resources. +// HEALTHCARE_REGION: Google Cloud region for healthcare resources. + +package healthcare + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "regexp" + "strings" + "testing" + "time" + + "github.com/google/uuid" + "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/tests" + "golang.org/x/oauth2/google" + "google.golang.org/api/healthcare/v1" + "google.golang.org/api/option" +) + +var ( + healthcareSourceKind = "healthcare" + getDatasetToolKind = "get-healthcare-dataset" + listFHIRStoresToolKind = "list-fhir-stores" + listDICOMStoresToolKind = "list-dicom-stores" + healthcareProject = os.Getenv("HEALTHCARE_PROJECT") + healthcareRegion = os.Getenv("HEALTHCARE_REGION") + healthcareDataset = os.Getenv("HEALTHCARE_DATASET") +) + +func getHealthcareVars(t *testing.T) map[string]any { + switch "" { + case healthcareProject: + t.Fatal("'HEALTHCARE_PROJECT' not set") + case healthcareRegion: + t.Fatal("'HEALTHCARE_REGION' not set") + case healthcareDataset: + t.Fatal("'HEALTHCARE_DATASET' not set") + } + return map[string]any{ + "kind": healthcareSourceKind, + "project": healthcareProject, + "region": healthcareRegion, + "dataset": healthcareDataset, + } +} + +func TestHealthcareToolEndpoints(t *testing.T) { + sourceConfig := getHealthcareVars(t) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + + healthcareService, err := newHealthcareService(ctx) + if err != nil { + t.Fatalf("failed to create healthcare service: %v", err) + } + + fhirStoreID := "fhir-store-" + uuid.New().String() + dicomStoreID := "dicom-store-" + uuid.New().String() + + teardown := setupHealthcareResources(t, ctx, healthcareService, healthcareDataset, fhirStoreID, dicomStoreID) + defer teardown(t) + + toolsFile := getToolsConfig(sourceConfig) + toolsFile = addClientAuthSourceConfig(t, toolsFile) + + var args []string + cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...) + if err != nil { + t.Fatalf("command initialization returned an error: %s", err) + } + defer cleanup() + + waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out) + if err != nil { + t.Logf("toolbox command logs: %s", out) + t.Fatalf("toolbox didn't start successfully: %s", err) + } + + datasetWant := fmt.Sprintf(`"name":"projects/%s/locations/%s/datasets/%s"`, healthcareProject, healthcareRegion, healthcareDataset) + fhirStoreWant := fmt.Sprintf(`"name":"projects/%s/locations/%s/datasets/%s/fhirStores/%s"`, healthcareProject, healthcareRegion, healthcareDataset, fhirStoreID) + dicomStoreWant := fmt.Sprintf(`"name":"projects/%s/locations/%s/datasets/%s/dicomStores/%s"`, healthcareProject, healthcareRegion, healthcareDataset, dicomStoreID) + + runGetDatasetToolInvokeTest(t, datasetWant) + runListFHIRStoresToolInvokeTest(t, fhirStoreWant) + runListDICOMStoresToolInvokeTest(t, dicomStoreWant) +} + +func TestHealthcareToolWithStoreRestriction(t *testing.T) { + sourceConfig := getHealthcareVars(t) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + + healthcareService, err := newHealthcareService(ctx) + if err != nil { + t.Fatalf("failed to create healthcare service: %v", err) + } + + // Create stores + allowedFHIRStoreID := "fhir-store-allowed-" + uuid.New().String() + allowedDICOMStoreID := "dicom-store-allowed-" + uuid.New().String() + disallowedFHIRStoreID := "fhir-store-disallowed-" + uuid.New().String() + disallowedDICOMStoreID := "dicom-store-disallowed-" + uuid.New().String() + + teardownAllowedStores := setupHealthcareResources(t, ctx, healthcareService, healthcareDataset, allowedFHIRStoreID, allowedDICOMStoreID) + defer teardownAllowedStores(t) + teardownDisallowedStores := setupHealthcareResources(t, ctx, healthcareService, healthcareDataset, disallowedFHIRStoreID, disallowedDICOMStoreID) + defer teardownDisallowedStores(t) + + // Configure source with dataset restriction. + sourceConfig["allowedFhirStores"] = []string{allowedFHIRStoreID} + sourceConfig["allowedDicomStores"] = []string{allowedDICOMStoreID} + + // Configure tool + toolsConfig := map[string]any{ + "list-fhir-stores-restricted": map[string]any{ + "kind": "list-fhir-stores", + "source": "my-instance", + "description": "Tool to list fhir stores", + }, + "list-dicom-stores-restricted": map[string]any{ + "kind": "list-dicom-stores", + "source": "my-instance", + "description": "Tool to list dicom stores", + }, + } + + // Create config file + config := map[string]any{ + "sources": map[string]any{ + "my-instance": sourceConfig, + }, + "tools": toolsConfig, + } + + // Start server + cmd, cleanup, err := tests.StartCmd(ctx, config) + if err != nil { + t.Fatalf("command initialization returned an error: %s", err) + } + defer cleanup() + + waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out) + if err != nil { + t.Logf("toolbox command logs: \n%s", out) + t.Fatalf("toolbox didn't start successfully: %s", err) + } + + // Run tests + runListFHIRStoresWithRestriction(t, allowedFHIRStoreID, disallowedFHIRStoreID) + runListDICOMStoresWithRestriction(t, allowedDICOMStoreID, disallowedDICOMStoreID) +} + +func newHealthcareService(ctx context.Context) (*healthcare.Service, error) { + creds, err := google.FindDefaultCredentials(ctx, healthcare.CloudHealthcareScope) + if err != nil { + return nil, fmt.Errorf("failed to find default credentials: %w", err) + } + + healthcareService, err := healthcare.NewService(ctx, option.WithCredentials(creds)) + if err != nil { + return nil, fmt.Errorf("failed to create healthcare service: %w", err) + } + return healthcareService, nil +} + +func setupHealthcareResources(t *testing.T, ctx context.Context, service *healthcare.Service, datasetID, fhirStoreID, dicomStoreID string) func(*testing.T) { + datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", healthcareProject, healthcareRegion, datasetID) + var err error + + // Create FHIR store + fhirStore := &healthcare.FhirStore{Version: "R4"} + if fhirStore, err = service.Projects.Locations.Datasets.FhirStores.Create(datasetName, fhirStore).FhirStoreId(fhirStoreID).Do(); err != nil { + t.Fatalf("failed to create fhir store: %v", err) + } + + // Create DICOM store + dicomStore := &healthcare.DicomStore{} + if dicomStore, err = service.Projects.Locations.Datasets.DicomStores.Create(datasetName, dicomStore).DicomStoreId(dicomStoreID).Do(); err != nil { + t.Fatalf("failed to create dicom store: %v", err) + } + + return func(t *testing.T) { + if _, err := service.Projects.Locations.Datasets.FhirStores.Delete(fhirStore.Name).Do(); err != nil { + t.Logf("failed to delete fhir store: %v", err) + } + if _, err := service.Projects.Locations.Datasets.DicomStores.Delete(dicomStore.Name).Do(); err != nil { + t.Logf("failed to delete dicom store: %v", err) + } + } +} + +func getToolsConfig(sourceConfig map[string]any) map[string]any { + config := map[string]any{ + "sources": map[string]any{ + "my-instance": sourceConfig, + }, + "tools": map[string]any{ + "my-get-dataset-tool": map[string]any{ + "kind": getDatasetToolKind, + "source": "my-instance", + "description": "Tool to get a healthcare dataset", + }, + "my-list-fhir-stores-tool": map[string]any{ + "kind": listFHIRStoresToolKind, + "source": "my-instance", + "description": "Tool to list FHIR stores", + }, + "my-list-dicom-stores-tool": map[string]any{ + "kind": listDICOMStoresToolKind, + "source": "my-instance", + "description": "Tool to list DICOM stores", + }, + "my-client-auth-get-dataset-tool": map[string]any{ + "kind": getDatasetToolKind, + "source": "my-client-auth-source", + "description": "Tool to get a healthcare dataset", + }, + "my-client-auth-list-fhir-stores-tool": map[string]any{ + "kind": listFHIRStoresToolKind, + "source": "my-client-auth-source", + "description": "Tool to list FHIR stores", + }, + "my-client-auth-list-dicom-stores-tool": map[string]any{ + "kind": listDICOMStoresToolKind, + "source": "my-client-auth-source", + "description": "Tool to list DICOM stores", + }, + "my-auth-get-dataset-tool": map[string]any{ + "kind": getDatasetToolKind, + "source": "my-instance", + "description": "Tool to get a healthcare dataset with auth", + "authRequired": []string{ + "my-google-auth", + }, + }, + "my-auth-list-fhir-stores-tool": map[string]any{ + "kind": listFHIRStoresToolKind, + "source": "my-instance", + "description": "Tool to list FHIR stores with auth", + "authRequired": []string{ + "my-google-auth", + }, + }, + "my-auth-list-dicom-stores-tool": map[string]any{ + "kind": listDICOMStoresToolKind, + "source": "my-instance", + "description": "Tool to list DICOM stores with auth", + "authRequired": []string{ + "my-google-auth", + }, + }, + }, + "authServices": map[string]any{ + "my-google-auth": map[string]any{ + "kind": "google", + "clientId": tests.ClientId, + }, + }, + } + return config +} + +func addClientAuthSourceConfig(t *testing.T, config map[string]any) map[string]any { + sources, ok := config["sources"].(map[string]any) + if !ok { + t.Fatalf("unable to get sources from config") + } + sources["my-client-auth-source"] = map[string]any{ + "kind": healthcareSourceKind, + "project": healthcareProject, + "region": healthcareRegion, + "dataset": healthcareDataset, + "useClientOAuth": true, + } + config["sources"] = sources + return config +} + +func runGetDatasetToolInvokeTest(t *testing.T, want string) { + idToken, err := tests.GetGoogleIdToken(tests.ClientId) + if err != nil { + t.Fatalf("error getting Google ID token: %s", err) + } + + accessToken, err := sources.GetIAMAccessToken(t.Context()) + if err != nil { + t.Fatalf("error getting access token from ADC: %s", err) + } + accessToken = "Bearer " + accessToken + + invokeTcs := []struct { + name string + api string + requestHeader map[string]string + requestBody io.Reader + want string + isErr bool + }{ + { + name: "invoke my-get-dataset-tool", + api: "http://127.0.0.1:5000/api/tool/my-get-dataset-tool/invoke", + requestHeader: map[string]string{}, + requestBody: bytes.NewBuffer([]byte(`{}`)), + want: want, + isErr: false, + }, + { + name: "invoke my-auth-get-dataset-tool with auth", + api: "http://127.0.0.1:5000/api/tool/my-auth-get-dataset-tool/invoke", + requestHeader: map[string]string{"my-google-auth_token": idToken}, + requestBody: bytes.NewBuffer([]byte(`{}`)), + want: want, + isErr: false, + }, + { + name: "invoke my-auth-get-dataset-tool with client auth", + api: "http://127.0.0.1:5000/api/tool/my-get-dataset-tool/invoke", + requestHeader: map[string]string{"Authorization": accessToken}, + requestBody: bytes.NewBuffer([]byte(`{}`)), + want: want, + isErr: false, + }, + { + name: "invoke my-auth-get-dataset-tool without auth token", + api: "http://127.0.0.1:5000/api/tool/my-auth-get-dataset-tool/invoke", + requestHeader: map[string]string{}, + requestBody: bytes.NewBuffer([]byte(`{}`)), + isErr: true, + }, + { + name: "invoke my-auth-get-dataset-tool with invalid auth token", + api: "http://127.0.0.1:5000/api/tool/my-auth-get-dataset-tool/invoke", + requestHeader: map[string]string{"Authorization": "invalid-token"}, + requestBody: bytes.NewBuffer([]byte(`{}`)), + isErr: true, + }, + { + name: "invoke my-client-auth-get-dataset-tool with client auth", + api: "http://127.0.0.1:5000/api/tool/my-client-auth-get-dataset-tool/invoke", + requestHeader: map[string]string{"Authorization": accessToken}, + requestBody: bytes.NewBuffer([]byte(`{}`)), + want: want, + isErr: false, + }, + { + name: "invoke my-client-auth-get-dataset-tool without auth token", + api: "http://127.0.0.1:5000/api/tool/my-client-auth-get-dataset-tool/invoke", + requestHeader: map[string]string{}, + requestBody: bytes.NewBuffer([]byte(`{}`)), + isErr: true, + }, + { + name: "invoke my-client-auth-get-dataset-tool with invalid auth token", + api: "http://127.0.0.1:5000/api/tool/my-client-auth-get-dataset-tool/invoke", + requestHeader: map[string]string{"my-google-auth_token": idToken}, + requestBody: bytes.NewBuffer([]byte(`{}`)), + isErr: true, + }, + } + for _, tc := range invokeTcs { + t.Run(tc.name, func(t *testing.T) { + got, status := runTest(t, tc.api, tc.requestHeader, tc.requestBody) + if tc.isErr { + if status == http.StatusOK { + t.Errorf("expected error but got success") + } + return + } + if status != http.StatusOK { + t.Errorf("expected status OK but got %d", status) + } else if !strings.Contains(got, tc.want) { + t.Errorf("expected result to contain %q but got %q", tc.want, got) + } + }) + } +} + +func runListFHIRStoresToolInvokeTest(t *testing.T, want string) { + idToken, err := tests.GetGoogleIdToken(tests.ClientId) + if err != nil { + t.Fatalf("error getting Google ID token: %s", err) + } + + accessToken, err := sources.GetIAMAccessToken(t.Context()) + if err != nil { + t.Fatalf("error getting access token from ADC: %s", err) + } + accessToken = "Bearer " + accessToken + + invokeTcs := []struct { + name string + api string + requestHeader map[string]string + requestBody io.Reader + want string + isErr bool + }{ + { + name: "invoke my-list-fhir-stores-tool", + api: "http://127.0.0.1:5000/api/tool/my-list-fhir-stores-tool/invoke", + requestHeader: map[string]string{}, + requestBody: bytes.NewBuffer([]byte(`{}`)), + want: want, + isErr: false, + }, + { + name: "invoke my-auth-list-fhir-stores-tool with auth", + api: "http://127.0.0.1:5000/api/tool/my-auth-list-fhir-stores-tool/invoke", + requestHeader: map[string]string{"my-google-auth_token": idToken}, + requestBody: bytes.NewBuffer([]byte(`{}`)), + want: want, + isErr: false, + }, + { + name: "invoke my-auth-list-fhir-stores-tool with client auth", + api: "http://127.0.0.1:5000/api/tool/my-list-fhir-stores-tool/invoke", + requestHeader: map[string]string{"Authorization": accessToken}, + requestBody: bytes.NewBuffer([]byte(`{}`)), + want: want, + isErr: false, + }, + { + name: "invoke my-auth-list-fhir-stores-tool without auth token", + api: "http://127.0.0.1:5000/api/tool/my-auth-list-fhir-stores-tool/invoke", + requestHeader: map[string]string{}, + requestBody: bytes.NewBuffer([]byte(`{}`)), + isErr: true, + }, + { + name: "invoke my-auth-list-fhir-stores-tool with invalid auth token", + api: "http://127.0.0.1:5000/api/tool/my-auth-list-fhir-stores-tool/invoke", + requestHeader: map[string]string{"Authorization": "invalid-token"}, + requestBody: bytes.NewBuffer([]byte(`{}`)), + isErr: true, + }, + { + name: "invoke my-client-auth-list-fhir-stores-tool with client auth", + api: "http://127.0.0.1:5000/api/tool/my-client-auth-list-fhir-stores-tool/invoke", + requestHeader: map[string]string{"Authorization": accessToken}, + requestBody: bytes.NewBuffer([]byte(`{}`)), + want: want, + isErr: false, + }, + { + name: "invoke my-client-auth-list-fhir-stores-tool without auth token", + api: "http://127.0.0.1:5000/api/tool/my-client-auth-list-fhir-stores-tool/invoke", + requestHeader: map[string]string{}, + requestBody: bytes.NewBuffer([]byte(`{}`)), + isErr: true, + }, + { + name: "invoke my-client-auth-list-fhir-stores-tool with invalid auth token", + api: "http://127.0.0.1:5000/api/tool/my-client-auth-list-fhir-stores-tool/invoke", + requestHeader: map[string]string{"my-google-auth_token": idToken}, + requestBody: bytes.NewBuffer([]byte(`{}`)), + isErr: true, + }, + } + for _, tc := range invokeTcs { + t.Run(tc.name, func(t *testing.T) { + got, status := runTest(t, tc.api, tc.requestHeader, tc.requestBody) + if tc.isErr { + if status == http.StatusOK { + t.Errorf("expected error but got success") + } + return + } + if status != http.StatusOK { + t.Errorf("expected status OK but got %d", status) + } else if !strings.Contains(got, tc.want) { + t.Errorf("expected result to contain %q but got %q", tc.want, got) + } + }) + } +} + +func runListDICOMStoresToolInvokeTest(t *testing.T, want string) { + idToken, err := tests.GetGoogleIdToken(tests.ClientId) + if err != nil { + t.Fatalf("error getting Google ID token: %s", err) + } + + accessToken, err := sources.GetIAMAccessToken(t.Context()) + if err != nil { + t.Fatalf("error getting access token from ADC: %s", err) + } + accessToken = "Bearer " + accessToken + + invokeTcs := []struct { + name string + api string + requestHeader map[string]string + requestBody io.Reader + want string + isErr bool + }{ + { + name: "invoke my-list-dicom-stores-tool", + api: "http://127.0.0.1:5000/api/tool/my-list-dicom-stores-tool/invoke", + requestHeader: map[string]string{}, + requestBody: bytes.NewBuffer([]byte(`{}`)), + want: want, + isErr: false, + }, + { + name: "invoke my-auth-list-dicom-stores-tool with auth", + api: "http://127.0.0.1:5000/api/tool/my-auth-list-dicom-stores-tool/invoke", + requestHeader: map[string]string{"my-google-auth_token": idToken}, + requestBody: bytes.NewBuffer([]byte(`{}`)), + want: want, + isErr: false, + }, + { + name: "invoke my-auth-list-dicom-stores-tool with client auth", + api: "http://127.0.0.1:5000/api/tool/my-list-dicom-stores-tool/invoke", + requestHeader: map[string]string{"Authorization": accessToken}, + requestBody: bytes.NewBuffer([]byte(`{}`)), + want: want, + isErr: false, + }, + { + name: "invoke my-auth-list-dicom-stores-tool without auth token", + api: "http://127.0.0.1:5000/api/tool/my-auth-list-dicom-stores-tool/invoke", + requestHeader: map[string]string{}, + requestBody: bytes.NewBuffer([]byte(`{}`)), + isErr: true, + }, + { + name: "invoke my-auth-list-dicom-stores-tool with invalid auth token", + api: "http://127.0.0.1:5000/api/tool/my-auth-list-dicom-stores-tool/invoke", + requestHeader: map[string]string{"Authorization": "invalid-token"}, + requestBody: bytes.NewBuffer([]byte(`{}`)), + isErr: true, + }, + { + name: "invoke my-client-auth-list-dicom-stores-tool with client auth", + api: "http://127.0.0.1:5000/api/tool/my-client-auth-list-dicom-stores-tool/invoke", + requestHeader: map[string]string{"Authorization": accessToken}, + requestBody: bytes.NewBuffer([]byte(`{}`)), + want: want, + isErr: false, + }, + { + name: "invoke my-client-auth-list-dicom-stores-tool without auth token", + api: "http://127.0.0.1:5000/api/tool/my-client-auth-list-dicom-stores-tool/invoke", + requestHeader: map[string]string{}, + requestBody: bytes.NewBuffer([]byte(`{}`)), + isErr: true, + }, + { + name: "invoke my-client-auth-list-dicom-stores-tool with invalid auth token", + api: "http://127.0.0.1:5000/api/tool/my-client-auth-list-dicom-stores-tool/invoke", + requestHeader: map[string]string{"my-google-auth_token": idToken}, + requestBody: bytes.NewBuffer([]byte(`{}`)), + isErr: true, + }, + } + for _, tc := range invokeTcs { + t.Run(tc.name, func(t *testing.T) { + got, status := runTest(t, tc.api, tc.requestHeader, tc.requestBody) + if tc.isErr { + if status == http.StatusOK { + t.Errorf("expected error but got success") + } + return + } + if status != http.StatusOK { + t.Errorf("expected status OK but got %d", status) + } else if !strings.Contains(got, tc.want) { + t.Errorf("expected result to contain %q but got %q", tc.want, got) + } + }) + } +} + +func runTest(t *testing.T, api string, requestHeader map[string]string, requestBody io.Reader) (string, int) { + resp, bodyBytes := tests.RunRequest(t, http.MethodPost, api, requestBody, requestHeader) + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", resp.StatusCode + } + + var body map[string]interface{} + err := json.Unmarshal(bodyBytes, &body) + if err != nil { + t.Fatalf("error parsing response body") + } + + got, ok := body["result"].(string) + if !ok { + t.Fatalf("unable to find result in response body") + } + return got, http.StatusOK +} + +func runListFHIRStoresWithRestriction(t *testing.T, allowedFHIRStore, disallowedFHIRStore string) { + api := "http://127.0.0.1:5000/api/tool/list-fhir-stores-restricted/invoke" + got, status := runTest(t, api, map[string]string{"Content-type": "application/json"}, bytes.NewBuffer([]byte(`{}`))) + if status != http.StatusOK { + t.Fatalf("expected status OK but got %d", status) + } + + if !strings.Contains(got, allowedFHIRStore) { + t.Fatalf("expected %q to contain %q, but it did not", got, allowedFHIRStore) + } + if strings.Contains(got, disallowedFHIRStore) { + t.Fatalf("expected %q to NOT contain %q, but it did", got, disallowedFHIRStore) + } +} + +func runListDICOMStoresWithRestriction(t *testing.T, allowedDICOMStore, disallowedDICOMStore string) { + api := "http://127.0.0.1:5000/api/tool/list-dicom-stores-restricted/invoke" + got, status := runTest(t, api, map[string]string{"Content-type": "application/json"}, bytes.NewBuffer([]byte(`{}`))) + if status != http.StatusOK { + t.Fatalf("expected status OK but got %d", status) + } + + if !strings.Contains(got, allowedDICOMStore) { + t.Fatalf("expected %q to contain %q, but it did not", got, allowedDICOMStore) + } + if strings.Contains(got, disallowedDICOMStore) { + t.Fatalf("expected %q to NOT contain %q, but it did", got, disallowedDICOMStore) + } +}