diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py index d43c5dffe040b..9d49a7c74ea4b 100644 --- a/airflow/models/xcom_arg.py +++ b/airflow/models/xcom_arg.py @@ -37,6 +37,8 @@ if TYPE_CHECKING: from airflow.models.dag import DAG from airflow.models.operator import Operator + from airflow.models.taskinstance import TaskInstance + from airflow.utils.task_group import MappedTaskGroup, TaskGroup # Callable objects contained by MapXComArg. We only accept callables from # the user, but deserialize them into strings in a serialized XComArg for @@ -44,6 +46,27 @@ MapCallables = Sequence[Union[Callable[[Any], Any], str]] +def _find_common_ancestor_mapped_group(node1: Operator, node2: Operator) -> MappedTaskGroup | None: + """Given two operators, find their innermost common mapped task group.""" + if node1.dag is None or node2.dag is None or node1.dag_id != node2.dag_id: + return None + parent_group_ids = {g.group_id for g in node1.iter_mapped_task_groups()} + common_groups = (g for g in node2.iter_mapped_task_groups() if g.group_id in parent_group_ids) + return next(common_groups, None) + + +def _is_further_mapped_inside(operator: Operator, container: TaskGroup) -> bool: + """Whether given operator is *further* mapped inside a task group.""" + if operator.is_mapped: + return True + task_group = operator.task_group + while task_group is not None and task_group.group_id != container.group_id: + if isinstance(task_group, MappedTaskGroup): + return True + task_group = task_group.parent_group + return False + + class XComArg(ResolveMixin, DependencyMixin): """Reference to an XCom value pushed from another operator. @@ -318,15 +341,99 @@ def get_task_map_length(self, run_id: str, *, session: Session) -> int | None: ) return query.scalar() + def _get_map_indexes_to_pull( + self, + ti: TaskInstance, + ti_count: int | None, + *, + session: Session, + ) -> int | range | None: + """Infer the correct ``map_indexes`` to ``xcom_pull`` for resolution. + + The bulk of the logic mainly exists to solve the problem described by + the following example, where 'val' must resolve to different values, + depending on where the reference is being used:: + + @task + def this_task(v): # This task is ti.task. + return v * 2 + + @task_group + def tg1(inp): + val = referenced_task(inp) # This task is self.operator. + this_task(val) # When inp is 1, val here should resolve to 2. + return val + + # This val is the same object returned by tg1. + val = tg1.expand(inp=[1, 2, 3]) + + @task_group + def tg2(inp): + another_task(inp, val) # val here should resolve to [2, 4, 6]. + + tg2.expand(inp=["a", "b"]) + + The surrounding mapped task groups of ``self.operator`` and ``ti.task`` + are inspected to find a common "ancestor". If such an ancestor is found, + we need to return specific map indexes to pull a partial value from + upstream XCom. + + :param ti: The currently executing task instance, i.e. ``ti`` in the + template context. + :param ti_count: The total count of task instance this task was expanded + by the scheduler, i.e. ``expanded_ti_count`` in the template context. + :return: Specific map index or map indexes to pull, or ``None`` if we + want to "whole" return value (i.e. no mapped task groups involved). + """ + # Find the innermost common mapped task group between the current task + # If the current task and the referenced task does not have a common + # mapped task group, the two are in different task mapping contexts + # (like another_task above), and we should use the "whole" value. + common_ancestor = _find_common_ancestor_mapped_group(ti.task, self.operator) + if common_ancestor is None: + return None + + # This value should never be None since we already know the current task + # is in a mapped task group, and should have been expanded. The check + # exists mainly to satisfy Mypy. + if ti_count is None: + return None + + # At this point we know the two tasks share a mapped task group, and we + # should use a "partial" value. Let's break down the mapped ti count + # between the ancestor and further expansion happened inside it. + ancestor_ti_count = common_ancestor.get_mapped_ti_count(ti.run_id, session=session) + ancestor_map_index = ti.map_index * ancestor_ti_count // ti_count + + # If the task is NOT further expanded inside the common ancestor, we + # only want to reference one single ti. We must walk the actual DAG, + # and "ti_count == ancestor_ti_count" does not work, since the further + # expansion may be of length 1. + if not _is_further_mapped_inside(self.operator, common_ancestor): + return ancestor_map_index + + # Otherwise we need a partial aggregation for values from selected task + # instances in the ancestor's expansion context. + further_count = ti_count // ancestor_ti_count + map_index_start = ancestor_map_index * further_count + return range(map_index_start, map_index_start + further_count) + @provide_session def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any: + ti = context["ti"] task_id = self.operator.task_id - result = context["ti"].xcom_pull(task_ids=task_id, key=str(self.key), default=NOTSET, session=session) + result = ti.xcom_pull( + task_ids=task_id, + map_indexes=self._get_map_indexes_to_pull(ti, context["expanded_ti_count"], session=session), + key=self.key, + default=NOTSET, + session=session, + ) if not isinstance(result, ArgNotSet): return result if self.key == XCOM_RETURN_KEY: return None - raise XComNotFound(context["ti"].dag_id, task_id, self.key) + raise XComNotFound(ti.dag_id, task_id, self.key) def _get_callable_name(f: Callable | str) -> str: diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py index 58307f26bd6a3..ffee25f5e89a9 100644 --- a/tests/models/test_dagrun.py +++ b/tests/models/test_dagrun.py @@ -2051,3 +2051,51 @@ def tg(x, y): ("tg.task_2", 4, None), ("tg.task_2", 5, None), } + + +def test_operator_mapped_task_group_receives_value(dag_maker, session): + with dag_maker(session=session): + + @task + def t(value, *, ti=None): + results[(ti.task_id, ti.map_index)] = value + return value + + @task_group + def tg(va): + # Each expanded group has one t1 and t2 each. + t1 = t.override(task_id="t1")(va) + t2 = t.override(task_id="t2")(t1) + + with pytest.raises(NotImplementedError) as ctx: + t.override(task_id="t4").expand(value=va) + assert str(ctx.value) == "operator expansion in an expanded task group is not yet supported" + + return t2 + + # The group is mapped by 3. + t2 = tg.expand(va=[["a", "b"], [4], ["z"]]) + + # Aggregates results from task group. + t.override(task_id="t3")(t2) + + dr: DagRun = dag_maker.create_dagrun() + + results = {} + decision = dr.task_instance_scheduling_decisions(session=session) + for ti in decision.schedulable_tis: + ti.run() + assert results == {("tg.t1", 0): ["a", "b"], ("tg.t1", 1): [4], ("tg.t1", 2): ["z"]} + + results = {} + decision = dr.task_instance_scheduling_decisions(session=session) + for ti in decision.schedulable_tis: + ti.run() + assert results == {("tg.t2", 0): ["a", "b"], ("tg.t2", 1): [4], ("tg.t2", 2): ["z"]} + + results = {} + decision = dr.task_instance_scheduling_decisions(session=session) + for ti in decision.schedulable_tis: + ti.run() + assert len(results) == 1 + assert list(results[("t3", -1)]) == [["a", "b"], [4], ["z"]]