Skip to content

Commit

Permalink
Add tests for disabling default caching var and flag
Browse files Browse the repository at this point in the history
Signed-off-by: ddalvi <[email protected]>
  • Loading branch information
DharmitD committed Sep 24, 2024
1 parent 600624d commit ad89b3e
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 29 deletions.
104 changes: 79 additions & 25 deletions sdk/python/kfp/cli/cli_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from click import testing
from kfp.cli import cli
from kfp.cli import compile_
import yaml


class TestCliNounAliases(unittest.TestCase):
Expand Down Expand Up @@ -166,34 +167,87 @@ def test_deprecation_warning(self):
res.stdout.decode('utf-8'))


info_dict = cli.cli.to_info_dict(ctx=click.Context(cli.cli))
commands_dict = {
command: list(body.get('commands', {}).keys())
for command, body in info_dict['commands'].items()
}
noun_verb_list = [
(noun, verb) for noun, verbs in commands_dict.items() for verb in verbs
]
class TestKfpDslCompile(unittest.TestCase):

def invoke(self, args):
starting_args = ['dsl', 'compile']
args = starting_args + args
runner = testing.CliRunner()
return runner.invoke(
cli=cli.cli, args=args, catch_exceptions=False, obj={})

class TestSmokeTestAllCommandsWithHelp(parameterized.TestCase):
def create_pipeline_file(self):
pipeline_code = b"""
from kfp import dsl
@dsl.component
def my_component():
pass
@dsl.pipeline(name="tiny-pipeline")
def my_pipeline():
my_component_task = my_component()
"""
temp_pipeline = tempfile.NamedTemporaryFile(suffix='.py', delete=False)
temp_pipeline.write(pipeline_code)
temp_pipeline.flush()
return temp_pipeline

def load_output_yaml(self, output_file):
with open(output_file, 'r') as f:
return yaml.safe_load(f)

def test_compile_with_caching_flag_enabled(self):
temp_pipeline = self.create_pipeline_file()
output_file = 'test_output.yaml'

result = self.invoke(
['--py', temp_pipeline.name, '--output', output_file])
self.assertEqual(result.exit_code, 0)

@classmethod
def setUpClass(cls):
cls.runner = testing.CliRunner()

cls.vals = [('run', 'list')]

@parameterized.parameters(*noun_verb_list)
def test(self, noun: str, verb: str):
with mock.patch('kfp.cli.cli.client.Client'):
result = self.runner.invoke(
args=[noun, verb, '--help'],
cli=cli.cli,
catch_exceptions=False,
obj={})
self.assertTrue(result.output.startswith('Usage: '))
self.assertEqual(result.exit_code, 0)
output_data = self.load_output_yaml(output_file)
self.assertIn('root', output_data)
self.assertIn('tasks', output_data['root']['dag'])
for task in output_data['root']['dag']['tasks'].values():
self.assertIn('cachingOptions', task)
caching_options = task['cachingOptions']
self.assertEqual(caching_options.get('enableCache'), True)

def test_compile_with_caching_flag_disabled(self):
temp_pipeline = self.create_pipeline_file()
output_file = 'test_output.yaml'

result = self.invoke([
'--py', temp_pipeline.name, '--output', output_file,
'--disable-execution-caching-by-default'
])
self.assertEqual(result.exit_code, 0)

output_data = self.load_output_yaml(output_file)
self.assertIn('root', output_data)
self.assertIn('tasks', output_data['root']['dag'])
for task in output_data['root']['dag']['tasks'].values():
self.assertIn('cachingOptions', task)
caching_options = task['cachingOptions']
self.assertEqual(caching_options, {})

def test_compile_with_caching_disabled_env_var(self):
temp_pipeline = self.create_pipeline_file()
output_file = 'test_output.yaml'

os.environ['KFP_DISABLE_EXECUTION_CACHING_BY_DEFAULT'] = 'true'
result = self.invoke(
['--py', temp_pipeline.name, '--output', output_file])
self.assertEqual(result.exit_code, 0)
del os.environ['KFP_DISABLE_EXECUTION_CACHING_BY_DEFAULT']

output_data = self.load_output_yaml(output_file)
self.assertIn('root', output_data)
self.assertIn('tasks', output_data['root']['dag'])
for task in output_data['root']['dag']['tasks'].values():
self.assertIn('cachingOptions', task)
caching_options = task['cachingOptions']
self.assertEqual(caching_options, {})


if __name__ == '__main__':
Expand Down
1 change: 1 addition & 0 deletions sdk/python/kfp/cli/compile_.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from kfp.dsl import graph_component
from kfp.dsl.pipeline_context import Pipeline


def is_pipeline_func(func: Callable) -> bool:
"""Checks if a function is a pipeline function.
Expand Down
53 changes: 53 additions & 0 deletions sdk/python/kfp/compiler/compiler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,6 +910,59 @@ def my_pipeline() -> NamedTuple('Outputs', [
task = print_and_return(text='Hello')


class TestCompilePipelineCaching(unittest.TestCase):

def test_compile_pipeline_with_caching_enabled(self):
"""Test pipeline compilation with caching enabled."""

@dsl.component
def my_component():
pass

@dsl.pipeline(name='tiny-pipeline')
def my_pipeline():
my_task = my_component()
my_task.set_caching_options(True)

with tempfile.TemporaryDirectory() as tempdir:
output_yaml = os.path.join(tempdir, 'pipeline.yaml')
compiler.Compiler().compile(
pipeline_func=my_pipeline, package_path=output_yaml)

with open(output_yaml, 'r') as f:
pipeline_spec = yaml.safe_load(f)

task_spec = pipeline_spec['root']['dag']['tasks']['my-component']
caching_options = task_spec['cachingOptions']

self.assertTrue(caching_options['enableCache'])

def test_compile_pipeline_with_caching_disabled(self):
"""Test pipeline compilation with caching disabled."""

@dsl.component
def my_component():
pass

@dsl.pipeline(name='tiny-pipeline')
def my_pipeline():
my_task = my_component()
my_task.set_caching_options(False)

with tempfile.TemporaryDirectory() as tempdir:
output_yaml = os.path.join(tempdir, 'pipeline.yaml')
compiler.Compiler().compile(
pipeline_func=my_pipeline, package_path=output_yaml)

with open(output_yaml, 'r') as f:
pipeline_spec = yaml.safe_load(f)

task_spec = pipeline_spec['root']['dag']['tasks']['my-component']
caching_options = task_spec.get('cachingOptions', {})

self.assertEqual(caching_options, {})


class V2NamespaceAliasTest(unittest.TestCase):
"""Test that imports of both modules and objects are aliased (e.g. all
import path variants work)."""
Expand Down
3 changes: 2 additions & 1 deletion sdk/python/kfp/dsl/base_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ def __call__(self, *args, **kwargs) -> pipeline_task.PipelineTask:
args=task_inputs,
execute_locally=pipeline_context.Pipeline.get_default_pipeline() is
None,
execution_caching_default=pipeline_context.Pipeline.get_execution_caching_default(),
execution_caching_default=pipeline_context.Pipeline
.get_execution_caching_default(),
)

@property
Expand Down
7 changes: 4 additions & 3 deletions sdk/python/kfp/dsl/pipeline_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,14 @@
"""Definition for Pipeline."""

import functools
import os
from typing import Callable, Optional

from kfp.dsl import component_factory
from kfp.dsl import pipeline_task
from kfp.dsl import tasks_group
from kfp.dsl import utils

import os


def pipeline(func: Optional[Callable] = None,
*,
Expand Down Expand Up @@ -107,7 +106,9 @@ def get_default_pipeline():
# or the env var KFP_DISABLE_EXECUTION_CACHING_BY_DEFAULT.
# align with click's treatment of env vars for boolean flags.
# per click doc, "1", "true", "t", "yes", "y", and "on" are all converted to True
_execution_caching_default = not str(os.getenv('KFP_DISABLE_EXECUTION_CACHING_BY_DEFAULT')).strip().lower() in {"1", "true", "t", "yes", "y", "on"}
_execution_caching_default = not str(
os.getenv('KFP_DISABLE_EXECUTION_CACHING_BY_DEFAULT')).strip().lower(
) in {'1', 'true', 't', 'yes', 'y', 'on'}

@staticmethod
def get_execution_caching_default():
Expand Down

0 comments on commit ad89b3e

Please sign in to comment.