Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 109 additions & 2 deletions airflow/models/xcom_arg.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,36 @@
if TYPE_CHECKING:
from airflow.models.dag import DAG
from airflow.models.operator import Operator
from airflow.models.taskinstance import TaskInstance
from airflow.utils.task_group import MappedTaskGroup, TaskGroup

# Callable objects contained by MapXComArg. We only accept callables from
# the user, but deserialize them into strings in a serialized XComArg for
# safety (those callables are arbitrary user code).
MapCallables = Sequence[Union[Callable[[Any], Any], str]]


def _find_common_ancestor_mapped_group(node1: Operator, node2: Operator) -> MappedTaskGroup | None:
"""Given two operators, find their innermost common mapped task group."""
if node1.dag is None or node2.dag is None or node1.dag_id != node2.dag_id:
return None
parent_group_ids = {g.group_id for g in node1.iter_mapped_task_groups()}
common_groups = (g for g in node2.iter_mapped_task_groups() if g.group_id in parent_group_ids)
return next(common_groups, None)


def _is_further_mapped_inside(operator: Operator, container: TaskGroup) -> bool:
"""Whether given operator is *further* mapped inside a task group."""
if operator.is_mapped:
return True
task_group = operator.task_group
while task_group is not None and task_group.group_id != container.group_id:
if isinstance(task_group, MappedTaskGroup):
return True
task_group = task_group.parent_group
return False


class XComArg(ResolveMixin, DependencyMixin):
"""Reference to an XCom value pushed from another operator.

Expand Down Expand Up @@ -318,15 +341,99 @@ def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
)
return query.scalar()

def _get_map_indexes_to_pull(
self,
ti: TaskInstance,
ti_count: int | None,
*,
session: Session,
) -> int | range | None:
"""Infer the correct ``map_indexes`` to ``xcom_pull`` for resolution.

The bulk of the logic mainly exists to solve the problem described by
the following example, where 'val' must resolve to different values,
depending on where the reference is being used::

@task
def this_task(v): # This task is ti.task.
return v * 2

@task_group
def tg1(inp):
val = referenced_task(inp) # This task is self.operator.
this_task(val) # When inp is 1, val here should resolve to 2.
return val

# This val is the same object returned by tg1.
val = tg1.expand(inp=[1, 2, 3])

@task_group
def tg2(inp):
another_task(inp, val) # val here should resolve to [2, 4, 6].

tg2.expand(inp=["a", "b"])

The surrounding mapped task groups of ``self.operator`` and ``ti.task``
are inspected to find a common "ancestor". If such an ancestor is found,
we need to return specific map indexes to pull a partial value from
upstream XCom.

:param ti: The currently executing task instance, i.e. ``ti`` in the
template context.
:param ti_count: The total count of task instance this task was expanded
by the scheduler, i.e. ``expanded_ti_count`` in the template context.
:return: Specific map index or map indexes to pull, or ``None`` if we
want to "whole" return value (i.e. no mapped task groups involved).
"""
# Find the innermost common mapped task group between the current task
# If the current task and the referenced task does not have a common
# mapped task group, the two are in different task mapping contexts
# (like another_task above), and we should use the "whole" value.
common_ancestor = _find_common_ancestor_mapped_group(ti.task, self.operator)
if common_ancestor is None:
return None

# This value should never be None since we already know the current task
# is in a mapped task group, and should have been expanded. The check
# exists mainly to satisfy Mypy.
if ti_count is None:
return None

# At this point we know the two tasks share a mapped task group, and we
# should use a "partial" value. Let's break down the mapped ti count
# between the ancestor and further expansion happened inside it.
ancestor_ti_count = common_ancestor.get_mapped_ti_count(ti.run_id, session=session)
ancestor_map_index = ti.map_index * ancestor_ti_count // ti_count

# If the task is NOT further expanded inside the common ancestor, we
# only want to reference one single ti. We must walk the actual DAG,
# and "ti_count == ancestor_ti_count" does not work, since the further
# expansion may be of length 1.
if not _is_further_mapped_inside(self.operator, common_ancestor):
return ancestor_map_index

# Otherwise we need a partial aggregation for values from selected task
# instances in the ancestor's expansion context.
further_count = ti_count // ancestor_ti_count
map_index_start = ancestor_map_index * further_count
return range(map_index_start, map_index_start + further_count)

@provide_session
def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
ti = context["ti"]
task_id = self.operator.task_id
result = context["ti"].xcom_pull(task_ids=task_id, key=str(self.key), default=NOTSET, session=session)
result = ti.xcom_pull(
task_ids=task_id,
map_indexes=self._get_map_indexes_to_pull(ti, context["expanded_ti_count"], session=session),
key=self.key,
default=NOTSET,
session=session,
)
if not isinstance(result, ArgNotSet):
return result
if self.key == XCOM_RETURN_KEY:
return None
raise XComNotFound(context["ti"].dag_id, task_id, self.key)
raise XComNotFound(ti.dag_id, task_id, self.key)


def _get_callable_name(f: Callable | str) -> str:
Expand Down
48 changes: 48 additions & 0 deletions tests/models/test_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -2051,3 +2051,51 @@ def tg(x, y):
("tg.task_2", 4, None),
("tg.task_2", 5, None),
}


def test_operator_mapped_task_group_receives_value(dag_maker, session):
with dag_maker(session=session):

@task
def t(value, *, ti=None):
results[(ti.task_id, ti.map_index)] = value
return value

@task_group
def tg(va):
# Each expanded group has one t1 and t2 each.
t1 = t.override(task_id="t1")(va)
t2 = t.override(task_id="t2")(t1)

with pytest.raises(NotImplementedError) as ctx:
t.override(task_id="t4").expand(value=va)
assert str(ctx.value) == "operator expansion in an expanded task group is not yet supported"

return t2

# The group is mapped by 3.
t2 = tg.expand(va=[["a", "b"], [4], ["z"]])

# Aggregates results from task group.
t.override(task_id="t3")(t2)

dr: DagRun = dag_maker.create_dagrun()

results = {}
decision = dr.task_instance_scheduling_decisions(session=session)
for ti in decision.schedulable_tis:
ti.run()
assert results == {("tg.t1", 0): ["a", "b"], ("tg.t1", 1): [4], ("tg.t1", 2): ["z"]}

results = {}
decision = dr.task_instance_scheduling_decisions(session=session)
for ti in decision.schedulable_tis:
ti.run()
assert results == {("tg.t2", 0): ["a", "b"], ("tg.t2", 1): [4], ("tg.t2", 2): ["z"]}

results = {}
decision = dr.task_instance_scheduling_decisions(session=session)
for ti in decision.schedulable_tis:
ti.run()
assert len(results) == 1
assert list(results[("t3", -1)]) == [["a", "b"], [4], ["z"]]