Skip to content
3 changes: 3 additions & 0 deletions sdk/ml/azure-ai-ml/azure/ai/ml/dsl/_pipeline_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,9 @@ def wrapper(*args, **kwargs) -> PipelineJob:
"tags": tags,
}
if _is_inside_dsl_pipeline_func():
# on_init/on_finalize is not supported for pipeline component
if job_settings.get("on_init") is not None or job_settings.get("on_finalize") is not None:
raise UserErrorException("On_init/on_finalize is not supported for pipeline component.")
# Build pipeline node instead of pipeline job if inside dsl.
built_pipeline = Pipeline(_from_component_func=True, **common_init_args)
if job_settings:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,12 +290,12 @@ def _validate_init_finalize_job(self) -> ValidationResult:
if job.type != "pipeline":
continue
if job.settings.on_init:
validation_result.append_warning(
validation_result.append_error(
yaml_path=f"jobs.{job_name}.settings.on_init",
message="On_init is not supported for subgraph.",
)
if job.settings.on_finalize:
validation_result.append_warning(
validation_result.append_error(
yaml_path=f"jobs.{job_name}.settings.on_finalize",
message="On_finalize is not supported for subgraph",
)
Expand All @@ -319,12 +319,16 @@ def _is_isolated_job(_validate_job_name: str) -> bool:
# no input to validate job
_validate_job = self.jobs[_validate_job_name]
for _input_name in _validate_job.inputs:
if not hasattr(_validate_job.inputs[_input_name]._data, "_data_binding"):
continue
_data_binding = _validate_job.inputs[_input_name]._data._data_binding()
if is_data_binding_expression(_data_binding, ["parent", "jobs"]):
return False
# no output from validate job
for _job_name, _job in self.jobs.items():
for _input_name in _job.inputs:
if not hasattr(_job.inputs[_input_name]._data, "_data_binding"):
continue
_data_binding = _job.inputs[_input_name]._data._data_binding()
if is_data_binding_expression(_data_binding, ["parent", "jobs", _validate_job_name]):
return False
Expand Down
36 changes: 17 additions & 19 deletions sdk/ml/azure-ai-ml/tests/dsl/unittests/test_dsl_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3856,6 +3856,7 @@ def test_init_finalize_job(self) -> None:
from azure.ai.ml.dsl import pipeline

def assert_pipeline_job_init_finalize_job(pipeline_job: PipelineJob):
assert pipeline_job._validate_init_finalize_job().passed
assert pipeline_job.settings.on_init == "init_job"
assert pipeline_job.settings.on_finalize == "finalize_job"
pipeline_job_dict = pipeline_job._to_rest_object().as_dict()
Expand Down Expand Up @@ -3951,6 +3952,21 @@ def init_finalize_with_invalid_connection_func(int_param: int, str_param: str):
set_pipeline_settings(on_init="init_job", on_finalize="finalize_job")
assert str(e.value) == "Please call `set_pipeline_settings` inside a `pipeline` decorated function."

# invalid case: set on_init for pipeline component
@dsl.pipeline
def subgraph_func():
node = self.component_func()
set_pipeline_settings(on_init=node) # set on_init for subgraph (pipeline component)

@dsl.pipeline
def subgraph_with_init_func():
subgraph_func()
self.component_func()

with pytest.raises(UserErrorException) as e:
subgraph_with_init_func()
assert str(e.value) == "On_init/on_finalize is not supported for pipeline component."

def test_init_finalize_job_with_subgraph(self, caplog) -> None:
from azure.ai.ml._internal.dsl import set_pipeline_settings

Expand All @@ -3959,7 +3975,6 @@ def test_init_finalize_job_with_subgraph(self, caplog) -> None:
def subgraph_func():
node = self.component_func()
node.compute = "cpu-cluster"
set_pipeline_settings(on_init=node) # will be ignored when in subgraph

@dsl.pipeline()
def subgraph_init_finalize_job_func():
Expand All @@ -3968,24 +3983,7 @@ def subgraph_init_finalize_job_func():
finalize_job = subgraph_func()
set_pipeline_settings(on_init=init_job, on_finalize=finalize_job)

# as we set on_init for subgraph, there should be warning of ignoring on_init setting
# update logger name to "Operation" to enable caplog capture logs
from azure.ai.ml.dsl import _pipeline_decorator

_pipeline_decorator.module_logger = logging.getLogger("Operation")
with caplog.at_level(logging.WARNING):
valid_pipeline = subgraph_init_finalize_job_func()
assert any(
[
(
"Job settings {'on_init': 'node'} on pipeline function 'subgraph_func' "
"are ignored when using inside PipelineJob."
)
== msg.replace(" ", "") # hack here, use replace to remove spaces in the middle
for msg in caplog.messages
]
)

valid_pipeline = subgraph_init_finalize_job_func()
assert valid_pipeline._customized_validate().passed
assert valid_pipeline.settings.on_init == "init_job"
assert valid_pipeline.settings.on_finalize == "finalize_job"
Expand Down