diff --git a/kedro_vertexai/config.py b/kedro_vertexai/config.py index 21284e6..d32a9f9 100644 --- a/kedro_vertexai/config.py +++ b/kedro_vertexai/config.py @@ -2,7 +2,7 @@ import os from importlib import import_module from inspect import signature -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional from pydantic import BaseModel, validator from pydantic.networks import IPvAnyAddress @@ -218,7 +218,7 @@ class ScheduleConfig(BaseModel): start_time: Optional[str] = None end_time: Optional[str] = None allow_queueing: Optional[bool] = False - max_run_count: Optional[Union[int, None]] = None + max_run_count: Optional[int] = None max_concurrent_run_count: Optional[int] = 1 @@ -238,7 +238,7 @@ class RunConfig(BaseModel): node_selectors: Optional[Dict[str, Dict[str, str]]] = {} dynamic_config_providers: Optional[List[DynamicConfigProviderConfig]] = [] mlflow: Optional[MLFlowVertexAIConfig] = None - schedules: Dict[str, ScheduleConfig] + schedules: Optional[Dict[str, ScheduleConfig]] = None def resources_for(self, node: str, tags: Optional[set] = None): default_config = self.resources["__default__"].dict() diff --git a/tests/test_config.py b/tests/test_config.py index b0629cc..847b06d 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -31,6 +31,15 @@ size: 3Gi access_modes: [ReadWriteOnce] keep: True +# schedules: +# default_schedule: +# cron_expression: "0 * * * *" +# timezone: Etc/UTC +# start_time: null +# end_time: null +# allow_queueing: false +# max_run_count: null +# max_concurrent_run_count: 1 mlflow: request_header_provider_params: service_account: test@example.com diff --git a/tests/test_vertex_ai_client.py b/tests/test_vertex_ai_client.py index aae7985..925014f 100644 --- a/tests/test_vertex_ai_client.py +++ b/tests/test_vertex_ai_client.py @@ -9,9 +9,7 @@ class TestVertexAIClient(unittest.TestCase): - @patch("kedro_vertexai.client.CloudSchedulerClient") - def create_client(self, cloud_scheduler_client_mock): - self.cloud_scheduler_client_mock = cloud_scheduler_client_mock.return_value + def create_client(self): config = PluginConfig.parse_obj( { "project_id": "PROJECT_ID",