Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion airflow/models/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,10 @@ def parse_time_mapped_ti_count(self) -> Optional[int]:
if not isinstance(value, MAPPABLE_LITERAL_TYPES):
# None literal type encountered, so give up
return None
total += len(value)
if total == 0:
total = len(value)
else:
total *= len(value)
Comment on lines +798 to +801
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could have total = 1 at the beginning of the function instead…?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried that, and that's how I found the .expand() bug. If we can assume we will always have at least something with a length, we could, but this felt more defensive.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let’s do this right now and maybe change it when we fix the expand bug. That way we can keep main working correctly.

return total

@cache
Expand Down
36 changes: 36 additions & 0 deletions tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2689,6 +2689,7 @@ def show(a, b):
ti.run()

show_task = dag.get_task("show")
assert show_task.parse_time_mapped_ti_count is None
mapped_tis, num = show_task.expand_mapped_task(dag_run.run_id, session=session)
assert num == len(mapped_tis) == 4

Expand All @@ -2697,6 +2698,41 @@ def show(a, b):
ti.run()
assert outputs == [(1, 1), (1, 2), (2, 1), (2, 2)]

def test_map_literal_cross_product(self, dag_maker, session):
"""Test a mapped task with literal cross product args expand properly."""
outputs = []

with dag_maker(dag_id="product_same_types", session=session) as dag:

@dag.task
def show(a, b):
outputs.append((a, b))

show.expand(a=[2, 4, 8], b=[5, 10])

dag_run = dag_maker.create_dagrun()

show_task = dag.get_task("show")
assert show_task.parse_time_mapped_ti_count == 6
mapped_tis, num = show_task.expand_mapped_task(dag_run.run_id, session=session)
assert len(mapped_tis) == 0 # Expanded at parse!
assert num == 6

tis = (
session.query(TaskInstance)
.filter(
TaskInstance.dag_id == dag.dag_id,
TaskInstance.task_id == 'show',
TaskInstance.run_id == dag_run.run_id,
)
.order_by(TaskInstance.map_index)
.all()
)
for ti in tis:
ti.refresh_from_task(show_task)
ti.run()
assert outputs == [(2, 5), (2, 10), (4, 5), (4, 10), (8, 5), (8, 10)]

def test_map_in_group(self, tmp_path: pathlib.Path, dag_maker, session):
out = tmp_path.joinpath("out")
out.touch()
Expand Down